feat: Support create collection with functions (#35973)

relate: https://github.com/milvus-io/milvus/issues/35853
Support create collection with functions. Prepare for support bm25
function.

---------

Signed-off-by: aoiasd <zhicheng.yue@zilliz.com>
This commit is contained in:
aoiasd 2024-09-12 10:43:06 +08:00 committed by GitHub
parent 08e681174a
commit da227ff9a1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 765 additions and 140 deletions

2
go.mod
View File

@ -22,7 +22,7 @@ require (
github.com/grpc-ecosystem/go-grpc-middleware v1.3.0
github.com/klauspost/compress v1.17.7
github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d
github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240822040249-4bbc8f623cbb
github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240909041258-8f8ca67816cd
github.com/minio/minio-go/v7 v7.0.61
github.com/pingcap/log v1.1.1-0.20221015072633-39906604fb81
github.com/prometheus/client_golang v1.14.0

4
go.sum
View File

@ -602,8 +602,8 @@ github.com/milvus-io/cgosymbolizer v0.0.0-20240722103217-b7dee0e50119 h1:9VXijWu
github.com/milvus-io/cgosymbolizer v0.0.0-20240722103217-b7dee0e50119/go.mod h1:DvXTE/K/RtHehxU8/GtDs4vFtfw64jJ3PaCnFri8CRg=
github.com/milvus-io/gorocksdb v0.0.0-20220624081344-8c5f4212846b h1:TfeY0NxYxZzUfIfYe5qYDBzt4ZYRqzUjTR6CvUzjat8=
github.com/milvus-io/gorocksdb v0.0.0-20220624081344-8c5f4212846b/go.mod h1:iwW+9cWfIzzDseEBCCeDSN5SD16Tidvy8cwQ7ZY8Qj4=
github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240822040249-4bbc8f623cbb h1:S3QIkNv9N1Vd1UKtdaQ4yVDPFAwFiPSAjN07axzbR70=
github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240822040249-4bbc8f623cbb/go.mod h1:/6UT4zZl6awVeXLeE7UGDWZvXj3IWkRsh3mqsn0DiAs=
github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240909041258-8f8ca67816cd h1:x0b0+foTe23sKcVFseR1DE8+BB08EH6ViiRHaz8PEik=
github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240909041258-8f8ca67816cd/go.mod h1:/6UT4zZl6awVeXLeE7UGDWZvXj3IWkRsh3mqsn0DiAs=
github.com/milvus-io/pulsar-client-go v0.6.10 h1:eqpJjU+/QX0iIhEo3nhOqMNXL+TyInAs1IAHZCrCM/A=
github.com/milvus-io/pulsar-client-go v0.6.10/go.mod h1:lQqCkgwDF8YFYjKA+zOheTk1tev2B+bKj5j7+nm8M1w=
github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8 h1:AMFGa4R4MiIpspGNG7Z948v4n35fFGB3RR3G/ry4FWs=

View File

@ -60,6 +60,14 @@ func BuildFieldKey(collectionID typeutil.UniqueID, fieldID int64) string {
return fmt.Sprintf("%s/%d", BuildFieldPrefix(collectionID), fieldID)
}
func BuildFunctionPrefix(collectionID typeutil.UniqueID) string {
return fmt.Sprintf("%s/%d", FunctionMetaPrefix, collectionID)
}
func BuildFunctionKey(collectionID typeutil.UniqueID, functionID int64) string {
return fmt.Sprintf("%s/%d", BuildFunctionPrefix(collectionID), functionID)
}
func BuildAliasKey210(alias string) string {
return fmt.Sprintf("%s/%s", CollectionAliasMetaPrefix210, alias)
}
@ -166,7 +174,7 @@ func (kc *Catalog) CreateCollection(ctx context.Context, coll *model.Collection,
kvs := map[string]string{}
// save partition info to newly path.
// save partition info to new path.
for _, partition := range coll.Partitions {
k := BuildPartitionKey(coll.CollectionID, partition.PartitionID)
partitionInfo := model.MarshalPartitionModel(partition)
@ -178,8 +186,7 @@ func (kc *Catalog) CreateCollection(ctx context.Context, coll *model.Collection,
}
// no default aliases will be created.
// save fields info to newly path.
// save fields info to new path.
for _, field := range coll.Fields {
k := BuildFieldKey(coll.CollectionID, field.FieldID)
fieldInfo := model.MarshalFieldModel(field)
@ -190,6 +197,17 @@ func (kc *Catalog) CreateCollection(ctx context.Context, coll *model.Collection,
kvs[k] = string(v)
}
// save functions info to new path.
for _, function := range coll.Functions {
k := BuildFunctionKey(coll.CollectionID, function.ID)
functionInfo := model.MarshalFunctionModel(function)
v, err := proto.Marshal(functionInfo)
if err != nil {
return err
}
kvs[k] = string(v)
}
// Though batchSave is not atomic enough, we can promise the atomicity outside.
// Recovering from failure, if we found collection is creating, we should remove all these related meta.
// since SnapshotKV may save both snapshot key and the original key if the original key is newest
@ -358,6 +376,24 @@ func (kc *Catalog) listFieldsAfter210(ctx context.Context, collectionID typeutil
return fields, nil
}
func (kc *Catalog) listFunctions(collectionID typeutil.UniqueID, ts typeutil.Timestamp) ([]*model.Function, error) {
prefix := BuildFunctionPrefix(collectionID)
_, values, err := kc.Snapshot.LoadWithPrefix(prefix, ts)
if err != nil {
return nil, err
}
functions := make([]*model.Function, 0, len(values))
for _, v := range values {
functionSchema := &schemapb.FunctionSchema{}
err := proto.Unmarshal([]byte(v), functionSchema)
if err != nil {
return nil, err
}
functions = append(functions, model.UnmarshalFunctionModel(functionSchema))
}
return functions, nil
}
func (kc *Catalog) appendPartitionAndFieldsInfo(ctx context.Context, collMeta *pb.CollectionInfo,
ts typeutil.Timestamp,
) (*model.Collection, error) {
@ -379,6 +415,11 @@ func (kc *Catalog) appendPartitionAndFieldsInfo(ctx context.Context, collMeta *p
}
collection.Fields = fields
functions, err := kc.listFunctions(collection.CollectionID, ts)
if err != nil {
return nil, err
}
collection.Functions = functions
return collection, nil
}
@ -441,6 +482,9 @@ func (kc *Catalog) DropCollection(ctx context.Context, collectionInfo *model.Col
for _, field := range collectionInfo.Fields {
delMetakeysSnap = append(delMetakeysSnap, BuildFieldKey(collectionInfo.CollectionID, field.FieldID))
}
for _, function := range collectionInfo.Functions {
delMetakeysSnap = append(delMetakeysSnap, BuildFunctionKey(collectionInfo.CollectionID, function.ID))
}
// delMetakeysSnap = append(delMetakeysSnap, buildPartitionPrefix(collectionInfo.CollectionID))
// delMetakeysSnap = append(delMetakeysSnap, buildFieldPrefix(collectionInfo.CollectionID))

View File

@ -207,8 +207,17 @@ func TestCatalog_ListCollections(t *testing.T) {
return strings.HasPrefix(prefix, FieldMetaPrefix)
}), ts).
Return([]string{"key"}, []string{string(fm)}, nil)
kc := Catalog{Snapshot: kv}
functionMeta := &schemapb.FunctionSchema{}
fcm, err := proto.Marshal(functionMeta)
assert.NoError(t, err)
kv.On("LoadWithPrefix", mock.MatchedBy(
func(prefix string) bool {
return strings.HasPrefix(prefix, FunctionMetaPrefix)
}), ts).
Return([]string{"key"}, []string{string(fcm)}, nil)
kc := Catalog{Snapshot: kv}
ret, err := kc.ListCollections(ctx, testDb, ts)
assert.NoError(t, err)
assert.NotNil(t, ret)
@ -248,6 +257,16 @@ func TestCatalog_ListCollections(t *testing.T) {
return strings.HasPrefix(prefix, FieldMetaPrefix)
}), ts).
Return([]string{"key"}, []string{string(fm)}, nil)
functionMeta := &schemapb.FunctionSchema{}
fcm, err := proto.Marshal(functionMeta)
assert.NoError(t, err)
kv.On("LoadWithPrefix", mock.MatchedBy(
func(prefix string) bool {
return strings.HasPrefix(prefix, FunctionMetaPrefix)
}), ts).
Return([]string{"key"}, []string{string(fcm)}, nil)
kv.On("MultiSaveAndRemove", mock.Anything, mock.Anything, ts).Return(nil)
kc := Catalog{Snapshot: kv}
@ -1215,6 +1234,22 @@ func TestCatalog_CreateCollection(t *testing.T) {
err := kc.CreateCollection(ctx, coll, 100)
assert.NoError(t, err)
})
t.Run("create collection with function", func(t *testing.T) {
mockSnapshot := newMockSnapshot(t, withMockSave(nil), withMockMultiSave(nil))
kc := &Catalog{Snapshot: mockSnapshot}
ctx := context.Background()
coll := &model.Collection{
Partitions: []*model.Partition{
{PartitionName: "test"},
},
Fields: []*model.Field{{Name: "text", DataType: schemapb.DataType_VarChar}, {Name: "sparse", DataType: schemapb.DataType_SparseFloatVector}},
Functions: []*model.Function{{Name: "test", Type: schemapb.FunctionType_BM25, InputFieldNames: []string{"text"}, OutputFieldNames: []string{"sparse"}}},
State: pb.CollectionState_CollectionCreating,
}
err := kc.CreateCollection(ctx, coll, 100)
assert.NoError(t, err)
})
}
func TestCatalog_DropCollection(t *testing.T) {
@ -1281,6 +1316,22 @@ func TestCatalog_DropCollection(t *testing.T) {
err := kc.DropCollection(ctx, coll, 100)
assert.NoError(t, err)
})
t.Run("drop collection with function", func(t *testing.T) {
mockSnapshot := newMockSnapshot(t, withMockMultiSaveAndRemove(nil))
kc := &Catalog{Snapshot: mockSnapshot}
ctx := context.Background()
coll := &model.Collection{
Partitions: []*model.Partition{
{PartitionName: "test"},
},
Fields: []*model.Field{{Name: "text", DataType: schemapb.DataType_VarChar}, {Name: "sparse", DataType: schemapb.DataType_SparseFloatVector}},
Functions: []*model.Function{{Name: "test", Type: schemapb.FunctionType_BM25, InputFieldNames: []string{"text"}, OutputFieldNames: []string{"sparse"}}},
State: pb.CollectionState_CollectionDropping,
}
err := kc.DropCollection(ctx, coll, 100)
assert.NoError(t, err)
})
}
func getUserInfoMetaString(username string) string {
@ -2779,3 +2830,15 @@ func TestCatalog_AlterDatabase(t *testing.T) {
err = c.AlterDatabase(ctx, newDB, typeutil.ZeroTimestamp)
assert.ErrorIs(t, err, mockErr)
}
func TestCatalog_listFunctionError(t *testing.T) {
mockSnapshot := newMockSnapshot(t)
kc := &Catalog{Snapshot: mockSnapshot}
mockSnapshot.EXPECT().LoadWithPrefix(mock.Anything, mock.Anything).Return(nil, nil, fmt.Errorf("mock error"))
_, err := kc.listFunctions(1, 1)
assert.Error(t, err)
mockSnapshot.EXPECT().LoadWithPrefix(mock.Anything, mock.Anything).Return([]string{"test-key"}, []string{"invalid bytes"}, nil)
_, err = kc.listFunctions(1, 1)
assert.Error(t, err)
}

View File

@ -20,6 +20,7 @@ const (
PartitionMetaPrefix = ComponentPrefix + "/partitions"
AliasMetaPrefix = ComponentPrefix + "/aliases"
FieldMetaPrefix = ComponentPrefix + "/fields"
FunctionMetaPrefix = ComponentPrefix + "/functions"
// CollectionAliasMetaPrefix210 prefix for collection alias meta
CollectionAliasMetaPrefix210 = ComponentPrefix + "/collection-alias"

View File

@ -18,6 +18,7 @@ type Collection struct {
Description string
AutoID bool
Fields []*Field
Functions []*Function
VirtualChannelNames []string
PhysicalChannelNames []string
ShardsNum int32
@ -54,6 +55,7 @@ func (c *Collection) Clone() *Collection {
Properties: common.CloneKeyValuePairs(c.Properties),
State: c.State,
EnableDynamicField: c.EnableDynamicField,
Functions: CloneFunctions(c.Functions),
}
}

View File

@ -12,14 +12,16 @@ import (
)
var (
colID int64 = 1
colName = "c"
fieldID int64 = 101
fieldName = "field110"
partID int64 = 20
partName = "testPart"
tenantID = "tenant-1"
typeParams = []*commonpb.KeyValuePair{
colID int64 = 1
colName = "c"
fieldID int64 = 101
fieldName = "field110"
partID int64 = 20
partName = "testPart"
tenantID = "tenant-1"
functionID int64 = 1
functionName = "test-bm25"
typeParams = []*commonpb.KeyValuePair{
{
Key: "field110-k1",
Value: "field110-v1",

View File

@ -7,21 +7,22 @@ import (
)
type Field struct {
FieldID int64
Name string
IsPrimaryKey bool
Description string
DataType schemapb.DataType
TypeParams []*commonpb.KeyValuePair
IndexParams []*commonpb.KeyValuePair
AutoID bool
State schemapb.FieldState
IsDynamic bool
IsPartitionKey bool // partition key mode, multi logic partitions share a physical partition
IsClusteringKey bool
DefaultValue *schemapb.ValueField
ElementType schemapb.DataType
Nullable bool
FieldID int64
Name string
IsPrimaryKey bool
Description string
DataType schemapb.DataType
TypeParams []*commonpb.KeyValuePair
IndexParams []*commonpb.KeyValuePair
AutoID bool
State schemapb.FieldState
IsDynamic bool
IsPartitionKey bool // partition key mode, multi logic partitions share a physical partition
IsClusteringKey bool
IsFunctionOutput bool
DefaultValue *schemapb.ValueField
ElementType schemapb.DataType
Nullable bool
}
func (f *Field) Available() bool {
@ -30,21 +31,22 @@ func (f *Field) Available() bool {
func (f *Field) Clone() *Field {
return &Field{
FieldID: f.FieldID,
Name: f.Name,
IsPrimaryKey: f.IsPrimaryKey,
Description: f.Description,
DataType: f.DataType,
TypeParams: common.CloneKeyValuePairs(f.TypeParams),
IndexParams: common.CloneKeyValuePairs(f.IndexParams),
AutoID: f.AutoID,
State: f.State,
IsDynamic: f.IsDynamic,
IsPartitionKey: f.IsPartitionKey,
IsClusteringKey: f.IsClusteringKey,
DefaultValue: f.DefaultValue,
ElementType: f.ElementType,
Nullable: f.Nullable,
FieldID: f.FieldID,
Name: f.Name,
IsPrimaryKey: f.IsPrimaryKey,
Description: f.Description,
DataType: f.DataType,
TypeParams: common.CloneKeyValuePairs(f.TypeParams),
IndexParams: common.CloneKeyValuePairs(f.IndexParams),
AutoID: f.AutoID,
State: f.State,
IsDynamic: f.IsDynamic,
IsPartitionKey: f.IsPartitionKey,
IsClusteringKey: f.IsClusteringKey,
IsFunctionOutput: f.IsFunctionOutput,
DefaultValue: f.DefaultValue,
ElementType: f.ElementType,
Nullable: f.Nullable,
}
}
@ -75,6 +77,7 @@ func (f *Field) Equal(other Field) bool {
f.IsClusteringKey == other.IsClusteringKey &&
f.DefaultValue == other.DefaultValue &&
f.ElementType == other.ElementType &&
f.IsFunctionOutput == other.IsFunctionOutput &&
f.Nullable == other.Nullable
}
@ -97,20 +100,21 @@ func MarshalFieldModel(field *Field) *schemapb.FieldSchema {
}
return &schemapb.FieldSchema{
FieldID: field.FieldID,
Name: field.Name,
IsPrimaryKey: field.IsPrimaryKey,
Description: field.Description,
DataType: field.DataType,
TypeParams: field.TypeParams,
IndexParams: field.IndexParams,
AutoID: field.AutoID,
IsDynamic: field.IsDynamic,
IsPartitionKey: field.IsPartitionKey,
IsClusteringKey: field.IsClusteringKey,
DefaultValue: field.DefaultValue,
ElementType: field.ElementType,
Nullable: field.Nullable,
FieldID: field.FieldID,
Name: field.Name,
IsPrimaryKey: field.IsPrimaryKey,
Description: field.Description,
DataType: field.DataType,
TypeParams: field.TypeParams,
IndexParams: field.IndexParams,
AutoID: field.AutoID,
IsDynamic: field.IsDynamic,
IsPartitionKey: field.IsPartitionKey,
IsClusteringKey: field.IsClusteringKey,
IsFunctionOutput: field.IsFunctionOutput,
DefaultValue: field.DefaultValue,
ElementType: field.ElementType,
Nullable: field.Nullable,
}
}
@ -132,20 +136,21 @@ func UnmarshalFieldModel(fieldSchema *schemapb.FieldSchema) *Field {
}
return &Field{
FieldID: fieldSchema.FieldID,
Name: fieldSchema.Name,
IsPrimaryKey: fieldSchema.IsPrimaryKey,
Description: fieldSchema.Description,
DataType: fieldSchema.DataType,
TypeParams: fieldSchema.TypeParams,
IndexParams: fieldSchema.IndexParams,
AutoID: fieldSchema.AutoID,
IsDynamic: fieldSchema.IsDynamic,
IsPartitionKey: fieldSchema.IsPartitionKey,
IsClusteringKey: fieldSchema.IsClusteringKey,
DefaultValue: fieldSchema.DefaultValue,
ElementType: fieldSchema.ElementType,
Nullable: fieldSchema.Nullable,
FieldID: fieldSchema.FieldID,
Name: fieldSchema.Name,
IsPrimaryKey: fieldSchema.IsPrimaryKey,
Description: fieldSchema.Description,
DataType: fieldSchema.DataType,
TypeParams: fieldSchema.TypeParams,
IndexParams: fieldSchema.IndexParams,
AutoID: fieldSchema.AutoID,
IsDynamic: fieldSchema.IsDynamic,
IsPartitionKey: fieldSchema.IsPartitionKey,
IsClusteringKey: fieldSchema.IsClusteringKey,
IsFunctionOutput: fieldSchema.IsFunctionOutput,
DefaultValue: fieldSchema.DefaultValue,
ElementType: fieldSchema.ElementType,
Nullable: fieldSchema.Nullable,
}
}

View File

@ -0,0 +1,120 @@
package model
import (
"slices"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
)
type Function struct {
Name string
ID int64
Description string
Type schemapb.FunctionType
InputFieldIDs []int64
InputFieldNames []string
OutputFieldIDs []int64
OutputFieldNames []string
Params []*commonpb.KeyValuePair
}
func (f *Function) Clone() *Function {
return &Function{
Name: f.Name,
ID: f.ID,
Description: f.Description,
Type: f.Type,
InputFieldIDs: f.InputFieldIDs,
InputFieldNames: f.InputFieldNames,
OutputFieldIDs: f.OutputFieldIDs,
OutputFieldNames: f.OutputFieldNames,
Params: f.Params,
}
}
func (f *Function) Equal(other Function) bool {
return f.Name == other.Name &&
f.Type == other.Type &&
f.Description == other.Description &&
slices.Equal(f.InputFieldNames, other.InputFieldNames) &&
slices.Equal(f.InputFieldIDs, other.InputFieldIDs) &&
slices.Equal(f.OutputFieldNames, other.OutputFieldNames) &&
slices.Equal(f.OutputFieldIDs, other.OutputFieldIDs) &&
slices.Equal(f.Params, other.Params)
}
func CloneFunctions(functions []*Function) []*Function {
clone := make([]*Function, len(functions))
for i, function := range functions {
clone[i] = function.Clone()
}
return functions
}
func MarshalFunctionModel(function *Function) *schemapb.FunctionSchema {
if function == nil {
return nil
}
return &schemapb.FunctionSchema{
Name: function.Name,
Id: function.ID,
Description: function.Description,
Type: function.Type,
InputFieldIds: function.InputFieldIDs,
InputFieldNames: function.InputFieldNames,
OutputFieldIds: function.OutputFieldIDs,
OutputFieldNames: function.OutputFieldNames,
Params: function.Params,
}
}
func UnmarshalFunctionModel(schema *schemapb.FunctionSchema) *Function {
if schema == nil {
return nil
}
return &Function{
Name: schema.GetName(),
ID: schema.GetId(),
Description: schema.GetDescription(),
Type: schema.GetType(),
InputFieldIDs: schema.GetInputFieldIds(),
InputFieldNames: schema.GetInputFieldNames(),
OutputFieldIDs: schema.GetOutputFieldIds(),
OutputFieldNames: schema.GetOutputFieldNames(),
Params: schema.GetParams(),
}
}
func MarshalFunctionModels(functions []*Function) []*schemapb.FunctionSchema {
if functions == nil {
return nil
}
functionSchemas := make([]*schemapb.FunctionSchema, len(functions))
for idx, function := range functions {
functionSchemas[idx] = MarshalFunctionModel(function)
}
return functionSchemas
}
func UnmarshalFunctionModels(functions []*schemapb.FunctionSchema) []*Function {
if functions == nil {
return nil
}
functionSchemas := make([]*Function, len(functions))
for idx, function := range functions {
functionSchemas[idx] = UnmarshalFunctionModel(function)
}
return functionSchemas
}

View File

@ -0,0 +1,81 @@
package model
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
)
var (
functionSchemaPb = &schemapb.FunctionSchema{
Id: functionID,
Name: functionName,
Type: schemapb.FunctionType_BM25,
InputFieldIds: []int64{101},
InputFieldNames: []string{"text"},
OutputFieldIds: []int64{103},
OutputFieldNames: []string{"sparse"},
}
functionModel = &Function{
ID: functionID,
Name: functionName,
Type: schemapb.FunctionType_BM25,
InputFieldIDs: []int64{101},
InputFieldNames: []string{"text"},
OutputFieldIDs: []int64{103},
OutputFieldNames: []string{"sparse"},
}
)
func TestMarshalFunctionModel(t *testing.T) {
ret := MarshalFunctionModel(functionModel)
assert.Equal(t, functionSchemaPb, ret)
assert.Nil(t, MarshalFunctionModel(nil))
}
func TestMarshalFunctionModels(t *testing.T) {
ret := MarshalFunctionModels([]*Function{functionModel})
assert.Equal(t, []*schemapb.FunctionSchema{functionSchemaPb}, ret)
assert.Nil(t, MarshalFunctionModels(nil))
}
func TestUnmarshalFunctionModel(t *testing.T) {
ret := UnmarshalFunctionModel(functionSchemaPb)
assert.Equal(t, functionModel, ret)
assert.Nil(t, UnmarshalFunctionModel(nil))
}
func TestUnmarshalFunctionModels(t *testing.T) {
ret := UnmarshalFunctionModels([]*schemapb.FunctionSchema{functionSchemaPb})
assert.Equal(t, []*Function{functionModel}, ret)
assert.Nil(t, UnmarshalFunctionModels(nil))
}
func TestFunctionEqual(t *testing.T) {
EqualFunction := Function{
ID: functionID,
Name: functionName,
Type: schemapb.FunctionType_BM25,
InputFieldIDs: []int64{101},
InputFieldNames: []string{"text"},
OutputFieldIDs: []int64{103},
OutputFieldNames: []string{"sparse"},
}
NoEqualFunction := Function{
ID: functionID,
Name: functionName,
Type: schemapb.FunctionType_BM25,
InputFieldIDs: []int64{101},
InputFieldNames: []string{"text"},
OutputFieldIDs: []int64{102},
OutputFieldNames: []string{"sparse"},
}
assert.True(t, functionModel.Equal(EqualFunction))
assert.True(t, functionModel.Equal(*functionModel.Clone()))
assert.False(t, functionModel.Equal(NoEqualFunction))
}

View File

@ -719,6 +719,7 @@ func (m *MetaCache) describeCollection(ctx context.Context, database, collection
Description: coll.Schema.Description,
AutoID: coll.Schema.AutoID,
Fields: make([]*schemapb.FieldSchema, 0),
Functions: make([]*schemapb.FunctionSchema, 0),
EnableDynamicField: coll.Schema.EnableDynamicField,
},
CollectionID: coll.CollectionID,
@ -735,6 +736,8 @@ func (m *MetaCache) describeCollection(ctx context.Context, database, collection
resp.Schema.Fields = append(resp.Schema.Fields, field)
}
}
resp.Schema.Functions = append(resp.Schema.Functions, coll.Schema.Functions...)
return resp, nil
}

View File

@ -213,9 +213,10 @@ func TestMetaCache_GetCollection(t *testing.T) {
assert.Equal(t, rootCoord.GetAccessCount(), 1)
assert.NoError(t, err)
assert.Equal(t, schema.CollectionSchema, &schemapb.CollectionSchema{
AutoID: true,
Fields: []*schemapb.FieldSchema{},
Name: "collection1",
AutoID: true,
Fields: []*schemapb.FieldSchema{},
Functions: []*schemapb.FunctionSchema{},
Name: "collection1",
})
id, err = globalMetaCache.GetCollectionID(ctx, dbName, "collection2")
assert.Equal(t, rootCoord.GetAccessCount(), 2)
@ -225,9 +226,10 @@ func TestMetaCache_GetCollection(t *testing.T) {
assert.Equal(t, rootCoord.GetAccessCount(), 2)
assert.NoError(t, err)
assert.Equal(t, schema.CollectionSchema, &schemapb.CollectionSchema{
AutoID: true,
Fields: []*schemapb.FieldSchema{},
Name: "collection2",
AutoID: true,
Fields: []*schemapb.FieldSchema{},
Functions: []*schemapb.FunctionSchema{},
Name: "collection2",
})
// test to get from cache, this should trigger root request
@ -239,9 +241,10 @@ func TestMetaCache_GetCollection(t *testing.T) {
assert.Equal(t, rootCoord.GetAccessCount(), 2)
assert.NoError(t, err)
assert.Equal(t, schema.CollectionSchema, &schemapb.CollectionSchema{
AutoID: true,
Fields: []*schemapb.FieldSchema{},
Name: "collection1",
AutoID: true,
Fields: []*schemapb.FieldSchema{},
Functions: []*schemapb.FunctionSchema{},
Name: "collection1",
})
}
@ -298,9 +301,10 @@ func TestMetaCache_GetCollectionName(t *testing.T) {
assert.Equal(t, rootCoord.GetAccessCount(), 1)
assert.NoError(t, err)
assert.Equal(t, schema.CollectionSchema, &schemapb.CollectionSchema{
AutoID: true,
Fields: []*schemapb.FieldSchema{},
Name: "collection1",
AutoID: true,
Fields: []*schemapb.FieldSchema{},
Functions: []*schemapb.FunctionSchema{},
Name: "collection1",
})
collection, err = globalMetaCache.GetCollectionName(ctx, GetCurDBNameFromContextOrDefault(ctx), 1)
assert.Equal(t, rootCoord.GetAccessCount(), 1)
@ -310,9 +314,10 @@ func TestMetaCache_GetCollectionName(t *testing.T) {
assert.Equal(t, rootCoord.GetAccessCount(), 2)
assert.NoError(t, err)
assert.Equal(t, schema.CollectionSchema, &schemapb.CollectionSchema{
AutoID: true,
Fields: []*schemapb.FieldSchema{},
Name: "collection2",
AutoID: true,
Fields: []*schemapb.FieldSchema{},
Functions: []*schemapb.FunctionSchema{},
Name: "collection2",
})
// test to get from cache, this should trigger root request
@ -324,9 +329,10 @@ func TestMetaCache_GetCollectionName(t *testing.T) {
assert.Equal(t, rootCoord.GetAccessCount(), 2)
assert.NoError(t, err)
assert.Equal(t, schema.CollectionSchema, &schemapb.CollectionSchema{
AutoID: true,
Fields: []*schemapb.FieldSchema{},
Name: "collection1",
AutoID: true,
Fields: []*schemapb.FieldSchema{},
Functions: []*schemapb.FunctionSchema{},
Name: "collection1",
})
}
@ -349,18 +355,20 @@ func TestMetaCache_GetCollectionFailure(t *testing.T) {
schema, err = globalMetaCache.GetCollectionSchema(ctx, dbName, "collection1")
assert.NoError(t, err)
assert.Equal(t, schema.CollectionSchema, &schemapb.CollectionSchema{
AutoID: true,
Fields: []*schemapb.FieldSchema{},
Name: "collection1",
AutoID: true,
Fields: []*schemapb.FieldSchema{},
Functions: []*schemapb.FunctionSchema{},
Name: "collection1",
})
rootCoord.Error = true
// should be cached with no error
assert.NoError(t, err)
assert.Equal(t, schema.CollectionSchema, &schemapb.CollectionSchema{
AutoID: true,
Fields: []*schemapb.FieldSchema{},
Name: "collection1",
AutoID: true,
Fields: []*schemapb.FieldSchema{},
Functions: []*schemapb.FunctionSchema{},
Name: "collection1",
})
}
@ -422,9 +430,10 @@ func TestMetaCache_ConcurrentTest1(t *testing.T) {
schema, err := globalMetaCache.GetCollectionSchema(ctx, dbName, "collection1")
assert.NoError(t, err)
assert.Equal(t, schema.CollectionSchema, &schemapb.CollectionSchema{
AutoID: true,
Fields: []*schemapb.FieldSchema{},
Name: "collection1",
AutoID: true,
Fields: []*schemapb.FieldSchema{},
Functions: []*schemapb.FunctionSchema{},
Name: "collection1",
})
time.Sleep(10 * time.Millisecond)
}
@ -1071,6 +1080,7 @@ func TestSchemaInfo_GetLoadFieldIDs(t *testing.T) {
vectorField,
dynamicField,
},
Functions: []*schemapb.FunctionSchema{},
},
loadFields: nil,
skipDynamicField: false,
@ -1091,6 +1101,7 @@ func TestSchemaInfo_GetLoadFieldIDs(t *testing.T) {
dynamicField,
clusteringKeyField,
},
Functions: []*schemapb.FunctionSchema{},
},
loadFields: nil,
skipDynamicField: false,
@ -1111,6 +1122,7 @@ func TestSchemaInfo_GetLoadFieldIDs(t *testing.T) {
dynamicField,
clusteringKeyField,
},
Functions: []*schemapb.FunctionSchema{},
},
loadFields: []string{"pk", "part_key", "vector", "clustering_key"},
skipDynamicField: false,
@ -1130,6 +1142,7 @@ func TestSchemaInfo_GetLoadFieldIDs(t *testing.T) {
vectorField,
dynamicField,
},
Functions: []*schemapb.FunctionSchema{},
},
loadFields: []string{"pk", "part_key", "vector"},
skipDynamicField: true,
@ -1149,6 +1162,7 @@ func TestSchemaInfo_GetLoadFieldIDs(t *testing.T) {
vectorField,
dynamicField,
},
Functions: []*schemapb.FunctionSchema{},
},
loadFields: []string{"part_key", "vector"},
skipDynamicField: true,
@ -1167,6 +1181,7 @@ func TestSchemaInfo_GetLoadFieldIDs(t *testing.T) {
vectorField,
dynamicField,
},
Functions: []*schemapb.FunctionSchema{},
},
loadFields: []string{"pk", "vector"},
skipDynamicField: true,
@ -1185,6 +1200,7 @@ func TestSchemaInfo_GetLoadFieldIDs(t *testing.T) {
vectorField,
dynamicField,
},
Functions: []*schemapb.FunctionSchema{},
},
loadFields: []string{"pk", "part_key"},
skipDynamicField: true,
@ -1203,6 +1219,7 @@ func TestSchemaInfo_GetLoadFieldIDs(t *testing.T) {
vectorField,
clusteringKeyField,
},
Functions: []*schemapb.FunctionSchema{},
},
loadFields: []string{"pk", "part_key", "vector"},
expectErr: true,

