diff --git a/internal/distributed/proxy/httpserver/constant.go b/internal/distributed/proxy/httpserver/constant.go index 05ec9d6ec1..c91f22c671 100644 --- a/internal/distributed/proxy/httpserver/constant.go +++ b/internal/distributed/proxy/httpserver/constant.go @@ -1,6 +1,10 @@ package httpserver -import "time" +import ( + "time" + + "github.com/milvus-io/milvus/pkg/util/metric" +) // v2 const ( @@ -14,23 +18,23 @@ const ( AliasCategory = "/aliases/" ImportJobCategory = "/jobs/import/" - ListAction = "list" - HasAction = "has" - DescribeAction = "describe" - CreateAction = "create" - DropAction = "drop" - StatsAction = "get_stats" - LoadStateAction = "get_load_state" - RenameAction = "rename" - LoadAction = "load" - ReleaseAction = "release" - QueryAction = "query" - GetAction = "get" - DeleteAction = "delete" - InsertAction = "insert" - UpsertAction = "upsert" - SearchAction = "search" - HybridSearchAction = "hybrid_search" + ListAction = "list" + HasAction = "has" + DescribeAction = "describe" + CreateAction = "create" + DropAction = "drop" + StatsAction = "get_stats" + LoadStateAction = "get_load_state" + RenameAction = "rename" + LoadAction = "load" + ReleaseAction = "release" + QueryAction = "query" + GetAction = "get" + DeleteAction = "delete" + InsertAction = "insert" + UpsertAction = "upsert" + SearchAction = "search" + AdvancedSearchAction = "advanced_search" UpdatePasswordAction = "update_password" GrantRoleAction = "grant_role" @@ -72,6 +76,7 @@ const ( HTTPIndexName = "indexName" HTTPIndexField = "fieldName" HTTPAliasName = "aliasName" + HTTPRequestData = "data" DefaultDbName = "default" DefaultIndexName = "vector_idx" DefaultAliasName = "the_alias" @@ -112,7 +117,7 @@ const ( HTTPReturnGrantor = "grantor" HTTPReturnDbName = "dbName" - DefaultMetricType = "L2" + DefaultMetricType = metric.COSINE DefaultPrimaryFieldName = "id" DefaultVectorFieldName = "vector" diff --git a/internal/distributed/proxy/httpserver/handler_v1.go b/internal/distributed/proxy/httpserver/handler_v1.go index 45d5c56ca6..40d627261c 100644 --- a/internal/distributed/proxy/httpserver/handler_v1.go +++ b/internal/distributed/proxy/httpserver/handler_v1.go @@ -22,6 +22,7 @@ import ( "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/metric" "github.com/milvus-io/milvus/pkg/util/requestutil" ) @@ -220,7 +221,7 @@ func (h *HandlersV1) listCollections(c *gin.Context) { func (h *HandlersV1) createCollection(c *gin.Context) { httpReq := CreateCollectionReq{ DbName: DefaultDbName, - MetricType: DefaultMetricType, + MetricType: metric.L2, PrimaryField: DefaultPrimaryFieldName, VectorField: DefaultVectorFieldName, EnableDynamicField: EnableDynamic, diff --git a/internal/distributed/proxy/httpserver/handler_v1_test.go b/internal/distributed/proxy/httpserver/handler_v1_test.go index fd3d041388..8c071fdb1e 100644 --- a/internal/distributed/proxy/httpserver/handler_v1_test.go +++ b/internal/distributed/proxy/httpserver/handler_v1_test.go @@ -271,7 +271,7 @@ func TestVectorCollectionsDescribe(t *testing.T) { name: "get load status fail", mp: mp2, exceptCode: http.StatusOK, - expectedBody: "{\"code\":200,\"data\":{\"collectionName\":\"" + DefaultCollectionName + "\",\"description\":\"\",\"enableDynamicField\":true,\"fields\":[{\"autoId\":false,\"description\":\"\",\"name\":\"book_id\",\"partitionKey\":false,\"primaryKey\":true,\"type\":\"Int64\"},{\"autoId\":false,\"description\":\"\",\"name\":\"word_count\",\"partitionKey\":false,\"primaryKey\":false,\"type\":\"Int64\"},{\"autoId\":false,\"description\":\"\",\"name\":\"book_intro\",\"partitionKey\":false,\"primaryKey\":false,\"type\":\"FloatVector(2)\"}],\"indexes\":[{\"fieldName\":\"book_intro\",\"indexName\":\"" + DefaultIndexName + "\",\"metricType\":\"L2\"}],\"load\":\"\",\"shardsNum\":1}}", + expectedBody: "{\"code\":200,\"data\":{\"collectionName\":\"" + DefaultCollectionName + "\",\"description\":\"\",\"enableDynamicField\":true,\"fields\":[{\"autoId\":false,\"description\":\"\",\"name\":\"book_id\",\"partitionKey\":false,\"primaryKey\":true,\"type\":\"Int64\"},{\"autoId\":false,\"description\":\"\",\"name\":\"word_count\",\"partitionKey\":false,\"primaryKey\":false,\"type\":\"Int64\"},{\"autoId\":false,\"description\":\"\",\"name\":\"book_intro\",\"partitionKey\":false,\"primaryKey\":false,\"type\":\"FloatVector(2)\"}],\"indexes\":[{\"fieldName\":\"book_intro\",\"indexName\":\"" + DefaultIndexName + "\",\"metricType\":\"COSINE\"}],\"load\":\"\",\"shardsNum\":1}}", }) mp3 := mocks.NewMockProxy(t) @@ -293,7 +293,7 @@ func TestVectorCollectionsDescribe(t *testing.T) { name: "show collection details success", mp: mp4, exceptCode: http.StatusOK, - expectedBody: "{\"code\":200,\"data\":{\"collectionName\":\"" + DefaultCollectionName + "\",\"description\":\"\",\"enableDynamicField\":true,\"fields\":[{\"autoId\":false,\"description\":\"\",\"name\":\"book_id\",\"partitionKey\":false,\"primaryKey\":true,\"type\":\"Int64\"},{\"autoId\":false,\"description\":\"\",\"name\":\"word_count\",\"partitionKey\":false,\"primaryKey\":false,\"type\":\"Int64\"},{\"autoId\":false,\"description\":\"\",\"name\":\"book_intro\",\"partitionKey\":false,\"primaryKey\":false,\"type\":\"FloatVector(2)\"}],\"indexes\":[{\"fieldName\":\"book_intro\",\"indexName\":\"" + DefaultIndexName + "\",\"metricType\":\"L2\"}],\"load\":\"LoadStateLoaded\",\"shardsNum\":1}}", + expectedBody: "{\"code\":200,\"data\":{\"collectionName\":\"" + DefaultCollectionName + "\",\"description\":\"\",\"enableDynamicField\":true,\"fields\":[{\"autoId\":false,\"description\":\"\",\"name\":\"book_id\",\"partitionKey\":false,\"primaryKey\":true,\"type\":\"Int64\"},{\"autoId\":false,\"description\":\"\",\"name\":\"word_count\",\"partitionKey\":false,\"primaryKey\":false,\"type\":\"Int64\"},{\"autoId\":false,\"description\":\"\",\"name\":\"book_intro\",\"partitionKey\":false,\"primaryKey\":false,\"type\":\"FloatVector(2)\"}],\"indexes\":[{\"fieldName\":\"book_intro\",\"indexName\":\"" + DefaultIndexName + "\",\"metricType\":\"COSINE\"}],\"load\":\"LoadStateLoaded\",\"shardsNum\":1}}", }) for _, tt := range testCases { diff --git a/internal/distributed/proxy/httpserver/handler_v2.go b/internal/distributed/proxy/httpserver/handler_v2.go index 51f0533089..3af1b57865 100644 --- a/internal/distributed/proxy/httpserver/handler_v2.go +++ b/internal/distributed/proxy/httpserver/handler_v2.go @@ -66,12 +66,12 @@ func (h *HandlersV2) RegisterRoutesToV2(router gin.IRouter) { } }, wrapperTraceLog(h.wrapperCheckDatabase(h.query))))) router.POST(EntityCategory+GetAction, timeoutMiddleware(wrapperPost(func() any { - return &CollectionIDOutputReq{ + return &CollectionIDReq{ OutputFields: []string{DefaultOutputFields}, } }, wrapperTraceLog(h.wrapperCheckDatabase(h.get))))) router.POST(EntityCategory+DeleteAction, timeoutMiddleware(wrapperPost(func() any { - return &CollectionIDFilterReq{} + return &CollectionFilterReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.delete))))) router.POST(EntityCategory+InsertAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionDataReq{} @@ -84,11 +84,11 @@ func (h *HandlersV2) RegisterRoutesToV2(router gin.IRouter) { Limit: 100, } }, wrapperTraceLog(h.wrapperCheckDatabase(h.search))))) - router.POST(EntityCategory+HybridSearchAction, timeoutMiddleware(wrapperPost(func() any { + router.POST(EntityCategory+AdvancedSearchAction, timeoutMiddleware(wrapperPost(func() any { return &HybridSearchReq{ Limit: 100, } - }, wrapperTraceLog(h.wrapperCheckDatabase(h.hybridSearch))))) + }, wrapperTraceLog(h.wrapperCheckDatabase(h.advancedSearch))))) router.POST(PartitionCategory+ListAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.listPartitions))))) router.POST(PartitionCategory+HasAction, timeoutMiddleware(wrapperPost(func() any { return &PartitionReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.hasPartitions))))) @@ -534,7 +534,7 @@ func (h *HandlersV2) query(ctx context.Context, c *gin.Context, anyReq any, dbNa } func (h *HandlersV2) get(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) { - httpReq := anyReq.(*CollectionIDOutputReq) + httpReq := anyReq.(*CollectionIDReq) collSchema, err := h.GetCollectionSchema(ctx, c, dbName, httpReq.CollectionName) if err != nil { return nil, err @@ -577,7 +577,7 @@ func (h *HandlersV2) get(ctx context.Context, c *gin.Context, anyReq any, dbName } func (h *HandlersV2) delete(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) { - httpReq := anyReq.(*CollectionIDFilterReq) + httpReq := anyReq.(*CollectionFilterReq) collSchema, err := h.GetCollectionSchema(ctx, c, dbName, httpReq.CollectionName) if err != nil { return nil, err @@ -728,6 +728,43 @@ func (h *HandlersV2) upsert(ctx context.Context, c *gin.Context, anyReq any, dbN return resp, err } +func generatePlaceholderGroup(ctx context.Context, body string, collSchema *schemapb.CollectionSchema, fieldName string) ([]byte, error) { + var err error + var vectorField *schemapb.FieldSchema + if len(fieldName) == 0 { + for _, field := range collSchema.Fields { + if IsVectorField(field) { + if len(fieldName) == 0 { + fieldName = field.Name + vectorField = field + } else { + return nil, errors.New("search without annsFields, but already found multiple vector fields: [" + fieldName + ", " + field.Name + "]") + } + } + } + } else { + for _, field := range collSchema.Fields { + if field.Name == fieldName && IsVectorField(field) { + vectorField = field + break + } + } + } + if vectorField == nil { + return nil, errors.New("cannot find a vector field named: " + fieldName) + } + dim, _ := getDim(vectorField) + phv, err := convertVectors2Placeholder(body, vectorField.DataType, dim) + if err != nil { + return nil, err + } + return proto.Marshal(&commonpb.PlaceholderGroup{ + Placeholders: []*commonpb.PlaceholderValue{ + phv, + }, + }) +} + func generateSearchParams(ctx context.Context, c *gin.Context, reqParams map[string]float64) ([]*commonpb.KeyValuePair, error) { params := map[string]interface{}{ // auto generated mapping "level": int(commonpb.ConsistencyLevel_Bounded), @@ -759,6 +796,10 @@ func generateSearchParams(ctx context.Context, c *gin.Context, reqParams map[str func (h *HandlersV2) search(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) { httpReq := anyReq.(*SearchReqV2) + collSchema, err := h.GetCollectionSchema(ctx, c, dbName, httpReq.CollectionName) + if err != nil { + return nil, err + } searchParams, err := generateSearchParams(ctx, c, httpReq.Params) if err != nil { return nil, err @@ -766,12 +807,23 @@ func (h *HandlersV2) search(ctx context.Context, c *gin.Context, anyReq any, dbN searchParams = append(searchParams, &commonpb.KeyValuePair{Key: common.TopKKey, Value: strconv.FormatInt(int64(httpReq.Limit), 10)}) searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamOffset, Value: strconv.FormatInt(int64(httpReq.Offset), 10)}) searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamGroupByField, Value: httpReq.GroupByField}) + searchParams = append(searchParams, &commonpb.KeyValuePair{Key: proxy.AnnsFieldKey, Value: httpReq.AnnsField}) searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamRoundDecimal, Value: "-1"}) + body, _ := c.Get(gin.BodyBytesKey) + placeholderGroup, err := generatePlaceholderGroup(ctx, string(body.([]byte)), collSchema, httpReq.AnnsField) + if err != nil { + log.Ctx(ctx).Warn("high level restful api, search with vector invalid", zap.Error(err)) + c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(merr.ErrIncorrectParameterFormat), + HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error() + ", error: " + err.Error(), + }) + return nil, err + } req := &milvuspb.SearchRequest{ DbName: dbName, CollectionName: httpReq.CollectionName, Dsl: httpReq.Filter, - PlaceholderGroup: vectors2PlaceholderGroupBytes(httpReq.Vector), + PlaceholderGroup: placeholderGroup, DslType: commonpb.DslType_BoolExprV1, OutputFields: httpReq.OutputFields, PartitionNames: httpReq.PartitionNames, @@ -803,14 +855,20 @@ func (h *HandlersV2) search(ctx context.Context, c *gin.Context, anyReq any, dbN return resp, err } -func (h *HandlersV2) hybridSearch(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) { +func (h *HandlersV2) advancedSearch(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) { httpReq := anyReq.(*HybridSearchReq) req := &milvuspb.HybridSearchRequest{ DbName: dbName, CollectionName: httpReq.CollectionName, Requests: []*milvuspb.SearchRequest{}, } - for _, subReq := range httpReq.Search { + collSchema, err := h.GetCollectionSchema(ctx, c, dbName, httpReq.CollectionName) + if err != nil { + return nil, err + } + body, _ := c.Get(gin.BodyBytesKey) + searchArray := gjson.Get(string(body.([]byte)), "search").Array() + for i, subReq := range httpReq.Search { searchParams, err := generateSearchParams(ctx, c, subReq.Params) if err != nil { return nil, err @@ -820,11 +878,20 @@ func (h *HandlersV2) hybridSearch(ctx context.Context, c *gin.Context, anyReq an searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamGroupByField, Value: subReq.GroupByField}) searchParams = append(searchParams, &commonpb.KeyValuePair{Key: proxy.AnnsFieldKey, Value: subReq.AnnsField}) searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamRoundDecimal, Value: "-1"}) + placeholderGroup, err := generatePlaceholderGroup(ctx, searchArray[i].Raw, collSchema, subReq.AnnsField) + if err != nil { + log.Ctx(ctx).Warn("high level restful api, search with vector invalid", zap.Error(err)) + c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(merr.ErrIncorrectParameterFormat), + HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error() + ", error: " + err.Error(), + }) + return nil, err + } searchReq := &milvuspb.SearchRequest{ DbName: dbName, CollectionName: httpReq.CollectionName, Dsl: subReq.Filter, - PlaceholderGroup: vectors2PlaceholderGroupBytes(subReq.Vector), + PlaceholderGroup: placeholderGroup, DslType: commonpb.DslType_BoolExprV1, OutputFields: httpReq.OutputFields, PartitionNames: httpReq.PartitionNames, @@ -871,20 +938,53 @@ func (h *HandlersV2) createCollection(ctx context.Context, c *gin.Context, anyRe var err error fieldNames := map[string]bool{} if httpReq.Schema.Fields == nil || len(httpReq.Schema.Fields) == 0 { + if httpReq.Dimension == 0 { + err := merr.WrapErrParameterInvalid("collectionName & dimension", "collectionName", + "dimension is required for quickly create collection(default metric type: "+DefaultMetricType+")") + log.Ctx(ctx).Warn("high level restful api, quickly create collection fail", zap.Error(err), zap.Any("request", anyReq)) + c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(err), + HTTPReturnMessage: err.Error(), + }) + return nil, err + } + idDataType := schemapb.DataType_Int64 + switch httpReq.IDType { + case "Varchar": + idDataType = schemapb.DataType_VarChar + case "": + httpReq.IDType = "Int64" + case "Int64": + default: + err := merr.WrapErrParameterInvalid("Int64, Varchar", httpReq.IDType, + "idType can only be [Int64, Varchar](case sensitive), default: Int64") + log.Ctx(ctx).Warn("high level restful api, quickly create collection fail", zap.Error(err), zap.Any("request", anyReq)) + c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(err), + HTTPReturnMessage: err.Error(), + }) + return nil, err + } + if len(httpReq.PrimaryFieldName) == 0 { + httpReq.PrimaryFieldName = PrimaryFieldName + } + if len(httpReq.VectorFieldName) == 0 { + httpReq.VectorFieldName = VectorFieldName + } schema, err = proto.Marshal(&schemapb.CollectionSchema{ Name: httpReq.CollectionName, AutoID: EnableAutoID, Fields: []*schemapb.FieldSchema{ { FieldID: common.StartOfUserFieldID, - Name: PrimaryFieldName, + Name: httpReq.PrimaryFieldName, IsPrimaryKey: true, - DataType: schemapb.DataType_Int64, + DataType: idDataType, AutoID: EnableAutoID, }, { FieldID: common.StartOfUserFieldID + 1, - Name: VectorFieldName, + Name: httpReq.VectorFieldName, IsPrimaryKey: false, DataType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{ @@ -967,6 +1067,9 @@ func (h *HandlersV2) createCollection(ctx context.Context, c *gin.Context, anyRe return resp, err } if httpReq.Schema.Fields == nil || len(httpReq.Schema.Fields) == 0 { + if len(httpReq.MetricType) == 0 { + httpReq.MetricType = DefaultMetricType + } createIndexReq := &milvuspb.CreateIndexRequest{ DbName: dbName, CollectionName: httpReq.CollectionName, diff --git a/internal/distributed/proxy/httpserver/handler_v2_test.go b/internal/distributed/proxy/httpserver/handler_v2_test.go index 215d012a24..a7422caec9 100644 --- a/internal/distributed/proxy/httpserver/handler_v2_test.go +++ b/internal/distributed/proxy/httpserver/handler_v2_test.go @@ -400,13 +400,34 @@ func TestDatabaseWrapper(t *testing.T) { func TestCreateCollection(t *testing.T) { postTestCases := []requestBodyTestCase{} mp := mocks.NewMockProxy(t) - mp.EXPECT().CreateCollection(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Times(7) - mp.EXPECT().CreateIndex(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Twice() - mp.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Twice() + mp.EXPECT().CreateCollection(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Times(9) + mp.EXPECT().CreateIndex(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Times(4) + mp.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Times(4) mp.EXPECT().CreateIndex(mock.Anything, mock.Anything).Return(commonErrorStatus, nil).Twice() mp.EXPECT().CreateCollection(mock.Anything, mock.Anything).Return(commonErrorStatus, nil).Once() testEngine := initHTTPServerV2(mp, false) path := versionalV2(CollectionCategory, CreateAction) + // quickly create collection + postTestCases = append(postTestCases, requestBodyTestCase{ + path: path, + requestBody: []byte(`{"collectionName": "` + DefaultCollectionName + `"}`), + errMsg: "dimension is required for quickly create collection(default metric type: COSINE): invalid parameter[expected=collectionName & dimension][actual=collectionName]", + errCode: 1100, // ErrParameterInvalid + }) + postTestCases = append(postTestCases, requestBodyTestCase{ + path: path, + requestBody: []byte(`{"collectionName": "` + DefaultCollectionName + `", "dimension": 2, "idType": "Varchar"}`), + }) + postTestCases = append(postTestCases, requestBodyTestCase{ + path: path, + requestBody: []byte(`{"collectionName": "` + DefaultCollectionName + `", "dimension": 2, "idType": "unknown"}`), + errMsg: "idType can only be [Int64, Varchar](case sensitive), default: Int64: invalid parameter[expected=Int64, Varchar][actual=unknown]", + errCode: 1100, // ErrParameterInvalid + }) + postTestCases = append(postTestCases, requestBodyTestCase{ + path: path, + requestBody: []byte(`{"collectionName": "` + DefaultCollectionName + `", "dimension": 2}`), + }) postTestCases = append(postTestCases, requestBodyTestCase{ path: path, requestBody: []byte(`{"collectionName": "` + DefaultCollectionName + `", "dimension": 2, "metricType": "L2"}`), @@ -995,13 +1016,7 @@ func TestDML(t *testing.T) { Status: &StatusSuccess, }, nil).Times(6) mp.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{Status: commonErrorStatus}, nil).Times(4) - mp.EXPECT().Search(mock.Anything, mock.Anything).Return(&milvuspb.SearchResults{Status: commonSuccessStatus, Results: &schemapb.SearchResultData{TopK: int64(0)}}, nil).Times(3) - mp.EXPECT().Search(mock.Anything, mock.Anything).Return(&milvuspb.SearchResults{Status: &commonpb.Status{ - ErrorCode: 1700, // ErrFieldNotFound - Reason: "groupBy field not found in schema: field not found[field=test]", - }}, nil).Once() - mp.EXPECT().HybridSearch(mock.Anything, mock.Anything).Return(&milvuspb.SearchResults{Status: commonSuccessStatus, Results: &schemapb.SearchResultData{TopK: int64(0)}}, nil).Twice() - mp.EXPECT().Query(mock.Anything, mock.Anything).Return(&milvuspb.QueryResults{Status: commonSuccessStatus, OutputFields: []string{}, FieldsData: []*schemapb.FieldData{}}, nil).Twice() + mp.EXPECT().Query(mock.Anything, mock.Anything).Return(&milvuspb.QueryResults{Status: commonSuccessStatus, OutputFields: []string{}, FieldsData: []*schemapb.FieldData{}}, nil).Times(3) mp.EXPECT().Insert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{Status: commonSuccessStatus, InsertCnt: int64(0), IDs: &schemapb.IDs{IdField: &schemapb.IDs_IntId{IntId: &schemapb.LongArray{Data: []int64{}}}}}, nil).Once() mp.EXPECT().Insert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{Status: commonSuccessStatus, InsertCnt: int64(0), IDs: &schemapb.IDs{IdField: &schemapb.IDs_StrId{StrId: &schemapb.StringArray{Data: []string{}}}}}, nil).Once() mp.EXPECT().Upsert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{Status: commonSuccessStatus, UpsertCnt: int64(0), IDs: &schemapb.IDs{IdField: &schemapb.IDs_IntId{IntId: &schemapb.LongArray{Data: []int64{}}}}}, nil).Once() @@ -1009,51 +1024,19 @@ func TestDML(t *testing.T) { mp.EXPECT().Delete(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{Status: commonSuccessStatus}, nil).Once() testEngine := initHTTPServerV2(mp, false) queryTestCases := []requestBodyTestCase{} - queryTestCases = append(queryTestCases, requestBodyTestCase{ - path: SearchAction, - requestBody: []byte(`{"collectionName": "book", "vector": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"]}`), - }) - queryTestCases = append(queryTestCases, requestBodyTestCase{ - path: SearchAction, - requestBody: []byte(`{"collectionName": "book", "vector": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "params": {"radius":0.9}}`), - }) - queryTestCases = append(queryTestCases, requestBodyTestCase{ - path: SearchAction, - requestBody: []byte(`{"collectionName": "book", "vector": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "params": {"range_filter": 0.1}}`), - errMsg: "can only accept json format request, error: invalid search params", - errCode: 1801, // ErrIncorrectParameterFormat - }) - queryTestCases = append(queryTestCases, requestBodyTestCase{ - path: SearchAction, - requestBody: []byte(`{"collectionName": "book", "vector": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "params": {"radius":0.9, "range_filter": 0.1}, "groupingField": "word_count"}`), - }) - queryTestCases = append(queryTestCases, requestBodyTestCase{ - path: SearchAction, - requestBody: []byte(`{"collectionName": "book", "vector": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "params": {"radius":0.9, "range_filter": 0.1}, "groupingField": "test"}`), - errMsg: "groupBy field not found in schema: field not found[field=test]", - errCode: 65535, - }) - queryTestCases = append(queryTestCases, requestBodyTestCase{ - path: SearchAction, - requestBody: []byte(`{"collectionName": "book", "vector": [["0.1", "0.2"]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "params": {"radius":0.9, "range_filter": 0.1}, "groupingField": "test"}`), - errMsg: "can only accept json format request, error: json: cannot unmarshal string into Go struct field SearchReqV2.vector of type float32", - errCode: 1801, - }) - queryTestCases = append(queryTestCases, requestBodyTestCase{ - path: HybridSearchAction, - requestBody: []byte(`{"collectionName": "hello_milvus", "search": [{"vector": [[0.1, 0.2]], "annsField": "float_vector1", "metricType": "L2", "limit": 3}, {"vector": [[0.1, 0.2]], "annsField": "float_vector2", "metricType": "L2", "limit": 3}], "rerank": {"strategy": "rrf", "params": {"k": 1}}}`), - }) - queryTestCases = append(queryTestCases, requestBodyTestCase{ - path: HybridSearchAction, - requestBody: []byte(`{"collectionName": "hello_milvus", "search": [{"vector": [[0.1, 0.2]], "annsField": "float_vector1", "metricType": "L2", "limit": 3}, {"vector": [[0.1, 0.2]], "annsField": "float_vector2", "metricType": "L2", "limit": 3}], "rerank": {"strategy": "weighted", "params": {"weights": [0.9, 0.8]}}}`), - }) queryTestCases = append(queryTestCases, requestBodyTestCase{ path: QueryAction, requestBody: []byte(`{"collectionName": "book", "filter": "book_id in [2, 4, 6, 8]", "outputFields": ["book_id", "word_count", "book_intro"], "offset": 1}`), }) queryTestCases = append(queryTestCases, requestBodyTestCase{ path: GetAction, - requestBody: []byte(`{"collectionName": "book", "id" : [2, 4, 6, 8, 0], "outputFields": ["book_id", "word_count", "book_intro"]}`), + requestBody: []byte(`{"collectionName": "book", "outputFields": ["book_id", "word_count", "book_intro"]}`), + errMsg: "missing required parameters, error: Key: 'CollectionIDReq.ID' Error:Field validation for 'ID' failed on the 'required' tag", + errCode: 1802, // ErrMissingRequiredParameters + }) + queryTestCases = append(queryTestCases, requestBodyTestCase{ + path: QueryAction, + requestBody: []byte(`{"collectionName": "book", "filter": "book_id in [2, 4, 6, 8]"}`), }) queryTestCases = append(queryTestCases, requestBodyTestCase{ path: InsertAction, @@ -1071,9 +1054,19 @@ func TestDML(t *testing.T) { path: UpsertAction, requestBody: []byte(`{"collectionName": "book", "data": [{"book_id": 0, "word_count": 0, "book_intro": [0.11825, 0.6]}]}`), }) + queryTestCases = append(queryTestCases, requestBodyTestCase{ + path: DeleteAction, + requestBody: []byte(`{"collectionName": "book", "filter": "book_id in [0]"}`), + }) queryTestCases = append(queryTestCases, requestBodyTestCase{ path: DeleteAction, requestBody: []byte(`{"collectionName": "book", "id" : [0]}`), + errMsg: "missing required parameters, error: Key: 'CollectionFilterReq.Filter' Error:Field validation for 'Filter' failed on the 'required' tag", + errCode: 1802, // ErrMissingRequiredParameters + }) + queryTestCases = append(queryTestCases, requestBodyTestCase{ + path: GetAction, + requestBody: []byte(`{"collectionName": "book", "id" : [2, 4, 6, 8, 0], "outputFields": ["book_id", "word_count", "book_intro"]}`), }) queryTestCases = append(queryTestCases, requestBodyTestCase{ path: GetAction, @@ -1095,7 +1088,7 @@ func TestDML(t *testing.T) { }) queryTestCases = append(queryTestCases, requestBodyTestCase{ path: DeleteAction, - requestBody: []byte(`{"collectionName": "book", "id" : [0]}`), + requestBody: []byte(`{"collectionName": "book", "filter": "book_id in [0]"}`), errMsg: "", errCode: 65535, }) @@ -1120,3 +1113,180 @@ func TestDML(t *testing.T) { }) } } + +func TestSearchV2(t *testing.T) { + paramtable.Init() + mp := mocks.NewMockProxy(t) + mp.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{ + CollectionName: DefaultCollectionName, + Schema: generateCollectionSchema(schemapb.DataType_Int64), + ShardsNum: ShardNumDefault, + Status: &StatusSuccess, + }, nil).Times(9) + mp.EXPECT().Search(mock.Anything, mock.Anything).Return(&milvuspb.SearchResults{Status: commonSuccessStatus, Results: &schemapb.SearchResultData{TopK: int64(0)}}, nil).Times(3) + mp.EXPECT().Search(mock.Anything, mock.Anything).Return(&milvuspb.SearchResults{Status: &commonpb.Status{ + ErrorCode: 1700, // ErrFieldNotFound + Reason: "groupBy field not found in schema: field not found[field=test]", + }}, nil).Once() + mp.EXPECT().HybridSearch(mock.Anything, mock.Anything).Return(&milvuspb.SearchResults{Status: commonSuccessStatus, Results: &schemapb.SearchResultData{TopK: int64(0)}}, nil).Twice() + collSchema := generateCollectionSchema(schemapb.DataType_Int64) + binaryVectorField := generateVectorFieldSchema(schemapb.DataType_BinaryVector) + binaryVectorField.Name = "binaryVector" + float16VectorField := generateVectorFieldSchema(schemapb.DataType_Float16Vector) + float16VectorField.Name = "float16Vector" + bfloat16VectorField := generateVectorFieldSchema(schemapb.DataType_BFloat16Vector) + bfloat16VectorField.Name = "bfloat16Vector" + collSchema.Fields = append(collSchema.Fields, &binaryVectorField) + collSchema.Fields = append(collSchema.Fields, &float16VectorField) + collSchema.Fields = append(collSchema.Fields, &bfloat16VectorField) + mp.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{ + CollectionName: DefaultCollectionName, + Schema: collSchema, + ShardsNum: ShardNumDefault, + Status: &StatusSuccess, + }, nil).Times(9) + mp.EXPECT().Search(mock.Anything, mock.Anything).Return(&milvuspb.SearchResults{Status: commonSuccessStatus, Results: &schemapb.SearchResultData{TopK: int64(0)}}, nil).Twice() + testEngine := initHTTPServerV2(mp, false) + queryTestCases := []requestBodyTestCase{} + queryTestCases = append(queryTestCases, requestBodyTestCase{ + path: SearchAction, + requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"]}`), + }) + queryTestCases = append(queryTestCases, requestBodyTestCase{ + path: SearchAction, + requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "params": {"radius":0.9}}`), + }) + queryTestCases = append(queryTestCases, requestBodyTestCase{ + path: SearchAction, + requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "params": {"range_filter": 0.1}}`), + errMsg: "can only accept json format request, error: invalid search params", + errCode: 1801, // ErrIncorrectParameterFormat + }) + queryTestCases = append(queryTestCases, requestBodyTestCase{ + path: SearchAction, + requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "params": {"radius":0.9, "range_filter": 0.1}, "groupingField": "word_count"}`), + }) + queryTestCases = append(queryTestCases, requestBodyTestCase{ + path: SearchAction, + requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "params": {"radius":0.9, "range_filter": 0.1}, "groupingField": "test"}`), + errMsg: "groupBy field not found in schema: field not found[field=test]", + errCode: 65535, + }) + queryTestCases = append(queryTestCases, requestBodyTestCase{ + path: SearchAction, + requestBody: []byte(`{"collectionName": "book", "data": [["0.1", "0.2"]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "params": {"radius":0.9, "range_filter": 0.1}, "groupingField": "test"}`), + errMsg: "can only accept json format request, error: json: cannot unmarshal string into Go value of type float32: invalid parameter[expected=FloatVector][actual=[\"0.1\", \"0.2\"]]", + errCode: 1801, + }) + queryTestCases = append(queryTestCases, requestBodyTestCase{ + path: AdvancedSearchAction, + requestBody: []byte(`{"collectionName": "hello_milvus", "search": [{"data": [[0.1, 0.2]], "annsField": "book_intro", "metricType": "L2", "limit": 3}, {"data": [[0.1, 0.2]], "annsField": "book_intro", "metricType": "L2", "limit": 3}], "rerank": {"strategy": "weighted", "params": {"weights": [0.9, 0.8]}}}`), + }) + // annsField + queryTestCases = append(queryTestCases, requestBodyTestCase{ + path: SearchAction, + requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "annsField": "word_count", "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "params": {"radius":0.9, "range_filter": 0.1}, "groupingField": "test"}`), + errMsg: "can only accept json format request, error: cannot find a vector field named: word_count", + errCode: 1801, + }) + queryTestCases = append(queryTestCases, requestBodyTestCase{ + path: AdvancedSearchAction, + requestBody: []byte(`{"collectionName": "hello_milvus", "search": [{"data": [[0.1, 0.2]], "annsField": "float_vector1", "metricType": "L2", "limit": 3}, {"data": [[0.1, 0.2]], "annsField": "float_vector2", "metricType": "L2", "limit": 3}], "rerank": {"strategy": "rrf", "params": {"k": 1}}}`), + errMsg: "can only accept json format request, error: cannot find a vector field named: float_vector1", + errCode: 1801, + }) + // multiple annsFields + queryTestCases = append(queryTestCases, requestBodyTestCase{ + path: SearchAction, + requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"]}`), + errMsg: "can only accept json format request, error: search without annsFields, but already found multiple vector fields: [book_intro, binaryVector]", + errCode: 1801, + }) + queryTestCases = append(queryTestCases, requestBodyTestCase{ + path: SearchAction, + requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "annsField": "book_intro", "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"]}`), + }) + queryTestCases = append(queryTestCases, requestBodyTestCase{ + path: SearchAction, + requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "annsField": "binaryVector", "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"]}`), + errMsg: "can only accept json format request, error: json: cannot unmarshal number 0.1 into Go value of type uint8: invalid parameter[expected=BinaryVector][actual=[[0.1, 0.2]]]", + errCode: 1801, + }) + queryTestCases = append(queryTestCases, requestBodyTestCase{ + path: SearchAction, + requestBody: []byte(`{"collectionName": "book", "data": ["AQ=="], "annsField": "binaryVector", "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"]}`), + }) + queryTestCases = append(queryTestCases, requestBodyTestCase{ + path: AdvancedSearchAction, + requestBody: []byte(`{"collectionName": "hello_milvus", "search": [` + + `{"data": [[0.1, 0.2]], "annsField": "book_intro", "metricType": "L2", "limit": 3},` + + `{"data": ["AQ=="], "annsField": "binaryVector", "metricType": "L2", "limit": 3},` + + `{"data": ["AQIDBA=="], "annsField": "float16Vector", "metricType": "L2", "limit": 3},` + + `{"data": ["AQIDBA=="], "annsField": "bfloat16Vector", "metricType": "L2", "limit": 3}` + + `], "rerank": {"strategy": "weighted", "params": {"weights": [0.9, 0.8]}}}`), + }) + queryTestCases = append(queryTestCases, requestBodyTestCase{ + path: AdvancedSearchAction, + requestBody: []byte(`{"collectionName": "hello_milvus", "search": [` + + `{"data": [[0.1, 0.2, 0.3]], "annsField": "book_intro", "metricType": "L2", "limit": 3},` + + `{"data": ["AQ=="], "annsField": "binaryVector", "metricType": "L2", "limit": 3},` + + `{"data": ["AQIDBA=="], "annsField": "float16Vector", "metricType": "L2", "limit": 3},` + + `{"data": ["AQIDBA=="], "annsField": "bfloat16Vector", "metricType": "L2", "limit": 3}` + + `], "rerank": {"strategy": "weighted", "params": {"weights": [0.9, 0.8]}}}`), + errMsg: "can only accept json format request, error: dimension: 2, but length of []float: 3: invalid parameter[expected=FloatVector][actual=[0.1, 0.2, 0.3]]", + errCode: 1801, + }) + queryTestCases = append(queryTestCases, requestBodyTestCase{ + path: AdvancedSearchAction, + requestBody: []byte(`{"collectionName": "hello_milvus", "search": [` + + `{"data": [[0.1, 0.2]], "annsField": "book_intro", "metricType": "L2", "limit": 3},` + + `{"data": ["AQID"], "annsField": "binaryVector", "metricType": "L2", "limit": 3},` + + `{"data": ["AQIDBA=="], "annsField": "float16Vector", "metricType": "L2", "limit": 3},` + + `{"data": ["AQIDBA=="], "annsField": "bfloat16Vector", "metricType": "L2", "limit": 3}` + + `], "rerank": {"strategy": "weighted", "params": {"weights": [0.9, 0.8]}}}`), + errMsg: "can only accept json format request, error: dimension: 8, bytesLen: 1, but length of []byte: 3: invalid parameter[expected=BinaryVector][actual=\x01\x02\x03]", + errCode: 1801, + }) + queryTestCases = append(queryTestCases, requestBodyTestCase{ + path: AdvancedSearchAction, + requestBody: []byte(`{"collectionName": "hello_milvus", "search": [` + + `{"data": [[0.1, 0.2]], "annsField": "book_intro", "metricType": "L2", "limit": 3},` + + `{"data": ["AQ=="], "annsField": "binaryVector", "metricType": "L2", "limit": 3},` + + `{"data": ["AQID"], "annsField": "float16Vector", "metricType": "L2", "limit": 3},` + + `{"data": ["AQIDBA=="], "annsField": "bfloat16Vector", "metricType": "L2", "limit": 3}` + + `], "rerank": {"strategy": "weighted", "params": {"weights": [0.9, 0.8]}}}`), + errMsg: "can only accept json format request, error: dimension: 2, bytesLen: 4, but length of []byte: 3: invalid parameter[expected=Float16Vector][actual=\x01\x02\x03]", + errCode: 1801, + }) + queryTestCases = append(queryTestCases, requestBodyTestCase{ + path: AdvancedSearchAction, + requestBody: []byte(`{"collectionName": "hello_milvus", "search": [` + + `{"data": [[0.1, 0.2]], "annsField": "book_intro", "metricType": "L2", "limit": 3},` + + `{"data": ["AQ=="], "annsField": "binaryVector", "metricType": "L2", "limit": 3},` + + `{"data": ["AQIDBA=="], "annsField": "float16Vector", "metricType": "L2", "limit": 3},` + + `{"data": ["AQID"], "annsField": "bfloat16Vector", "metricType": "L2", "limit": 3}` + + `], "rerank": {"strategy": "weighted", "params": {"weights": [0.9, 0.8]}}}`), + errMsg: "can only accept json format request, error: dimension: 2, bytesLen: 4, but length of []byte: 3: invalid parameter[expected=BFloat16Vector][actual=\x01\x02\x03]", + errCode: 1801, + }) + + for _, testcase := range queryTestCases { + t.Run("search", func(t *testing.T) { + bodyReader := bytes.NewReader(testcase.requestBody) + req := httptest.NewRequest(http.MethodPost, versionalV2(EntityCategory, testcase.path), bodyReader) + w := httptest.NewRecorder() + testEngine.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + returnBody := &ReturnErrMsg{} + err := json.Unmarshal(w.Body.Bytes(), returnBody) + assert.Nil(t, err) + if testcase.errCode != 0 { + assert.Equal(t, testcase.errCode, returnBody.Code) + assert.Equal(t, testcase.errMsg, returnBody.Message) + } else { + assert.Equal(t, int32(http.StatusOK), returnBody.Code) + } + fmt.Println(w.Body.String()) + }) + } +} diff --git a/internal/distributed/proxy/httpserver/request_v2.go b/internal/distributed/proxy/httpserver/request_v2.go index 569af4a0dd..a4404a1b1f 100644 --- a/internal/distributed/proxy/httpserver/request_v2.go +++ b/internal/distributed/proxy/httpserver/request_v2.go @@ -99,7 +99,7 @@ type QueryReqV2 struct { func (req *QueryReqV2) GetDbName() string { return req.DbName } -type CollectionIDOutputReq struct { +type CollectionIDReq struct { DbName string `json:"dbName"` CollectionName string `json:"collectionName" binding:"required"` PartitionName string `json:"partitionName"` @@ -108,17 +108,16 @@ type CollectionIDOutputReq struct { ID interface{} `json:"id" binding:"required"` } -func (req *CollectionIDOutputReq) GetDbName() string { return req.DbName } +func (req *CollectionIDReq) GetDbName() string { return req.DbName } -type CollectionIDFilterReq struct { - DbName string `json:"dbName"` - CollectionName string `json:"collectionName" binding:"required"` - PartitionName string `json:"partitionName"` - ID interface{} `json:"id"` - Filter string `json:"filter"` +type CollectionFilterReq struct { + DbName string `json:"dbName"` + CollectionName string `json:"collectionName" binding:"required"` + PartitionName string `json:"partitionName"` + Filter string `json:"filter" binding:"required"` } -func (req *CollectionIDFilterReq) GetDbName() string { return req.DbName } +func (req *CollectionFilterReq) GetDbName() string { return req.DbName } type CollectionDataReq struct { DbName string `json:"dbName"` @@ -132,7 +131,8 @@ func (req *CollectionDataReq) GetDbName() string { return req.DbName } type SearchReqV2 struct { DbName string `json:"dbName"` CollectionName string `json:"collectionName" binding:"required"` - Vector [][]float32 `json:"vector"` + Data []interface{} `json:"data" binding:"required"` + AnnsField string `json:"annsField"` PartitionNames []string `json:"partitionNames"` Filter string `json:"filter"` GroupByField string `json:"groupingField"` @@ -150,7 +150,7 @@ type Rand struct { } type SubSearchReq struct { - Vector [][]float32 `json:"vector"` + Data []interface{} `json:"data" binding:"required"` AnnsField string `json:"annsField"` Filter string `json:"filter"` GroupByField string `json:"groupingField"` @@ -296,12 +296,15 @@ type CollectionSchema struct { } type CollectionReq struct { - DbName string `json:"dbName"` - CollectionName string `json:"collectionName" binding:"required"` - Dimension int32 `json:"dimension"` - MetricType string `json:"metricType"` - Schema CollectionSchema `json:"schema"` - IndexParams []IndexParam `json:"indexParams"` + DbName string `json:"dbName"` + CollectionName string `json:"collectionName" binding:"required"` + Dimension int32 `json:"dimension"` + IDType string `json:"idType"` + MetricType string `json:"metricType"` + PrimaryFieldName string `json:"primaryFieldName"` + VectorFieldName string `json:"vectorFieldName"` + Schema CollectionSchema `json:"schema"` + IndexParams []IndexParam `json:"indexParams"` } func (req *CollectionReq) GetDbName() string { return req.DbName } diff --git a/internal/distributed/proxy/httpserver/utils.go b/internal/distributed/proxy/httpserver/utils.go index eacf572dcb..3434b670a7 100644 --- a/internal/distributed/proxy/httpserver/utils.go +++ b/internal/distributed/proxy/httpserver/utils.go @@ -183,7 +183,7 @@ func printIndexes(indexes []*milvuspb.IndexDescription) []gin.H { func checkAndSetData(body string, collSchema *schemapb.CollectionSchema) (error, []map[string]interface{}) { var reallyDataArray []map[string]interface{} - dataResult := gjson.Get(body, "data") + dataResult := gjson.Get(body, HTTPRequestData) dataResultArray := dataResult.Array() if len(dataResultArray) == 0 { return merr.ErrMissingRequiredParameters, reallyDataArray @@ -914,6 +914,67 @@ func serialize(fv []float32) []byte { return data } +func serializeFloatVectors(vectors []gjson.Result, dataType schemapb.DataType, dimension, bytesLen int64) ([][]byte, error) { + values := make([][]byte, 0) + for _, vector := range vectors { + var vectorArray []float32 + err := json.Unmarshal([]byte(vector.String()), &vectorArray) + if err != nil { + return nil, merr.WrapErrParameterInvalid(schemapb.DataType_name[int32(dataType)], vector.String(), err.Error()) + } + if int64(len(vectorArray)) != dimension { + return nil, merr.WrapErrParameterInvalid(schemapb.DataType_name[int32(dataType)], vector.String(), + fmt.Sprintf("dimension: %d, but length of []float: %d", dimension, len(vectorArray))) + } + vectorBytes := serialize(vectorArray) + values = append(values, vectorBytes) + } + return values, nil +} + +func serializeByteVectors(vectorStr string, dataType schemapb.DataType, dimension, bytesLen int64) ([][]byte, error) { + values := make([][]byte, 0) + err := json.Unmarshal([]byte(vectorStr), &values) // todo check len == dimension * 1/2/2 + if err != nil { + return nil, merr.WrapErrParameterInvalid(schemapb.DataType_name[int32(dataType)], vectorStr, err.Error()) + } + for _, vectorArray := range values { + if int64(len(vectorArray)) != bytesLen { + return nil, merr.WrapErrParameterInvalid(schemapb.DataType_name[int32(dataType)], string(vectorArray), + fmt.Sprintf("dimension: %d, bytesLen: %d, but length of []byte: %d", dimension, bytesLen, len(vectorArray))) + } + } + return values, nil +} + +func convertVectors2Placeholder(body string, dataType schemapb.DataType, dimension int64) (*commonpb.PlaceholderValue, error) { + var valueType commonpb.PlaceholderType + var values [][]byte + var err error + switch dataType { + case schemapb.DataType_FloatVector: + valueType = commonpb.PlaceholderType_FloatVector + values, err = serializeFloatVectors(gjson.Get(body, HTTPRequestData).Array(), dataType, dimension, dimension*4) + case schemapb.DataType_BinaryVector: + valueType = commonpb.PlaceholderType_BinaryVector + values, err = serializeByteVectors(gjson.Get(body, HTTPRequestData).Raw, dataType, dimension, dimension/8) + case schemapb.DataType_Float16Vector: + valueType = commonpb.PlaceholderType_Float16Vector + values, err = serializeByteVectors(gjson.Get(body, HTTPRequestData).Raw, dataType, dimension, dimension*2) + case schemapb.DataType_BFloat16Vector: + valueType = commonpb.PlaceholderType_BFloat16Vector + values, err = serializeByteVectors(gjson.Get(body, HTTPRequestData).Raw, dataType, dimension, dimension*2) + } + if err != nil { + return nil, err + } + return &commonpb.PlaceholderValue{ + Tag: "$0", + Type: valueType, + Values: values, + }, nil +} + // todo: support [][]byte for BinaryVector func vectors2PlaceholderGroupBytes(vectors [][]float32) []byte { var placeHolderType commonpb.PlaceholderType diff --git a/internal/distributed/proxy/httpserver/utils_test.go b/internal/distributed/proxy/httpserver/utils_test.go index 2f444ee7e1..6db1e7a9db 100644 --- a/internal/distributed/proxy/httpserver/utils_test.go +++ b/internal/distributed/proxy/httpserver/utils_test.go @@ -8,6 +8,7 @@ import ( "testing" "github.com/gin-gonic/gin" + "github.com/golang/protobuf/proto" "github.com/stretchr/testify/assert" "github.com/tidwall/gjson" @@ -492,6 +493,42 @@ func TestSerialize(t *testing.T) { parameters := []float32{0.11111, 0.22222} assert.Equal(t, "\xa4\x8d\xe3=\xa4\x8dc>", string(serialize(parameters))) assert.Equal(t, "\n\x10\n\x02$0\x10e\x1a\b\xa4\x8d\xe3=\xa4\x8dc>", string(vectors2PlaceholderGroupBytes([][]float32{parameters}))) // todo + requestBody := "{\"data\": [[0.11111, 0.22222]]}" + vectors := gjson.Get(requestBody, HTTPRequestData) + values, err := serializeFloatVectors(vectors.Array(), schemapb.DataType_FloatVector, 2, -1) + assert.Nil(t, err) + placeholderValue := &commonpb.PlaceholderValue{ + Tag: "$0", + Type: commonpb.PlaceholderType_FloatVector, + Values: values, + } + bytes, err := proto.Marshal(&commonpb.PlaceholderGroup{ + Placeholders: []*commonpb.PlaceholderValue{ + placeholderValue, + }, + }) + assert.Nil(t, err) + assert.Equal(t, "\n\x10\n\x02$0\x10e\x1a\b\xa4\x8d\xe3=\xa4\x8dc>", string(bytes)) // todo + for _, dataType := range []schemapb.DataType{schemapb.DataType_BinaryVector, schemapb.DataType_Float16Vector, schemapb.DataType_BFloat16Vector} { + request := map[string]interface{}{ + HTTPRequestData: []interface{}{ + []byte{1, 2}, + }, + } + requestBody, _ := json.Marshal(request) + values, err = serializeByteVectors(gjson.Get(string(requestBody), HTTPRequestData).Raw, dataType, -1, 2) + assert.Nil(t, err) + placeholderValue = &commonpb.PlaceholderValue{ + Tag: "$0", + Values: values, + } + _, err = proto.Marshal(&commonpb.PlaceholderGroup{ + Placeholders: []*commonpb.PlaceholderValue{ + placeholderValue, + }, + }) + assert.Nil(t, err) + } } func compareRow64(m1 map[string]interface{}, m2 map[string]interface{}) bool {