mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-12-02 03:48:37 +08:00
feat: Restful support for BM25 function (#36713)
issue: https://github.com/milvus-io/milvus/issues/35853 Signed-off-by: Buqian Zheng <zhengbuqian@gmail.com>
This commit is contained in:
parent
e170991a10
commit
16b533cbf0
@ -95,14 +95,22 @@ const (
|
||||
|
||||
HTTPReturnHas = "has"
|
||||
|
||||
HTTPReturnFieldName = "name"
|
||||
HTTPReturnFieldID = "id"
|
||||
HTTPReturnFieldType = "type"
|
||||
HTTPReturnFieldPrimaryKey = "primaryKey"
|
||||
HTTPReturnFieldPartitionKey = "partitionKey"
|
||||
HTTPReturnFieldAutoID = "autoId"
|
||||
HTTPReturnFieldElementType = "elementType"
|
||||
HTTPReturnDescription = "description"
|
||||
HTTPReturnFieldName = "name"
|
||||
HTTPReturnFieldID = "id"
|
||||
HTTPReturnFieldType = "type"
|
||||
HTTPReturnFieldPrimaryKey = "primaryKey"
|
||||
HTTPReturnFieldPartitionKey = "partitionKey"
|
||||
HTTPReturnFieldAutoID = "autoId"
|
||||
HTTPReturnFieldElementType = "elementType"
|
||||
HTTPReturnDescription = "description"
|
||||
HTTPReturnFieldIsFunctionOutput = "isFunctionOutput"
|
||||
|
||||
HTTPReturnFunctionName = "name"
|
||||
HTTPReturnFunctionID = "id"
|
||||
HTTPReturnFunctionType = "type"
|
||||
HTTPReturnFunctionInputFieldNames = "inputFieldNames"
|
||||
HTTPReturnFunctionOutputFieldNames = "outputFieldNames"
|
||||
HTTPReturnFunctionParams = "params"
|
||||
|
||||
HTTPReturnIndexMetricType = "metricType"
|
||||
HTTPReturnIndexType = "indexType"
|
||||
|
@ -437,6 +437,7 @@ func (h *HandlersV2) getCollectionDetails(ctx context.Context, c *gin.Context, a
|
||||
HTTPReturnDescription: coll.Schema.Description,
|
||||
HTTPReturnFieldAutoID: autoID,
|
||||
"fields": printFieldsV2(coll.Schema.Fields),
|
||||
"functions": printFunctionDetails(coll.Schema.Functions),
|
||||
"aliases": aliases,
|
||||
"indexes": indexDesc,
|
||||
"load": collLoadState,
|
||||
@ -897,7 +898,21 @@ func generatePlaceholderGroup(ctx context.Context, body string, collSchema *sche
|
||||
if !typeutil.IsSparseFloatVectorType(vectorField.DataType) {
|
||||
dim, _ = getDim(vectorField)
|
||||
}
|
||||
phv, err := convertVectors2Placeholder(body, vectorField.DataType, dim)
|
||||
|
||||
dataType := vectorField.DataType
|
||||
|
||||
if vectorField.GetIsFunctionOutput() {
|
||||
for _, function := range collSchema.Functions {
|
||||
if function.Type == schemapb.FunctionType_BM25 {
|
||||
// TODO: currently only BM25 function is supported, thus guarantees one input field to one output field
|
||||
if function.OutputFieldNames[0] == vectorField.Name {
|
||||
dataType = schemapb.DataType_VarChar
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
phv, err := convertQueries2Placeholder(body, dataType, dim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -1086,6 +1101,17 @@ func (h *HandlersV2) createCollection(ctx context.Context, c *gin.Context, anyRe
|
||||
fieldNames := map[string]bool{}
|
||||
partitionsNum := int64(-1)
|
||||
if httpReq.Schema.Fields == nil || len(httpReq.Schema.Fields) == 0 {
|
||||
if len(httpReq.Schema.Functions) > 0 {
|
||||
err := merr.WrapErrParameterInvalid("schema", "functions",
|
||||
"functions are not supported for quickly create collection")
|
||||
log.Ctx(ctx).Warn("high level restful api, quickly create collection fail", zap.Error(err), zap.Any("request", anyReq))
|
||||
HTTPAbortReturn(c, http.StatusOK, gin.H{
|
||||
HTTPReturnCode: merr.Code(err),
|
||||
HTTPReturnMessage: err.Error(),
|
||||
})
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if httpReq.Dimension == 0 {
|
||||
err := merr.WrapErrParameterInvalid("collectionName & dimension", "collectionName",
|
||||
"dimension is required for quickly create collection(default metric type: "+DefaultMetricType+")")
|
||||
@ -1162,8 +1188,40 @@ func (h *HandlersV2) createCollection(ctx context.Context, c *gin.Context, anyRe
|
||||
Name: httpReq.CollectionName,
|
||||
AutoID: httpReq.Schema.AutoId,
|
||||
Fields: []*schemapb.FieldSchema{},
|
||||
Functions: []*schemapb.FunctionSchema{},
|
||||
EnableDynamicField: httpReq.Schema.EnableDynamicField,
|
||||
}
|
||||
|
||||
allOutputFields := []string{}
|
||||
|
||||
for _, function := range httpReq.Schema.Functions {
|
||||
functionTypeValue, ok := schemapb.FunctionType_value[function.FunctionType]
|
||||
if !ok {
|
||||
log.Ctx(ctx).Warn("function's data type is invalid(case sensitive).", zap.Any("function.DataType", function.FunctionType), zap.Any("function", function))
|
||||
err := merr.WrapErrParameterInvalid("FunctionType", function.FunctionType, "function data type is invalid(case sensitive)")
|
||||
HTTPAbortReturn(c, http.StatusOK, gin.H{
|
||||
HTTPReturnCode: merr.Code(merr.ErrParameterInvalid),
|
||||
HTTPReturnMessage: err.Error(),
|
||||
})
|
||||
return nil, err
|
||||
}
|
||||
functionType := schemapb.FunctionType(functionTypeValue)
|
||||
description := function.Description
|
||||
params := []*commonpb.KeyValuePair{}
|
||||
for key, value := range function.Params {
|
||||
params = append(params, &commonpb.KeyValuePair{Key: key, Value: fmt.Sprintf("%v", value)})
|
||||
}
|
||||
collSchema.Functions = append(collSchema.Functions, &schemapb.FunctionSchema{
|
||||
Name: function.FunctionName,
|
||||
Description: description,
|
||||
Type: functionType,
|
||||
InputFieldNames: function.InputFieldNames,
|
||||
OutputFieldNames: function.OutputFieldNames,
|
||||
Params: params,
|
||||
})
|
||||
allOutputFields = append(allOutputFields, function.OutputFieldNames...)
|
||||
}
|
||||
|
||||
for _, field := range httpReq.Schema.Fields {
|
||||
fieldDataType, ok := schemapb.DataType_value[field.DataType]
|
||||
if !ok {
|
||||
@ -1218,6 +1276,9 @@ func (h *HandlersV2) createCollection(ctx context.Context, c *gin.Context, anyRe
|
||||
for key, fieldParam := range field.ElementTypeParams {
|
||||
fieldSchema.TypeParams = append(fieldSchema.TypeParams, &commonpb.KeyValuePair{Key: key, Value: fmt.Sprintf("%v", fieldParam)})
|
||||
}
|
||||
if lo.Contains(allOutputFields, field.FieldName) {
|
||||
fieldSchema.IsFunctionOutput = true
|
||||
}
|
||||
collSchema.Fields = append(collSchema.Fields, &fieldSchema)
|
||||
fieldNames[field.FieldName] = true
|
||||
}
|
||||
|
@ -57,6 +57,22 @@ func init() {
|
||||
paramtable.Init()
|
||||
}
|
||||
|
||||
func sendReqAndVerify(t *testing.T, testEngine *gin.Engine, testName, method string, testcase requestBodyTestCase) {
|
||||
t.Run(testName, func(t *testing.T) {
|
||||
req := httptest.NewRequest(method, testcase.path, bytes.NewReader(testcase.requestBody))
|
||||
w := httptest.NewRecorder()
|
||||
testEngine.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
returnBody := &ReturnErrMsg{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), returnBody)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, testcase.errCode, returnBody.Code)
|
||||
if testcase.errCode != 0 {
|
||||
assert.Contains(t, returnBody.Message, testcase.errMsg)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestHTTPWrapper(t *testing.T) {
|
||||
postTestCases := []requestBodyTestCase{}
|
||||
postTestCasesTrace := []requestBodyTestCase{}
|
||||
@ -468,6 +484,230 @@ func TestDatabaseWrapper(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDocInDocOutCreateCollection(t *testing.T) {
|
||||
paramtable.Init()
|
||||
// disable rate limit
|
||||
paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "false")
|
||||
defer paramtable.Get().Reset(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key)
|
||||
|
||||
postTestCases := []requestBodyTestCase{}
|
||||
mp := mocks.NewMockProxy(t)
|
||||
mp.EXPECT().CreateCollection(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Times(1)
|
||||
testEngine := initHTTPServerV2(mp, false)
|
||||
path := versionalV2(CollectionCategory, CreateAction)
|
||||
|
||||
const baseRequestBody = `{
|
||||
"collectionName": "doc_in_doc_out_demo",
|
||||
"schema": {
|
||||
"autoId": false,
|
||||
"enableDynamicField": false,
|
||||
"fields": [
|
||||
{
|
||||
"fieldName": "my_id",
|
||||
"dataType": "Int64",
|
||||
"isPrimary": true
|
||||
},
|
||||
{
|
||||
"fieldName": "document_content",
|
||||
"dataType": "VarChar",
|
||||
"elementTypeParams": {
|
||||
"max_length": "9000"
|
||||
}
|
||||
},
|
||||
{
|
||||
"fieldName": "sparse_vector_1",
|
||||
"dataType": "SparseFloatVector"
|
||||
}
|
||||
],
|
||||
"functions": %s
|
||||
}
|
||||
}`
|
||||
|
||||
postTestCases = append(postTestCases, requestBodyTestCase{
|
||||
path: path,
|
||||
requestBody: []byte(fmt.Sprintf(baseRequestBody, `[
|
||||
{
|
||||
"name": "bm25_fn_1",
|
||||
"type": "BM25",
|
||||
"inputFieldNames": ["document_content"],
|
||||
"outputFieldNames": ["sparse_vector_1"]
|
||||
}
|
||||
]`)),
|
||||
})
|
||||
|
||||
postTestCases = append(postTestCases, requestBodyTestCase{
|
||||
path: path,
|
||||
requestBody: []byte(fmt.Sprintf(baseRequestBody, `[
|
||||
{
|
||||
"name": "bm25_fn_1",
|
||||
"type": "BM25_",
|
||||
"inputFieldNames": ["document_content"],
|
||||
"outputFieldNames": ["sparse_vector_1"]
|
||||
}
|
||||
]`)),
|
||||
errMsg: "actual=BM25_",
|
||||
errCode: 1100,
|
||||
})
|
||||
|
||||
postTestCases = append(postTestCases, requestBodyTestCase{
|
||||
path: path,
|
||||
requestBody: []byte(fmt.Sprintf(baseRequestBody, `[
|
||||
{
|
||||
"name": "bm25_fn_1",
|
||||
"inputFieldNames": ["document_content"],
|
||||
"outputFieldNames": ["sparse_vector_1"]
|
||||
}
|
||||
]`)),
|
||||
errMsg: "actual=", // unprovided function type is empty string
|
||||
errCode: 1100,
|
||||
})
|
||||
|
||||
for _, testcase := range postTestCases {
|
||||
sendReqAndVerify(t, testEngine, "post"+testcase.path, http.MethodPost, testcase)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDocInDocOutCreateCollectionQuickDisallowFunction(t *testing.T) {
|
||||
paramtable.Init()
|
||||
// disable rate limit
|
||||
paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "false")
|
||||
defer paramtable.Get().Reset(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key)
|
||||
|
||||
mp := mocks.NewMockProxy(t)
|
||||
testEngine := initHTTPServerV2(mp, false)
|
||||
path := versionalV2(CollectionCategory, CreateAction)
|
||||
|
||||
const baseRequestBody = `{
|
||||
"collectionName": "doc_in_doc_out_demo",
|
||||
"dimension": 2,
|
||||
"idType": "Varchar",
|
||||
"schema": {
|
||||
"autoId": false,
|
||||
"enableDynamicField": false,
|
||||
"functions": [
|
||||
{
|
||||
"name": "bm25_fn_1",
|
||||
"type": "BM25",
|
||||
"inputFieldNames": ["document_content"],
|
||||
"outputFieldNames": ["sparse_vector_1"]
|
||||
}
|
||||
]
|
||||
}
|
||||
}`
|
||||
|
||||
testcase := requestBodyTestCase{
|
||||
path: path,
|
||||
requestBody: []byte(baseRequestBody),
|
||||
errMsg: "functions are not supported for quickly create collection",
|
||||
errCode: 1100,
|
||||
}
|
||||
|
||||
sendReqAndVerify(t, testEngine, "post"+testcase.path, http.MethodPost, testcase)
|
||||
}
|
||||
|
||||
func TestDocInDocOutDescribeCollection(t *testing.T) {
|
||||
paramtable.Init()
|
||||
mp := mocks.NewMockProxy(t)
|
||||
mp.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{
|
||||
CollectionName: DefaultCollectionName,
|
||||
Schema: generateDocInDocOutCollectionSchema(schemapb.DataType_Int64),
|
||||
ShardsNum: ShardNumDefault,
|
||||
Status: &StatusSuccess,
|
||||
}, nil).Once()
|
||||
mp.EXPECT().GetLoadState(mock.Anything, mock.Anything).Return(&DefaultLoadStateResp, nil).Once()
|
||||
mp.EXPECT().DescribeIndex(mock.Anything, mock.Anything).Return(&DefaultDescIndexesReqp, nil).Once()
|
||||
mp.EXPECT().ListAliases(mock.Anything, mock.Anything).Return(&milvuspb.ListAliasesResponse{
|
||||
Status: &StatusSuccess,
|
||||
Aliases: []string{DefaultAliasName},
|
||||
}, nil).Once()
|
||||
testEngine := initHTTPServerV2(mp, false)
|
||||
testcase := requestBodyTestCase{
|
||||
path: versionalV2(CollectionCategory, DescribeAction),
|
||||
requestBody: []byte(`{"collectionName": "` + DefaultCollectionName + `"}`),
|
||||
}
|
||||
sendReqAndVerify(t, testEngine, testcase.path, http.MethodPost, testcase)
|
||||
}
|
||||
|
||||
func TestDocInDocOutInsert(t *testing.T) {
|
||||
paramtable.Init()
|
||||
// disable rate limit
|
||||
paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "false")
|
||||
defer paramtable.Get().Reset(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key)
|
||||
|
||||
mp := mocks.NewMockProxy(t)
|
||||
testEngine := initHTTPServerV2(mp, false)
|
||||
mp.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{
|
||||
CollectionName: DefaultCollectionName,
|
||||
Schema: generateDocInDocOutCollectionSchema(schemapb.DataType_Int64),
|
||||
ShardsNum: ShardNumDefault,
|
||||
Status: &StatusSuccess,
|
||||
}, nil).Once()
|
||||
mp.EXPECT().Insert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{Status: commonSuccessStatus, InsertCnt: int64(0), IDs: &schemapb.IDs{IdField: &schemapb.IDs_IntId{IntId: &schemapb.LongArray{Data: []int64{}}}}}, nil).Once()
|
||||
|
||||
testcase := requestBodyTestCase{
|
||||
path: versionalV2(EntityCategory, InsertAction),
|
||||
requestBody: []byte(`{"collectionName": "book", "data": [{"book_id": 0, "word_count": 0, "varchar_field": "some text"}]}`),
|
||||
}
|
||||
|
||||
sendReqAndVerify(t, testEngine, testcase.path, http.MethodPost, testcase)
|
||||
}
|
||||
|
||||
func TestDocInDocOutInsertInvalid(t *testing.T) {
|
||||
paramtable.Init()
|
||||
// disable rate limit
|
||||
paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "false")
|
||||
defer paramtable.Get().Reset(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key)
|
||||
|
||||
mp := mocks.NewMockProxy(t)
|
||||
testEngine := initHTTPServerV2(mp, false)
|
||||
mp.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{
|
||||
CollectionName: DefaultCollectionName,
|
||||
Schema: generateDocInDocOutCollectionSchema(schemapb.DataType_Int64),
|
||||
ShardsNum: ShardNumDefault,
|
||||
Status: &StatusSuccess,
|
||||
}, nil).Once()
|
||||
// invlaid insert request, will not be sent to proxy
|
||||
|
||||
testcase := requestBodyTestCase{
|
||||
path: versionalV2(EntityCategory, InsertAction),
|
||||
requestBody: []byte(`{"collectionName": "book", "data": [{"book_id": 0, "word_count": 0, "book_intro": {"1": 0.1}, "varchar_field": "some text"}]}`),
|
||||
errCode: 1804,
|
||||
errMsg: "not allowed to provide input data for function output field",
|
||||
}
|
||||
|
||||
sendReqAndVerify(t, testEngine, testcase.path, http.MethodPost, testcase)
|
||||
}
|
||||
|
||||
func TestDocInDocOutSearch(t *testing.T) {
|
||||
paramtable.Init()
|
||||
// disable rate limit
|
||||
paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "false")
|
||||
defer paramtable.Get().Reset(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key)
|
||||
|
||||
mp := mocks.NewMockProxy(t)
|
||||
testEngine := initHTTPServerV2(mp, false)
|
||||
mp.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{
|
||||
CollectionName: DefaultCollectionName,
|
||||
Schema: generateDocInDocOutCollectionSchema(schemapb.DataType_Int64),
|
||||
ShardsNum: ShardNumDefault,
|
||||
Status: &StatusSuccess,
|
||||
}, nil).Once()
|
||||
mp.EXPECT().Search(mock.Anything, mock.Anything).Return(&milvuspb.SearchResults{Status: commonSuccessStatus, Results: &schemapb.SearchResultData{
|
||||
TopK: int64(3),
|
||||
OutputFields: []string{FieldWordCount},
|
||||
FieldsData: generateFieldData(),
|
||||
Ids: generateIDs(schemapb.DataType_Int64, 3),
|
||||
Scores: DefaultScores,
|
||||
}}, nil).Once()
|
||||
|
||||
testcase := requestBodyTestCase{
|
||||
path: versionalV2(EntityCategory, SearchAction),
|
||||
requestBody: []byte(`{"collectionName": "book", "data": ["query data"], "limit": 4, "outputFields": ["word_count"]}`),
|
||||
}
|
||||
|
||||
sendReqAndVerify(t, testEngine, testcase.path, http.MethodPost, testcase)
|
||||
}
|
||||
|
||||
func TestCreateCollection(t *testing.T) {
|
||||
paramtable.Init()
|
||||
// disable rate limit
|
||||
@ -1054,7 +1294,6 @@ func TestMethodGet(t *testing.T) {
|
||||
if testcase.errCode != 0 {
|
||||
assert.Equal(t, testcase.errMsg, returnBody.Message)
|
||||
}
|
||||
fmt.Println(w.Body.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -324,10 +324,20 @@ type FieldSchema struct {
|
||||
DefaultValue interface{} `json:"defaultValue" binding:"required"`
|
||||
}
|
||||
|
||||
type FunctionSchema struct {
|
||||
FunctionName string `json:"name" binding:"required"`
|
||||
Description string `json:"description"`
|
||||
FunctionType string `json:"type" binding:"required"`
|
||||
InputFieldNames []string `json:"inputFieldNames" binding:"required"`
|
||||
OutputFieldNames []string `json:"outputFieldNames" binding:"required"`
|
||||
Params map[string]interface{} `json:"params"`
|
||||
}
|
||||
|
||||
type CollectionSchema struct {
|
||||
Fields []FieldSchema `json:"fields"`
|
||||
AutoId bool `json:"autoID"`
|
||||
EnableDynamicField bool `json:"enableDynamicField"`
|
||||
Fields []FieldSchema `json:"fields"`
|
||||
Functions []FunctionSchema `json:"functions"`
|
||||
AutoId bool `json:"autoID"`
|
||||
EnableDynamicField bool `json:"enableDynamicField"`
|
||||
}
|
||||
|
||||
type CollectionReq struct {
|
||||
|
@ -147,52 +147,77 @@ func checkGetPrimaryKey(coll *schemapb.CollectionSchema, idResult gjson.Result)
|
||||
// --------------------- collection details --------------------- //
|
||||
|
||||
func printFields(fields []*schemapb.FieldSchema) []gin.H {
|
||||
return printFieldDetails(fields, true)
|
||||
var res []gin.H
|
||||
for _, field := range fields {
|
||||
fieldDetail := printFieldDetail(field, true)
|
||||
res = append(res, fieldDetail)
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func printFieldsV2(fields []*schemapb.FieldSchema) []gin.H {
|
||||
return printFieldDetails(fields, false)
|
||||
}
|
||||
|
||||
func printFieldDetails(fields []*schemapb.FieldSchema, oldVersion bool) []gin.H {
|
||||
var res []gin.H
|
||||
for _, field := range fields {
|
||||
fieldDetail := gin.H{
|
||||
HTTPReturnFieldName: field.Name,
|
||||
HTTPReturnFieldPrimaryKey: field.IsPrimaryKey,
|
||||
HTTPReturnFieldPartitionKey: field.IsPartitionKey,
|
||||
HTTPReturnFieldAutoID: field.AutoID,
|
||||
HTTPReturnDescription: field.Description,
|
||||
}
|
||||
if typeutil.IsVectorType(field.DataType) {
|
||||
fieldDetail[HTTPReturnFieldType] = field.DataType.String()
|
||||
if oldVersion {
|
||||
dim, _ := getDim(field)
|
||||
fieldDetail[HTTPReturnFieldType] = field.DataType.String() + "(" + strconv.FormatInt(dim, 10) + ")"
|
||||
}
|
||||
} else if field.DataType == schemapb.DataType_VarChar {
|
||||
fieldDetail[HTTPReturnFieldType] = field.DataType.String()
|
||||
if oldVersion {
|
||||
maxLength, _ := parameterutil.GetMaxLength(field)
|
||||
fieldDetail[HTTPReturnFieldType] = field.DataType.String() + "(" + strconv.FormatInt(maxLength, 10) + ")"
|
||||
}
|
||||
} else {
|
||||
fieldDetail[HTTPReturnFieldType] = field.DataType.String()
|
||||
}
|
||||
if !oldVersion {
|
||||
fieldDetail[HTTPReturnFieldID] = field.FieldID
|
||||
if field.TypeParams != nil {
|
||||
fieldDetail[Params] = field.TypeParams
|
||||
}
|
||||
if field.DataType == schemapb.DataType_Array {
|
||||
fieldDetail[HTTPReturnFieldElementType] = field.GetElementType().String()
|
||||
}
|
||||
}
|
||||
fieldDetail := printFieldDetail(field, false)
|
||||
res = append(res, fieldDetail)
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func printFieldDetail(field *schemapb.FieldSchema, oldVersion bool) gin.H {
|
||||
fieldDetail := gin.H{
|
||||
HTTPReturnFieldName: field.Name,
|
||||
HTTPReturnFieldPrimaryKey: field.IsPrimaryKey,
|
||||
HTTPReturnFieldPartitionKey: field.IsPartitionKey,
|
||||
HTTPReturnFieldAutoID: field.AutoID,
|
||||
HTTPReturnDescription: field.Description,
|
||||
}
|
||||
if field.GetIsFunctionOutput() {
|
||||
fieldDetail[HTTPReturnFieldIsFunctionOutput] = true
|
||||
}
|
||||
if typeutil.IsVectorType(field.DataType) {
|
||||
fieldDetail[HTTPReturnFieldType] = field.DataType.String()
|
||||
if oldVersion {
|
||||
dim, _ := getDim(field)
|
||||
fieldDetail[HTTPReturnFieldType] = field.DataType.String() + "(" + strconv.FormatInt(dim, 10) + ")"
|
||||
}
|
||||
} else if field.DataType == schemapb.DataType_VarChar {
|
||||
fieldDetail[HTTPReturnFieldType] = field.DataType.String()
|
||||
if oldVersion {
|
||||
maxLength, _ := parameterutil.GetMaxLength(field)
|
||||
fieldDetail[HTTPReturnFieldType] = field.DataType.String() + "(" + strconv.FormatInt(maxLength, 10) + ")"
|
||||
}
|
||||
} else {
|
||||
fieldDetail[HTTPReturnFieldType] = field.DataType.String()
|
||||
}
|
||||
if !oldVersion {
|
||||
fieldDetail[HTTPReturnFieldID] = field.FieldID
|
||||
if field.TypeParams != nil {
|
||||
fieldDetail[Params] = field.TypeParams
|
||||
}
|
||||
if field.DataType == schemapb.DataType_Array {
|
||||
fieldDetail[HTTPReturnFieldElementType] = field.GetElementType().String()
|
||||
}
|
||||
}
|
||||
return fieldDetail
|
||||
}
|
||||
|
||||
func printFunctionDetails(functions []*schemapb.FunctionSchema) []gin.H {
|
||||
var res []gin.H
|
||||
for _, function := range functions {
|
||||
res = append(res, gin.H{
|
||||
HTTPReturnFunctionName: function.Name,
|
||||
HTTPReturnDescription: function.Description,
|
||||
HTTPReturnFunctionType: function.Type,
|
||||
HTTPReturnFunctionID: function.Id,
|
||||
HTTPReturnFunctionInputFieldNames: function.InputFieldNames,
|
||||
HTTPReturnFunctionOutputFieldNames: function.OutputFieldNames,
|
||||
HTTPReturnFunctionParams: function.Params,
|
||||
})
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func getMetricType(pairs []*commonpb.KeyValuePair) string {
|
||||
metricType := DefaultMetricType
|
||||
for _, pair := range pairs {
|
||||
@ -258,6 +283,14 @@ func checkAndSetData(body string, collSchema *schemapb.CollectionSchema) (error,
|
||||
continue
|
||||
}
|
||||
|
||||
// if field is a function output field, user must not provide data for it
|
||||
if field.GetIsFunctionOutput() {
|
||||
if dataString != "" {
|
||||
return merr.WrapErrParameterInvalid("", "not allowed to provide input data for function output field: "+fieldName), reallyDataArray, validDataMap
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
switch fieldType {
|
||||
case schemapb.DataType_FloatVector:
|
||||
if dataString == "" {
|
||||
@ -626,11 +659,16 @@ func anyToColumns(rows []map[string]interface{}, validDataMap map[string][]bool,
|
||||
nameColumns := make(map[string]interface{})
|
||||
nameDims := make(map[string]int64)
|
||||
fieldData := make(map[string]*schemapb.FieldData)
|
||||
|
||||
for _, field := range sch.Fields {
|
||||
// skip auto id pk field
|
||||
if (field.IsPrimaryKey && field.AutoID) || field.IsDynamic {
|
||||
continue
|
||||
}
|
||||
// skip function output field
|
||||
if field.GetIsFunctionOutput() {
|
||||
continue
|
||||
}
|
||||
var data interface{}
|
||||
switch field.DataType {
|
||||
case schemapb.DataType_Bool:
|
||||
@ -685,8 +723,8 @@ func anyToColumns(rows []map[string]interface{}, validDataMap map[string][]bool,
|
||||
IsDynamic: field.IsDynamic,
|
||||
}
|
||||
}
|
||||
if len(nameDims) == 0 {
|
||||
return nil, fmt.Errorf("collection: %s has no vector field", sch.Name)
|
||||
if len(nameDims) == 0 && len(sch.Functions) == 0 {
|
||||
return nil, fmt.Errorf("collection: %s has no vector field or functions", sch.Name)
|
||||
}
|
||||
|
||||
dynamicCol := make([][]byte, 0, rowsLen)
|
||||
@ -709,6 +747,12 @@ func anyToColumns(rows []map[string]interface{}, validDataMap map[string][]bool,
|
||||
if (field.Nullable || field.DefaultValue != nil) && !ok {
|
||||
continue
|
||||
}
|
||||
if field.GetIsFunctionOutput() {
|
||||
if ok {
|
||||
return nil, fmt.Errorf("row %d has data provided for function output field %s", idx, field.Name)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("row %d does not has field %s", idx, field.Name)
|
||||
}
|
||||
@ -1035,7 +1079,7 @@ func serializeSparseFloatVectors(vectors []gjson.Result, dataType schemapb.DataT
|
||||
return values, nil
|
||||
}
|
||||
|
||||
func convertVectors2Placeholder(body string, dataType schemapb.DataType, dimension int64) (*commonpb.PlaceholderValue, error) {
|
||||
func convertQueries2Placeholder(body string, dataType schemapb.DataType, dimension int64) (*commonpb.PlaceholderValue, error) {
|
||||
var valueType commonpb.PlaceholderType
|
||||
var values [][]byte
|
||||
var err error
|
||||
@ -1055,6 +1099,12 @@ func convertVectors2Placeholder(body string, dataType schemapb.DataType, dimensi
|
||||
case schemapb.DataType_SparseFloatVector:
|
||||
valueType = commonpb.PlaceholderType_SparseFloatVector
|
||||
values, err = serializeSparseFloatVectors(gjson.Get(body, HTTPRequestData).Array(), dataType)
|
||||
case schemapb.DataType_VarChar:
|
||||
valueType = commonpb.PlaceholderType_VarChar
|
||||
res := gjson.Get(body, HTTPRequestData).Array()
|
||||
for _, v := range res {
|
||||
values = append(values, []byte(v.String()))
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -23,6 +23,7 @@ const (
|
||||
FieldWordCount = "word_count"
|
||||
FieldBookID = "book_id"
|
||||
FieldBookIntro = "book_intro"
|
||||
FieldVarchar = "varchar_field"
|
||||
)
|
||||
|
||||
var DefaultScores = []float32{0.01, 0.04, 0.09}
|
||||
@ -74,17 +75,21 @@ func generateVectorFieldSchema(dataType schemapb.DataType) *schemapb.FieldSchema
|
||||
if dataType == schemapb.DataType_BinaryVector {
|
||||
dim = "8"
|
||||
}
|
||||
typeParams := []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: common.DimKey,
|
||||
Value: dim,
|
||||
},
|
||||
}
|
||||
if dataType == schemapb.DataType_SparseFloatVector {
|
||||
typeParams = nil
|
||||
}
|
||||
return &schemapb.FieldSchema{
|
||||
FieldID: common.StartOfUserFieldID + int64(dataType),
|
||||
IsPrimaryKey: false,
|
||||
DataType: dataType,
|
||||
AutoID: false,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: common.DimKey,
|
||||
Value: dim,
|
||||
},
|
||||
},
|
||||
TypeParams: typeParams,
|
||||
}
|
||||
}
|
||||
|
||||
@ -110,6 +115,44 @@ func generateCollectionSchema(primaryDataType schemapb.DataType) *schemapb.Colle
|
||||
}
|
||||
}
|
||||
|
||||
func generateDocInDocOutCollectionSchema(primaryDataType schemapb.DataType) *schemapb.CollectionSchema {
|
||||
primaryField := generatePrimaryField(primaryDataType)
|
||||
vectorField := generateVectorFieldSchema(schemapb.DataType_SparseFloatVector)
|
||||
vectorField.Name = FieldBookIntro
|
||||
vectorField.IsFunctionOutput = true
|
||||
return &schemapb.CollectionSchema{
|
||||
Name: DefaultCollectionName,
|
||||
Description: "",
|
||||
AutoID: false,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
primaryField, {
|
||||
FieldID: common.StartOfUserFieldID + 1,
|
||||
Name: FieldWordCount,
|
||||
IsPrimaryKey: false,
|
||||
Description: "",
|
||||
DataType: 5,
|
||||
AutoID: false,
|
||||
}, vectorField, {
|
||||
FieldID: common.StartOfUserFieldID + 2,
|
||||
Name: FieldVarchar,
|
||||
IsPrimaryKey: false,
|
||||
Description: "",
|
||||
DataType: schemapb.DataType_VarChar,
|
||||
AutoID: false,
|
||||
},
|
||||
},
|
||||
Functions: []*schemapb.FunctionSchema{
|
||||
{
|
||||
Name: "sum",
|
||||
Type: schemapb.FunctionType_BM25,
|
||||
InputFieldNames: []string{FieldVarchar},
|
||||
OutputFieldNames: []string{FieldBookIntro},
|
||||
},
|
||||
},
|
||||
EnableDynamicField: true,
|
||||
}
|
||||
}
|
||||
|
||||
func generateIndexes() []*milvuspb.IndexDescription {
|
||||
return []*milvuspb.IndexDescription{
|
||||
{
|
||||
|
@ -86,7 +86,7 @@ func (t *searchTask) CanSkipAllocTimestamp() bool {
|
||||
var consistencyLevel commonpb.ConsistencyLevel
|
||||
useDefaultConsistency := t.request.GetUseDefaultConsistency()
|
||||
if !useDefaultConsistency {
|
||||
// legacy SDK & resultful behavior
|
||||
// legacy SDK & restful behavior
|
||||
if t.request.GetConsistencyLevel() == commonpb.ConsistencyLevel_Strong && t.request.GetGuaranteeTimestamp() > 0 {
|
||||
return true
|
||||
}
|
||||
@ -373,7 +373,7 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error {
|
||||
internalSubReq.FieldId = queryInfo.GetQueryFieldId()
|
||||
// set PartitionIDs for sub search
|
||||
if t.partitionKeyMode {
|
||||
// isolatioin has tighter constraint, check first
|
||||
// isolation has tighter constraint, check first
|
||||
mvErr := setQueryInfoIfMvEnable(queryInfo, t, plan)
|
||||
if mvErr != nil {
|
||||
return mvErr
|
||||
@ -453,7 +453,7 @@ func (t *searchTask) initSearchRequest(ctx context.Context) error {
|
||||
t.SearchRequest.FieldId = queryInfo.GetQueryFieldId()
|
||||
|
||||
if t.partitionKeyMode {
|
||||
// isolatioin has tighter constraint, check first
|
||||
// isolation has tighter constraint, check first
|
||||
mvErr := setQueryInfoIfMvEnable(queryInfo, t, plan)
|
||||
if mvErr != nil {
|
||||
return mvErr
|
||||
|
@ -296,7 +296,7 @@ func (sd *shardDelegator) Search(ctx context.Context, req *querypb.SearchRequest
|
||||
defer sd.lifetime.Done()
|
||||
|
||||
if !funcutil.SliceContain(req.GetDmlChannels(), sd.vchannelName) {
|
||||
log.Warn("deletgator received search request not belongs to it",
|
||||
log.Warn("delegator received search request not belongs to it",
|
||||
zap.Strings("reqChannels", req.GetDmlChannels()),
|
||||
)
|
||||
return nil, fmt.Errorf("dml channel not match, delegator channel %s, search channels %v", sd.vchannelName, req.GetDmlChannels())
|
||||
|
Loading…
Reference in New Issue
Block a user