View File

@ -301,6 +301,10 @@ func (t *createCollectionTask) PreExecute(ctx context.Context) error {
}
t.schema.AutoID = false
if err := validateFunction(t.schema); err != nil {
return err
}
if t.ShardsNum > Params.ProxyCfg.MaxShardNum.GetAsInt32() {
return fmt.Errorf("maximum shards's number should be limited to %d", Params.ProxyCfg.MaxShardNum.GetAsInt())
}
@ -632,6 +636,7 @@ func (t *describeCollectionTask) Execute(ctx context.Context) error {
Description: "",
AutoID: false,
Fields: make([]*schemapb.FieldSchema, 0),
Functions: make([]*schemapb.FunctionSchema, 0),
},
CollectionID: 0,
VirtualChannelNames: nil,
@ -681,23 +686,28 @@ func (t *describeCollectionTask) Execute(ctx context.Context) error {
}
if field.FieldID >= common.StartOfUserFieldID {
t.result.Schema.Fields = append(t.result.Schema.Fields, &schemapb.FieldSchema{
FieldID: field.FieldID,
Name: field.Name,
IsPrimaryKey: field.IsPrimaryKey,
AutoID: field.AutoID,
Description: field.Description,
DataType: field.DataType,
TypeParams: field.TypeParams,
IndexParams: field.IndexParams,
IsDynamic: field.IsDynamic,
IsPartitionKey: field.IsPartitionKey,
IsClusteringKey: field.IsClusteringKey,
DefaultValue: field.DefaultValue,
ElementType: field.ElementType,
Nullable: field.Nullable,
FieldID: field.FieldID,
Name: field.Name,
IsPrimaryKey: field.IsPrimaryKey,
AutoID: field.AutoID,
Description: field.Description,
DataType: field.DataType,
TypeParams: field.TypeParams,
IndexParams: field.IndexParams,
IsDynamic: field.IsDynamic,
IsPartitionKey: field.IsPartitionKey,
IsClusteringKey: field.IsClusteringKey,
DefaultValue: field.DefaultValue,
ElementType: field.ElementType,
Nullable: field.Nullable,
IsFunctionOutput: field.IsFunctionOutput,
})
}
}
for _, function := range result.Schema.Functions {
t.result.Schema.Functions = append(t.result.Schema.Functions, proto.Clone(function).(*schemapb.FunctionSchema))
}
return nil
}

View File

@ -71,6 +71,7 @@ type createIndexTask struct {
newExtraParams []*commonpb.KeyValuePair
collectionID UniqueID
functionSchema *schemapb.FunctionSchema
fieldSchema *schemapb.FieldSchema
userAutoIndexMetricTypeSpecified bool
}
@ -129,6 +130,48 @@ func wrapUserIndexParams(metricType string) []*commonpb.KeyValuePair {
}
}
func (cit *createIndexTask) parseFunctionParamsToIndex(indexParamsMap map[string]string) error {
if !cit.fieldSchema.GetIsFunctionOutput() {
return nil
}
switch cit.functionSchema.GetType() {
case schemapb.FunctionType_BM25:
for _, kv := range cit.functionSchema.GetParams() {
switch kv.GetKey() {
case "bm25_k1":
if _, ok := indexParamsMap["bm25_k1"]; !ok {
indexParamsMap["bm25_k1"] = kv.GetValue()
}
case "bm25_b":
if _, ok := indexParamsMap["bm25_b"]; !ok {
indexParamsMap["bm25_b"] = kv.GetValue()
}
case "bm25_avgdl":
if _, ok := indexParamsMap["bm25_avgdl"]; !ok {
indexParamsMap["bm25_avgdl"] = kv.GetValue()
}
}
}
// set default avgdl
if _, ok := indexParamsMap["bm25_k1"]; !ok {
indexParamsMap["bm25_k1"] = "1.2"
}
if _, ok := indexParamsMap["bm25_b"]; !ok {
indexParamsMap["bm25_b"] = "0.75"
}
if _, ok := indexParamsMap["bm25_avgdl"]; !ok {
indexParamsMap["bm25_avgdl"] = "100"
}
default:
return fmt.Errorf("parse unknown type function params to index")
}
return nil
}
func (cit *createIndexTask) parseIndexParams() error {
cit.newExtraParams = cit.req.GetExtraParams()
@ -149,6 +192,11 @@ func (cit *createIndexTask) parseIndexParams() error {
}
}
// fill index param for bm25 function
if err := cit.parseFunctionParamsToIndex(indexParamsMap); err != nil {
return err
}
if err := ValidateAutoIndexMmapConfig(isVecIndex, indexParamsMap); err != nil {
return err
}
@ -353,18 +401,29 @@ func (cit *createIndexTask) parseIndexParams() error {
return nil
}
func (cit *createIndexTask) getIndexedField(ctx context.Context) (*schemapb.FieldSchema, error) {
func (cit *createIndexTask) getIndexedFieldAndFunction(ctx context.Context) error {
schema, err := globalMetaCache.GetCollectionSchema(ctx, cit.req.GetDbName(), cit.req.GetCollectionName())
if err != nil {
log.Error("failed to get collection schema", zap.Error(err))
return nil, fmt.Errorf("failed to get collection schema: %s", err)
return fmt.Errorf("failed to get collection schema: %s", err)
}
field, err := schema.schemaHelper.GetFieldFromName(cit.req.GetFieldName())
if err != nil {
log.Error("create index on non-exist field", zap.Error(err))
return nil, fmt.Errorf("cannot create index on non-exist field: %s", cit.req.GetFieldName())
return fmt.Errorf("cannot create index on non-exist field: %s", cit.req.GetFieldName())
}
return field, nil
if field.IsFunctionOutput {
function, err := schema.schemaHelper.GetFunctionByOutputField(field)
if err != nil {
log.Error("create index failed, cannot find function of function output field", zap.Error(err))
return fmt.Errorf("create index failed, cannot find function of function output field: %s", cit.req.GetFieldName())
}
cit.functionSchema = function
}
cit.fieldSchema = field
return nil
}
func fillDimension(field *schemapb.FieldSchema, indexParams map[string]string) error {
@ -452,11 +511,11 @@ func (cit *createIndexTask) PreExecute(ctx context.Context) error {
return err
}
field, err := cit.getIndexedField(ctx)
err = cit.getIndexedFieldAndFunction(ctx)
if err != nil {
return err
}
cit.fieldSchema = field
// check index param, not accurate, only some static rules
err = cit.parseIndexParams()
if err != nil {

View File

@ -2170,7 +2170,7 @@ func TestTask_VarCharPrimaryKey(t *testing.T) {
})
}
func Test_createIndexTask_getIndexedField(t *testing.T) {
func Test_createIndexTask_getIndexedFieldAndFunction(t *testing.T) {
collectionName := "test"
fieldName := "test"
@ -2224,9 +2224,9 @@ func Test_createIndexTask_getIndexedField(t *testing.T) {
}), nil)
globalMetaCache = cache
field, err := cit.getIndexedField(context.Background())
err := cit.getIndexedFieldAndFunction(context.Background())
assert.NoError(t, err)
assert.Equal(t, fieldName, field.GetName())
assert.Equal(t, fieldName, cit.fieldSchema.GetName())
})
t.Run("schema not found", func(t *testing.T) {
@ -2237,7 +2237,7 @@ func Test_createIndexTask_getIndexedField(t *testing.T) {
mock.AnythingOfType("string"),
).Return(nil, errors.New("mock"))
globalMetaCache = cache
_, err := cit.getIndexedField(context.Background())
err := cit.getIndexedFieldAndFunction(context.Background())
assert.Error(t, err)
})
@ -2256,7 +2256,7 @@ func Test_createIndexTask_getIndexedField(t *testing.T) {
},
}), nil)
globalMetaCache = cache
_, err := cit.getIndexedField(context.Background())
err := cit.getIndexedFieldAndFunction(context.Background())
assert.Error(t, err)
})
}
@ -3128,6 +3128,10 @@ func TestCreateCollectionTaskWithPartitionKey(t *testing.T) {
},
},
}
sparseVecField := &schemapb.FieldSchema{
Name: "sparse",
DataType: schemapb.DataType_SparseFloatVector,
}
partitionKeyField := &schemapb.FieldSchema{
Name: "partition_key",
DataType: schemapb.DataType_Int64,
@ -3236,6 +3240,28 @@ func TestCreateCollectionTaskWithPartitionKey(t *testing.T) {
task.Schema = marshaledSchema
err = task.PreExecute(ctx)
assert.NoError(t, err)
// test schema with function
// invalid function
schema.Functions = []*schemapb.FunctionSchema{
{Name: "test", Type: schemapb.FunctionType_BM25, InputFieldNames: []string{"invalid name"}},
}
marshaledSchema, err = proto.Marshal(schema)
assert.NoError(t, err)
task.Schema = marshaledSchema
err = task.PreExecute(ctx)
assert.Error(t, err)
// normal case
schema.Fields = append(schema.Fields, sparseVecField)
schema.Functions = []*schemapb.FunctionSchema{
{Name: "test", Type: schemapb.FunctionType_BM25, InputFieldNames: []string{varCharField.Name}, OutputFieldNames: []string{sparseVecField.Name}},
}
marshaledSchema, err = proto.Marshal(schema)
assert.NoError(t, err)
task.Schema = marshaledSchema
err = task.PreExecute(ctx)
assert.NoError(t, err)
})
t.Run("Execute", func(t *testing.T) {

View File

@ -25,6 +25,7 @@ import (
"time"
"github.com/cockroachdb/errors"
"github.com/samber/lo"
"go.uber.org/zap"
"golang.org/x/crypto/bcrypt"
"google.golang.org/grpc/metadata"
@ -609,6 +610,143 @@ func validateSchema(coll *schemapb.CollectionSchema) error {
return nil
}
func validateFunction(coll *schemapb.CollectionSchema) error {
nameMap := lo.SliceToMap(coll.GetFields(), func(field *schemapb.FieldSchema) (string, *schemapb.FieldSchema) {
return field.GetName(), field
})
usedOutputField := typeutil.NewSet[string]()
usedFunctionName := typeutil.NewSet[string]()
// validate function
for _, function := range coll.GetFunctions() {
if usedFunctionName.Contain(function.GetName()) {
return fmt.Errorf("duplicate function name %s", function.GetName())
}
usedFunctionName.Insert(function.GetName())
inputFields := []*schemapb.FieldSchema{}
for _, name := range function.GetInputFieldNames() {
inputField, ok := nameMap[name]
if !ok {
return fmt.Errorf("function input field not found %s", function.InputFieldNames)
}
inputFields = append(inputFields, inputField)
}
err := checkFunctionInputField(function, inputFields)
if err != nil {
return err
}
outputFields := make([]*schemapb.FieldSchema, len(function.GetOutputFieldNames()))
for i, name := range function.GetOutputFieldNames() {
outputField, ok := nameMap[name]
if !ok {
return fmt.Errorf("function output field not found %s", function.InputFieldNames)
}
outputField.IsFunctionOutput = true
outputFields[i] = outputField
if usedOutputField.Contain(name) {
return fmt.Errorf("duplicate function output %s", name)
}
usedOutputField.Insert(name)
}
if err := checkFunctionOutputField(function, outputFields); err != nil {
return err
}
if err := checkFunctionParams(function); err != nil {
return err
}
}
return nil
}
func checkFunctionOutputField(function *schemapb.FunctionSchema, fields []*schemapb.FieldSchema) error {
switch function.GetType() {
case schemapb.FunctionType_BM25:
if len(fields) != 1 {
return fmt.Errorf("bm25 only need 1 output field, but now %d", len(fields))
}
if !typeutil.IsSparseFloatVectorType(fields[0].GetDataType()) {
return fmt.Errorf("bm25 only need sparse embedding output field, but now %s", fields[0].DataType.String())
}
if fields[0].GetIsPrimaryKey() {
return fmt.Errorf("bm25 output field can't be primary key")
}
if fields[0].GetIsPartitionKey() || fields[0].GetIsClusteringKey() {
return fmt.Errorf("bm25 output field can't be partition key or cluster key field")
}
default:
return fmt.Errorf("check output field for unknown function type")
}
return nil
}
func checkFunctionInputField(function *schemapb.FunctionSchema, fields []*schemapb.FieldSchema) error {
switch function.GetType() {
case schemapb.FunctionType_BM25:
if len(fields) != 1 || fields[0].DataType != schemapb.DataType_VarChar {
return fmt.Errorf("only one VARCHAR input field is allowed for a BM25 Function, got %d field with type %s",
len(fields), fields[0].DataType.String())
}
default:
return fmt.Errorf("check input field with unknown function type")
}
return nil
}
func checkFunctionParams(function *schemapb.FunctionSchema) error {
switch function.GetType() {
case schemapb.FunctionType_BM25:
for _, kv := range function.GetParams() {
switch kv.GetKey() {
case "bm25_k1":
k1, err := strconv.ParseFloat(kv.GetValue(), 64)
if err != nil {
return fmt.Errorf("failed to parse bm25_k1 value, %w", err)
}
if k1 < 0 || k1 > 3 {
return fmt.Errorf("bm25_k1 must in [0,3] but now %f", k1)
}
case "bm25_b":
b, err := strconv.ParseFloat(kv.GetValue(), 64)
if err != nil {
return fmt.Errorf("failed to parse bm25_b value, %w", err)
}
if b < 0 || b > 1 {
return fmt.Errorf("bm25_b must in [0,1] but now %f", b)
}
case "bm25_avgdl":
avgdl, err := strconv.ParseFloat(kv.GetValue(), 64)
if err != nil {
return fmt.Errorf("failed to parse bm25_avgdl value, %w", err)
}
if avgdl <= 0 {
return fmt.Errorf("bm25_avgdl must large than zero but now %f", avgdl)
}
case "analyzer_params":
// TODO ADD tokenizer check
default:
return fmt.Errorf("invalid function params, key: %s, value:%s", kv.GetKey(), kv.GetValue())
}
}
default:
return fmt.Errorf("check function params with unknown function type")
}
return nil
}
// validateMultipleVectorFields check if schema has multiple vector fields.
func validateMultipleVectorFields(schema *schemapb.CollectionSchema) error {
vecExist := false
@ -754,13 +892,19 @@ func autoGenDynamicFieldData(data [][]byte) *schemapb.FieldData {
// fillFieldIDBySchema set fieldID to fieldData according FieldSchemas
func fillFieldIDBySchema(columns []*schemapb.FieldData, schema *schemapb.CollectionSchema) error {
if len(columns) != len(schema.GetFields()) {
return fmt.Errorf("len(columns) mismatch the len(fields), len(columns): %d, len(fields): %d",
len(columns), len(schema.GetFields()))
}
fieldName2Schema := make(map[string]*schemapb.FieldSchema)
expectColumnNum := 0
for _, field := range schema.GetFields() {
fieldName2Schema[field.Name] = field
if !field.GetIsFunctionOutput() {
expectColumnNum++
}
}
if len(columns) != expectColumnNum {
return fmt.Errorf("len(columns) mismatch the expectColumnNum, expectColumnNum: %d, len(columns): %d",
expectColumnNum, len(columns))
}
for _, fieldData := range columns {
@ -1211,15 +1355,16 @@ func checkFieldsDataBySchema(schema *schemapb.CollectionSchema, insertMsg *msgst
if fieldSchema.GetDefaultValue() != nil && fieldSchema.IsPrimaryKey {
return merr.WrapErrParameterInvalidMsg("primary key can't be with default value")
}
if fieldSchema.IsPrimaryKey && fieldSchema.AutoID && !Params.ProxyCfg.SkipAutoIDCheck.GetAsBool() && inInsert {
if (fieldSchema.IsPrimaryKey && fieldSchema.AutoID && !Params.ProxyCfg.SkipAutoIDCheck.GetAsBool() && inInsert) || fieldSchema.GetIsFunctionOutput() {
// when inInsert, no need to pass when pk is autoid and SkipAutoIDCheck is false
autoGenFieldNum++
}
if _, ok := dataNameSet[fieldSchema.GetName()]; !ok {
if fieldSchema.IsPrimaryKey && fieldSchema.AutoID && !Params.ProxyCfg.SkipAutoIDCheck.GetAsBool() && inInsert {
if (fieldSchema.IsPrimaryKey && fieldSchema.AutoID && !Params.ProxyCfg.SkipAutoIDCheck.GetAsBool() && inInsert) || fieldSchema.GetIsFunctionOutput() {
// autoGenField
continue
}
if fieldSchema.GetDefaultValue() == nil && !fieldSchema.GetNullable() {
log.Warn("no corresponding fieldData pass in", zap.String("fieldSchema", fieldSchema.GetName()))
return merr.WrapErrParameterInvalidMsg("fieldSchema(%s) has no corresponding fieldData pass in", fieldSchema.GetName())

View File

@ -248,6 +248,7 @@ func (b *ServerBroker) BroadcastAlteredCollection(ctx context.Context, req *milv
Description: colMeta.Description,
AutoID: colMeta.AutoID,
Fields: model.MarshalFieldModels(colMeta.Fields),
Functions: model.MarshalFunctionModels(colMeta.Functions),
},
PartitionIDs: partitionIDs,
StartPositions: colMeta.StartPositions,

View File

@ -259,10 +259,34 @@ func (t *createCollectionTask) validateSchema(schema *schemapb.CollectionSchema)
return validateFieldDataType(schema)
}
func (t *createCollectionTask) assignFieldID(schema *schemapb.CollectionSchema) {
for idx := range schema.GetFields() {
schema.Fields[idx].FieldID = int64(idx + StartOfUserFieldID)
func (t *createCollectionTask) assignFieldAndFunctionID(schema *schemapb.CollectionSchema) error {
name2id := map[string]int64{}
for idx, field := range schema.GetFields() {
field.FieldID = int64(idx + StartOfUserFieldID)
name2id[field.GetName()] = field.GetFieldID()
}
for fidx, function := range schema.GetFunctions() {
function.InputFieldIds = make([]int64, len(function.InputFieldNames))
function.Id = int64(fidx) + StartOfUserFunctionID
for idx, name := range function.InputFieldNames {
fieldId, ok := name2id[name]
if !ok {
return fmt.Errorf("input field %s of function %s not found", name, function.GetName())
}
function.InputFieldIds[idx] = fieldId
}
function.OutputFieldIds = make([]int64, len(function.OutputFieldNames))
for idx, name := range function.OutputFieldNames {
fieldId, ok := name2id[name]
if !ok {
return fmt.Errorf("output field %s of function %s not found", name, function.GetName())
}
function.OutputFieldIds[idx] = fieldId
}
}
return nil
}
func (t *createCollectionTask) appendDynamicField(schema *schemapb.CollectionSchema) {
@ -303,7 +327,11 @@ func (t *createCollectionTask) prepareSchema() error {
return err
}
t.appendDynamicField(&schema)
t.assignFieldID(&schema)
if err := t.assignFieldAndFunctionID(&schema); err != nil {
return err
}
t.appendSysFields(&schema)
t.schema = &schema
return nil
@ -540,6 +568,7 @@ func (t *createCollectionTask) Execute(ctx context.Context) error {
Description: t.schema.Description,
AutoID: t.schema.AutoID,
Fields: model.UnmarshalFieldModels(t.schema.Fields),
Functions: model.UnmarshalFunctionModels(t.schema.Functions),
VirtualChannelNames: vchanNames,
PhysicalChannelNames: chanNames,
ShardsNum: t.Req.ShardsNum,
@ -609,6 +638,7 @@ func (t *createCollectionTask) Execute(ctx context.Context) error {
Description: collInfo.Description,
AutoID: collInfo.AutoID,
Fields: model.MarshalFieldModels(collInfo.Fields),
Functions: model.MarshalFunctionModels(collInfo.Functions),
},
},
}, &nullStep{})

View File

@ -29,6 +29,9 @@ const (
// StartOfUserFieldID id of user defined field begin from here
StartOfUserFieldID = common.StartOfUserFieldID
// StartOfUserFunctionID id of user defined function begin from here
StartOfUserFunctionID = common.StartOfUserFunctionID
// RowIDField id of row ID field
RowIDField = common.RowIDField

View File

@ -1124,6 +1124,7 @@ func convertModelToDesc(collInfo *model.Collection, aliases []string, dbName str
Description: collInfo.Description,
AutoID: collInfo.AutoID,
Fields: model.MarshalFieldModels(collInfo.Fields),
Functions: model.MarshalFunctionModels(collInfo.Functions),
EnableDynamicField: collInfo.EnableDynamicField,
}
resp.CollectionID = collInfo.CollectionID

View File

@ -42,6 +42,8 @@ const (
// StartOfUserFieldID represents the starting ID of the user-defined field
StartOfUserFieldID = 100
// StartOfUserFunctionID represents the starting ID of the user-defined function
StartOfUserFunctionID = 100
// RowIDField is the ID of the RowID field reserved by the system
RowIDField = 0

View File

@ -11,10 +11,11 @@ require (
github.com/confluentinc/confluent-kafka-go v1.9.1
github.com/containerd/cgroups/v3 v3.0.3
github.com/expr-lang/expr v1.15.7
github.com/golang/protobuf v1.5.4
github.com/grpc-ecosystem/go-grpc-middleware v1.3.0
github.com/json-iterator/go v1.1.12
github.com/klauspost/compress v1.17.7
github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240822040249-4bbc8f623cbb
github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240909041258-8f8ca67816cd
github.com/nats-io/nats-server/v2 v2.10.12
github.com/nats-io/nats.go v1.34.1
github.com/panjf2000/ants/v2 v2.7.2
@ -93,7 +94,6 @@ require (
github.com/godbus/dbus/v5 v5.0.4 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang-jwt/jwt v3.2.2+incompatible // indirect
github.com/golang/protobuf v1.5.4 // indirect
github.com/golang/snappy v0.0.4 // indirect
github.com/google/btree v1.1.2 // indirect
github.com/google/uuid v1.6.0 // indirect

View File

@ -494,8 +494,8 @@ github.com/milvus-io/cgosymbolizer v0.0.0-20240722103217-b7dee0e50119 h1:9VXijWu
github.com/milvus-io/cgosymbolizer v0.0.0-20240722103217-b7dee0e50119/go.mod h1:DvXTE/K/RtHehxU8/GtDs4vFtfw64jJ3PaCnFri8CRg=
github.com/milvus-io/gorocksdb v0.0.0-20220624081344-8c5f4212846b h1:TfeY0NxYxZzUfIfYe5qYDBzt4ZYRqzUjTR6CvUzjat8=
github.com/milvus-io/gorocksdb v0.0.0-20220624081344-8c5f4212846b/go.mod h1:iwW+9cWfIzzDseEBCCeDSN5SD16Tidvy8cwQ7ZY8Qj4=
github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240822040249-4bbc8f623cbb h1:S3QIkNv9N1Vd1UKtdaQ4yVDPFAwFiPSAjN07axzbR70=
github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240822040249-4bbc8f623cbb/go.mod h1:/6UT4zZl6awVeXLeE7UGDWZvXj3IWkRsh3mqsn0DiAs=
github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240909041258-8f8ca67816cd h1:x0b0+foTe23sKcVFseR1DE8+BB08EH6ViiRHaz8PEik=
github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240909041258-8f8ca67816cd/go.mod h1:/6UT4zZl6awVeXLeE7UGDWZvXj3IWkRsh3mqsn0DiAs=
github.com/milvus-io/pulsar-client-go v0.6.10 h1:eqpJjU+/QX0iIhEo3nhOqMNXL+TyInAs1IAHZCrCM/A=
github.com/milvus-io/pulsar-client-go v0.6.10/go.mod h1:lQqCkgwDF8YFYjKA+zOheTk1tev2B+bKj5j7+nm8M1w=
github.com/minio/highwayhash v1.0.2 h1:Aak5U0nElisjDCfPSG79Tgzkn2gl66NxOMspRrKnA/g=

View File

@ -323,7 +323,6 @@ func WrapErrAsInputErrorWhen(err error, targets ...milvusError) error {
if target.errCode == merr.errCode {
log.Info("mark error as input error", zap.Error(err))
WithErrorType(InputError)(&merr)
log.Info("test--", zap.String("type", merr.errType.String()))
return merr
}
}

View File

@ -429,6 +429,17 @@ func (helper *SchemaHelper) GetVectorDimFromID(fieldID int64) (int, error) {
return 0, fmt.Errorf("fieldID(%d) not has dim", fieldID)
}
func (helper *SchemaHelper) GetFunctionByOutputField(field *schemapb.FieldSchema) (*schemapb.FunctionSchema, error) {
for _, function := range helper.schema.GetFunctions() {
for _, id := range function.GetOutputFieldIds() {
if field.GetFieldID() == id {
return function, nil
}
}
}
return nil, fmt.Errorf("function not exist")
}
func IsBinaryVectorType(dataType schemapb.DataType) bool {
return dataType == schemapb.DataType_BinaryVector
}