fix: [restful v2] search & advanced_search API (#31113)

issue: #30688
former pr:  #30946

1. param `vector` is required #31012
2. param `annsField` is optional, for multiply vector fields #31010
3. support BinaryVector, Float16Vector, BFloat16Vector #31013
4. replace vector with data, to align with pymilvus milvus_client #31093
5. create collection quickly, to align with pymilvus milvus_client
#31149

---------

Signed-off-by: PowderLi <min.li@zilliz.com>
This commit is contained in:
PowderLi 2024-03-11 15:57:01 +08:00 committed by GitHub
parent 9cfe183253
commit 58d7b9f902
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 483 additions and 103 deletions

View File

@ -1,6 +1,10 @@
package httpserver
import "time"
import (
"time"
"github.com/milvus-io/milvus/pkg/util/metric"
)
// v2
const (
@ -14,23 +18,23 @@ const (
AliasCategory = "/aliases/"
ImportJobCategory = "/jobs/import/"
ListAction = "list"
HasAction = "has"
DescribeAction = "describe"
CreateAction = "create"
DropAction = "drop"
StatsAction = "get_stats"
LoadStateAction = "get_load_state"
RenameAction = "rename"
LoadAction = "load"
ReleaseAction = "release"
QueryAction = "query"
GetAction = "get"
DeleteAction = "delete"
InsertAction = "insert"
UpsertAction = "upsert"
SearchAction = "search"
HybridSearchAction = "hybrid_search"
ListAction = "list"
HasAction = "has"
DescribeAction = "describe"
CreateAction = "create"
DropAction = "drop"
StatsAction = "get_stats"
LoadStateAction = "get_load_state"
RenameAction = "rename"
LoadAction = "load"
ReleaseAction = "release"
QueryAction = "query"
GetAction = "get"
DeleteAction = "delete"
InsertAction = "insert"
UpsertAction = "upsert"
SearchAction = "search"
AdvancedSearchAction = "advanced_search"
UpdatePasswordAction = "update_password"
GrantRoleAction = "grant_role"
@ -72,6 +76,7 @@ const (
HTTPIndexName = "indexName"
HTTPIndexField = "fieldName"
HTTPAliasName = "aliasName"
HTTPRequestData = "data"
DefaultDbName = "default"
DefaultIndexName = "vector_idx"
DefaultAliasName = "the_alias"
@ -112,7 +117,7 @@ const (
HTTPReturnGrantor = "grantor"
HTTPReturnDbName = "dbName"
DefaultMetricType = "L2"
DefaultMetricType = metric.COSINE
DefaultPrimaryFieldName = "id"
DefaultVectorFieldName = "vector"

View File

@ -22,6 +22,7 @@ import (
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/metric"
"github.com/milvus-io/milvus/pkg/util/requestutil"
)
@ -220,7 +221,7 @@ func (h *HandlersV1) listCollections(c *gin.Context) {
func (h *HandlersV1) createCollection(c *gin.Context) {
httpReq := CreateCollectionReq{
DbName: DefaultDbName,
MetricType: DefaultMetricType,
MetricType: metric.L2,
PrimaryField: DefaultPrimaryFieldName,
VectorField: DefaultVectorFieldName,
EnableDynamicField: EnableDynamic,

View File

@ -271,7 +271,7 @@ func TestVectorCollectionsDescribe(t *testing.T) {
name: "get load status fail",
mp: mp2,
exceptCode: http.StatusOK,
expectedBody: "{\"code\":200,\"data\":{\"collectionName\":\"" + DefaultCollectionName + "\",\"description\":\"\",\"enableDynamicField\":true,\"fields\":[{\"autoId\":false,\"description\":\"\",\"name\":\"book_id\",\"partitionKey\":false,\"primaryKey\":true,\"type\":\"Int64\"},{\"autoId\":false,\"description\":\"\",\"name\":\"word_count\",\"partitionKey\":false,\"primaryKey\":false,\"type\":\"Int64\"},{\"autoId\":false,\"description\":\"\",\"name\":\"book_intro\",\"partitionKey\":false,\"primaryKey\":false,\"type\":\"FloatVector(2)\"}],\"indexes\":[{\"fieldName\":\"book_intro\",\"indexName\":\"" + DefaultIndexName + "\",\"metricType\":\"L2\"}],\"load\":\"\",\"shardsNum\":1}}",
expectedBody: "{\"code\":200,\"data\":{\"collectionName\":\"" + DefaultCollectionName + "\",\"description\":\"\",\"enableDynamicField\":true,\"fields\":[{\"autoId\":false,\"description\":\"\",\"name\":\"book_id\",\"partitionKey\":false,\"primaryKey\":true,\"type\":\"Int64\"},{\"autoId\":false,\"description\":\"\",\"name\":\"word_count\",\"partitionKey\":false,\"primaryKey\":false,\"type\":\"Int64\"},{\"autoId\":false,\"description\":\"\",\"name\":\"book_intro\",\"partitionKey\":false,\"primaryKey\":false,\"type\":\"FloatVector(2)\"}],\"indexes\":[{\"fieldName\":\"book_intro\",\"indexName\":\"" + DefaultIndexName + "\",\"metricType\":\"COSINE\"}],\"load\":\"\",\"shardsNum\":1}}",
})
mp3 := mocks.NewMockProxy(t)
@ -293,7 +293,7 @@ func TestVectorCollectionsDescribe(t *testing.T) {
name: "show collection details success",
mp: mp4,
exceptCode: http.StatusOK,
expectedBody: "{\"code\":200,\"data\":{\"collectionName\":\"" + DefaultCollectionName + "\",\"description\":\"\",\"enableDynamicField\":true,\"fields\":[{\"autoId\":false,\"description\":\"\",\"name\":\"book_id\",\"partitionKey\":false,\"primaryKey\":true,\"type\":\"Int64\"},{\"autoId\":false,\"description\":\"\",\"name\":\"word_count\",\"partitionKey\":false,\"primaryKey\":false,\"type\":\"Int64\"},{\"autoId\":false,\"description\":\"\",\"name\":\"book_intro\",\"partitionKey\":false,\"primaryKey\":false,\"type\":\"FloatVector(2)\"}],\"indexes\":[{\"fieldName\":\"book_intro\",\"indexName\":\"" + DefaultIndexName + "\",\"metricType\":\"L2\"}],\"load\":\"LoadStateLoaded\",\"shardsNum\":1}}",
expectedBody: "{\"code\":200,\"data\":{\"collectionName\":\"" + DefaultCollectionName + "\",\"description\":\"\",\"enableDynamicField\":true,\"fields\":[{\"autoId\":false,\"description\":\"\",\"name\":\"book_id\",\"partitionKey\":false,\"primaryKey\":true,\"type\":\"Int64\"},{\"autoId\":false,\"description\":\"\",\"name\":\"word_count\",\"partitionKey\":false,\"primaryKey\":false,\"type\":\"Int64\"},{\"autoId\":false,\"description\":\"\",\"name\":\"book_intro\",\"partitionKey\":false,\"primaryKey\":false,\"type\":\"FloatVector(2)\"}],\"indexes\":[{\"fieldName\":\"book_intro\",\"indexName\":\"" + DefaultIndexName + "\",\"metricType\":\"COSINE\"}],\"load\":\"LoadStateLoaded\",\"shardsNum\":1}}",
})
for _, tt := range testCases {

View File

@ -66,12 +66,12 @@ func (h *HandlersV2) RegisterRoutesToV2(router gin.IRouter) {
}
}, wrapperTraceLog(h.wrapperCheckDatabase(h.query)))))
router.POST(EntityCategory+GetAction, timeoutMiddleware(wrapperPost(func() any {
return &CollectionIDOutputReq{
return &CollectionIDReq{
OutputFields: []string{DefaultOutputFields},
}
}, wrapperTraceLog(h.wrapperCheckDatabase(h.get)))))
router.POST(EntityCategory+DeleteAction, timeoutMiddleware(wrapperPost(func() any {
return &CollectionIDFilterReq{}
return &CollectionFilterReq{}
}, wrapperTraceLog(h.wrapperCheckDatabase(h.delete)))))
router.POST(EntityCategory+InsertAction, timeoutMiddleware(wrapperPost(func() any {
return &CollectionDataReq{}
@ -84,11 +84,11 @@ func (h *HandlersV2) RegisterRoutesToV2(router gin.IRouter) {
Limit: 100,
}
}, wrapperTraceLog(h.wrapperCheckDatabase(h.search)))))
router.POST(EntityCategory+HybridSearchAction, timeoutMiddleware(wrapperPost(func() any {
router.POST(EntityCategory+AdvancedSearchAction, timeoutMiddleware(wrapperPost(func() any {
return &HybridSearchReq{
Limit: 100,
}
}, wrapperTraceLog(h.wrapperCheckDatabase(h.hybridSearch)))))
}, wrapperTraceLog(h.wrapperCheckDatabase(h.advancedSearch)))))
router.POST(PartitionCategory+ListAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.listPartitions)))))
router.POST(PartitionCategory+HasAction, timeoutMiddleware(wrapperPost(func() any { return &PartitionReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.hasPartitions)))))
@ -534,7 +534,7 @@ func (h *HandlersV2) query(ctx context.Context, c *gin.Context, anyReq any, dbNa
}
func (h *HandlersV2) get(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) {
httpReq := anyReq.(*CollectionIDOutputReq)
httpReq := anyReq.(*CollectionIDReq)
collSchema, err := h.GetCollectionSchema(ctx, c, dbName, httpReq.CollectionName)
if err != nil {
return nil, err
@ -577,7 +577,7 @@ func (h *HandlersV2) get(ctx context.Context, c *gin.Context, anyReq any, dbName
}
func (h *HandlersV2) delete(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) {
httpReq := anyReq.(*CollectionIDFilterReq)
httpReq := anyReq.(*CollectionFilterReq)
collSchema, err := h.GetCollectionSchema(ctx, c, dbName, httpReq.CollectionName)
if err != nil {
return nil, err
@ -728,6 +728,43 @@ func (h *HandlersV2) upsert(ctx context.Context, c *gin.Context, anyReq any, dbN
return resp, err
}
func generatePlaceholderGroup(ctx context.Context, body string, collSchema *schemapb.CollectionSchema, fieldName string) ([]byte, error) {
var err error
var vectorField *schemapb.FieldSchema
if len(fieldName) == 0 {
for _, field := range collSchema.Fields {
if IsVectorField(field) {
if len(fieldName) == 0 {
fieldName = field.Name
vectorField = field
} else {
return nil, errors.New("search without annsFields, but already found multiple vector fields: [" + fieldName + ", " + field.Name + "]")
}
}
}
} else {
for _, field := range collSchema.Fields {
if field.Name == fieldName && IsVectorField(field) {
vectorField = field
break
}
}
}
if vectorField == nil {
return nil, errors.New("cannot find a vector field named: " + fieldName)
}
dim, _ := getDim(vectorField)
phv, err := convertVectors2Placeholder(body, vectorField.DataType, dim)
if err != nil {
return nil, err
}
return proto.Marshal(&commonpb.PlaceholderGroup{
Placeholders: []*commonpb.PlaceholderValue{
phv,
},
})
}
func generateSearchParams(ctx context.Context, c *gin.Context, reqParams map[string]float64) ([]*commonpb.KeyValuePair, error) {
params := map[string]interface{}{ // auto generated mapping
"level": int(commonpb.ConsistencyLevel_Bounded),
@ -759,6 +796,10 @@ func generateSearchParams(ctx context.Context, c *gin.Context, reqParams map[str
func (h *HandlersV2) search(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) {
httpReq := anyReq.(*SearchReqV2)
collSchema, err := h.GetCollectionSchema(ctx, c, dbName, httpReq.CollectionName)
if err != nil {
return nil, err
}
searchParams, err := generateSearchParams(ctx, c, httpReq.Params)
if err != nil {
return nil, err
@ -766,12 +807,23 @@ func (h *HandlersV2) search(ctx context.Context, c *gin.Context, anyReq any, dbN
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: common.TopKKey, Value: strconv.FormatInt(int64(httpReq.Limit), 10)})
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamOffset, Value: strconv.FormatInt(int64(httpReq.Offset), 10)})
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamGroupByField, Value: httpReq.GroupByField})
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: proxy.AnnsFieldKey, Value: httpReq.AnnsField})
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamRoundDecimal, Value: "-1"})
body, _ := c.Get(gin.BodyBytesKey)
placeholderGroup, err := generatePlaceholderGroup(ctx, string(body.([]byte)), collSchema, httpReq.AnnsField)
if err != nil {
log.Ctx(ctx).Warn("high level restful api, search with vector invalid", zap.Error(err))
c.AbortWithStatusJSON(http.StatusOK, gin.H{
HTTPReturnCode: merr.Code(merr.ErrIncorrectParameterFormat),
HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error() + ", error: " + err.Error(),
})
return nil, err
}
req := &milvuspb.SearchRequest{
DbName: dbName,
CollectionName: httpReq.CollectionName,
Dsl: httpReq.Filter,
PlaceholderGroup: vectors2PlaceholderGroupBytes(httpReq.Vector),
PlaceholderGroup: placeholderGroup,
DslType: commonpb.DslType_BoolExprV1,
OutputFields: httpReq.OutputFields,
PartitionNames: httpReq.PartitionNames,
@ -803,14 +855,20 @@ func (h *HandlersV2) search(ctx context.Context, c *gin.Context, anyReq any, dbN
return resp, err
}
func (h *HandlersV2) hybridSearch(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) {
func (h *HandlersV2) advancedSearch(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) {
httpReq := anyReq.(*HybridSearchReq)
req := &milvuspb.HybridSearchRequest{
DbName: dbName,
CollectionName: httpReq.CollectionName,
Requests: []*milvuspb.SearchRequest{},
}
for _, subReq := range httpReq.Search {
collSchema, err := h.GetCollectionSchema(ctx, c, dbName, httpReq.CollectionName)
if err != nil {
return nil, err
}
body, _ := c.Get(gin.BodyBytesKey)
searchArray := gjson.Get(string(body.([]byte)), "search").Array()
for i, subReq := range httpReq.Search {
searchParams, err := generateSearchParams(ctx, c, subReq.Params)
if err != nil {
return nil, err
@ -820,11 +878,20 @@ func (h *HandlersV2) hybridSearch(ctx context.Context, c *gin.Context, anyReq an
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamGroupByField, Value: subReq.GroupByField})
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: proxy.AnnsFieldKey, Value: subReq.AnnsField})
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamRoundDecimal, Value: "-1"})
placeholderGroup, err := generatePlaceholderGroup(ctx, searchArray[i].Raw, collSchema, subReq.AnnsField)
if err != nil {
log.Ctx(ctx).Warn("high level restful api, search with vector invalid", zap.Error(err))
c.AbortWithStatusJSON(http.StatusOK, gin.H{
HTTPReturnCode: merr.Code(merr.ErrIncorrectParameterFormat),
HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error() + ", error: " + err.Error(),
})
return nil, err
}
searchReq := &milvuspb.SearchRequest{
DbName: dbName,
CollectionName: httpReq.CollectionName,
Dsl: subReq.Filter,
PlaceholderGroup: vectors2PlaceholderGroupBytes(subReq.Vector),
PlaceholderGroup: placeholderGroup,
DslType: commonpb.DslType_BoolExprV1,
OutputFields: httpReq.OutputFields,
PartitionNames: httpReq.PartitionNames,
@ -871,20 +938,53 @@ func (h *HandlersV2) createCollection(ctx context.Context, c *gin.Context, anyRe
var err error
fieldNames := map[string]bool{}
if httpReq.Schema.Fields == nil || len(httpReq.Schema.Fields) == 0 {
if httpReq.Dimension == 0 {
err := merr.WrapErrParameterInvalid("collectionName & dimension", "collectionName",
"dimension is required for quickly create collection(default metric type: "+DefaultMetricType+")")
log.Ctx(ctx).Warn("high level restful api, quickly create collection fail", zap.Error(err), zap.Any("request", anyReq))
c.AbortWithStatusJSON(http.StatusOK, gin.H{
HTTPReturnCode: merr.Code(err),
HTTPReturnMessage: err.Error(),
})
return nil, err
}
idDataType := schemapb.DataType_Int64
switch httpReq.IDType {
case "Varchar":
idDataType = schemapb.DataType_VarChar
case "":
httpReq.IDType = "Int64"
case "Int64":
default:
err := merr.WrapErrParameterInvalid("Int64, Varchar", httpReq.IDType,
"idType can only be [Int64, Varchar](case sensitive), default: Int64")
log.Ctx(ctx).Warn("high level restful api, quickly create collection fail", zap.Error(err), zap.Any("request", anyReq))
c.AbortWithStatusJSON(http.StatusOK, gin.H{
HTTPReturnCode: merr.Code(err),
HTTPReturnMessage: err.Error(),
})
return nil, err
}
if len(httpReq.PrimaryFieldName) == 0 {
httpReq.PrimaryFieldName = PrimaryFieldName
}
if len(httpReq.VectorFieldName) == 0 {
httpReq.VectorFieldName = VectorFieldName
}
schema, err = proto.Marshal(&schemapb.CollectionSchema{
Name: httpReq.CollectionName,
AutoID: EnableAutoID,
Fields: []*schemapb.FieldSchema{
{
FieldID: common.StartOfUserFieldID,
Name: PrimaryFieldName,
Name: httpReq.PrimaryFieldName,
IsPrimaryKey: true,
DataType: schemapb.DataType_Int64,
DataType: idDataType,
AutoID: EnableAutoID,
},
{
FieldID: common.StartOfUserFieldID + 1,
Name: VectorFieldName,
Name: httpReq.VectorFieldName,
IsPrimaryKey: false,
DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
@ -967,6 +1067,9 @@ func (h *HandlersV2) createCollection(ctx context.Context, c *gin.Context, anyRe
return resp, err
}
if httpReq.Schema.Fields == nil || len(httpReq.Schema.Fields) == 0 {
if len(httpReq.MetricType) == 0 {
httpReq.MetricType = DefaultMetricType
}
createIndexReq := &milvuspb.CreateIndexRequest{
DbName: dbName,
CollectionName: httpReq.CollectionName,

View File

@ -400,13 +400,34 @@ func TestDatabaseWrapper(t *testing.T) {
func TestCreateCollection(t *testing.T) {
postTestCases := []requestBodyTestCase{}
mp := mocks.NewMockProxy(t)
mp.EXPECT().CreateCollection(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Times(7)
mp.EXPECT().CreateIndex(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Twice()
mp.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Twice()
mp.EXPECT().CreateCollection(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Times(9)
mp.EXPECT().CreateIndex(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Times(4)
mp.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Times(4)
mp.EXPECT().CreateIndex(mock.Anything, mock.Anything).Return(commonErrorStatus, nil).Twice()
mp.EXPECT().CreateCollection(mock.Anything, mock.Anything).Return(commonErrorStatus, nil).Once()
testEngine := initHTTPServerV2(mp, false)
path := versionalV2(CollectionCategory, CreateAction)
// quickly create collection
postTestCases = append(postTestCases, requestBodyTestCase{
path: path,
requestBody: []byte(`{"collectionName": "` + DefaultCollectionName + `"}`),
errMsg: "dimension is required for quickly create collection(default metric type: COSINE): invalid parameter[expected=collectionName & dimension][actual=collectionName]",
errCode: 1100, // ErrParameterInvalid
})
postTestCases = append(postTestCases, requestBodyTestCase{
path: path,
requestBody: []byte(`{"collectionName": "` + DefaultCollectionName + `", "dimension": 2, "idType": "Varchar"}`),
})
postTestCases = append(postTestCases, requestBodyTestCase{
path: path,
requestBody: []byte(`{"collectionName": "` + DefaultCollectionName + `", "dimension": 2, "idType": "unknown"}`),
errMsg: "idType can only be [Int64, Varchar](case sensitive), default: Int64: invalid parameter[expected=Int64, Varchar][actual=unknown]",
errCode: 1100, // ErrParameterInvalid
})
postTestCases = append(postTestCases, requestBodyTestCase{
path: path,
requestBody: []byte(`{"collectionName": "` + DefaultCollectionName + `", "dimension": 2}`),
})
postTestCases = append(postTestCases, requestBodyTestCase{
path: path,
requestBody: []byte(`{"collectionName": "` + DefaultCollectionName + `", "dimension": 2, "metricType": "L2"}`),
@ -995,13 +1016,7 @@ func TestDML(t *testing.T) {
Status: &StatusSuccess,
}, nil).Times(6)
mp.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{Status: commonErrorStatus}, nil).Times(4)
mp.EXPECT().Search(mock.Anything, mock.Anything).Return(&milvuspb.SearchResults{Status: commonSuccessStatus, Results: &schemapb.SearchResultData{TopK: int64(0)}}, nil).Times(3)
mp.EXPECT().Search(mock.Anything, mock.Anything).Return(&milvuspb.SearchResults{Status: &commonpb.Status{
ErrorCode: 1700, // ErrFieldNotFound
Reason: "groupBy field not found in schema: field not found[field=test]",
}}, nil).Once()
mp.EXPECT().HybridSearch(mock.Anything, mock.Anything).Return(&milvuspb.SearchResults{Status: commonSuccessStatus, Results: &schemapb.SearchResultData{TopK: int64(0)}}, nil).Twice()
mp.EXPECT().Query(mock.Anything, mock.Anything).Return(&milvuspb.QueryResults{Status: commonSuccessStatus, OutputFields: []string{}, FieldsData: []*schemapb.FieldData{}}, nil).Twice()
mp.EXPECT().Query(mock.Anything, mock.Anything).Return(&milvuspb.QueryResults{Status: commonSuccessStatus, OutputFields: []string{}, FieldsData: []*schemapb.FieldData{}}, nil).Times(3)
mp.EXPECT().Insert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{Status: commonSuccessStatus, InsertCnt: int64(0), IDs: &schemapb.IDs{IdField: &schemapb.IDs_IntId{IntId: &schemapb.LongArray{Data: []int64{}}}}}, nil).Once()
mp.EXPECT().Insert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{Status: commonSuccessStatus, InsertCnt: int64(0), IDs: &schemapb.IDs{IdField: &schemapb.IDs_StrId{StrId: &schemapb.StringArray{Data: []string{}}}}}, nil).Once()
mp.EXPECT().Upsert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{Status: commonSuccessStatus, UpsertCnt: int64(0), IDs: &schemapb.IDs{IdField: &schemapb.IDs_IntId{IntId: &schemapb.LongArray{Data: []int64{}}}}}, nil).Once()
@ -1009,51 +1024,19 @@ func TestDML(t *testing.T) {
mp.EXPECT().Delete(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{Status: commonSuccessStatus}, nil).Once()
testEngine := initHTTPServerV2(mp, false)
queryTestCases := []requestBodyTestCase{}
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: SearchAction,
requestBody: []byte(`{"collectionName": "book", "vector": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"]}`),
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: SearchAction,
requestBody: []byte(`{"collectionName": "book", "vector": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "params": {"radius":0.9}}`),
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: SearchAction,
requestBody: []byte(`{"collectionName": "book", "vector": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "params": {"range_filter": 0.1}}`),
errMsg: "can only accept json format request, error: invalid search params",
errCode: 1801, // ErrIncorrectParameterFormat
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: SearchAction,
requestBody: []byte(`{"collectionName": "book", "vector": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "params": {"radius":0.9, "range_filter": 0.1}, "groupingField": "word_count"}`),
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: SearchAction,
requestBody: []byte(`{"collectionName": "book", "vector": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "params": {"radius":0.9, "range_filter": 0.1}, "groupingField": "test"}`),
errMsg: "groupBy field not found in schema: field not found[field=test]",
errCode: 65535,
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: SearchAction,
requestBody: []byte(`{"collectionName": "book", "vector": [["0.1", "0.2"]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "params": {"radius":0.9, "range_filter": 0.1}, "groupingField": "test"}`),
errMsg: "can only accept json format request, error: json: cannot unmarshal string into Go struct field SearchReqV2.vector of type float32",
errCode: 1801,
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: HybridSearchAction,
requestBody: []byte(`{"collectionName": "hello_milvus", "search": [{"vector": [[0.1, 0.2]], "annsField": "float_vector1", "metricType": "L2", "limit": 3}, {"vector": [[0.1, 0.2]], "annsField": "float_vector2", "metricType": "L2", "limit": 3}], "rerank": {"strategy": "rrf", "params": {"k": 1}}}`),
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: HybridSearchAction,
requestBody: []byte(`{"collectionName": "hello_milvus", "search": [{"vector": [[0.1, 0.2]], "annsField": "float_vector1", "metricType": "L2", "limit": 3}, {"vector": [[0.1, 0.2]], "annsField": "float_vector2", "metricType": "L2", "limit": 3}], "rerank": {"strategy": "weighted", "params": {"weights": [0.9, 0.8]}}}`),
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: QueryAction,
requestBody: []byte(`{"collectionName": "book", "filter": "book_id in [2, 4, 6, 8]", "outputFields": ["book_id", "word_count", "book_intro"], "offset": 1}`),
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: GetAction,
requestBody: []byte(`{"collectionName": "book", "id" : [2, 4, 6, 8, 0], "outputFields": ["book_id", "word_count", "book_intro"]}`),
requestBody: []byte(`{"collectionName": "book", "outputFields": ["book_id", "word_count", "book_intro"]}`),
errMsg: "missing required parameters, error: Key: 'CollectionIDReq.ID' Error:Field validation for 'ID' failed on the 'required' tag",
errCode: 1802, // ErrMissingRequiredParameters
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: QueryAction,
requestBody: []byte(`{"collectionName": "book", "filter": "book_id in [2, 4, 6, 8]"}`),
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: InsertAction,
@ -1071,9 +1054,19 @@ func TestDML(t *testing.T) {
path: UpsertAction,
requestBody: []byte(`{"collectionName": "book", "data": [{"book_id": 0, "word_count": 0, "book_intro": [0.11825, 0.6]}]}`),
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: DeleteAction,
requestBody: []byte(`{"collectionName": "book", "filter": "book_id in [0]"}`),
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: DeleteAction,
requestBody: []byte(`{"collectionName": "book", "id" : [0]}`),
errMsg: "missing required parameters, error: Key: 'CollectionFilterReq.Filter' Error:Field validation for 'Filter' failed on the 'required' tag",
errCode: 1802, // ErrMissingRequiredParameters
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: GetAction,
requestBody: []byte(`{"collectionName": "book", "id" : [2, 4, 6, 8, 0], "outputFields": ["book_id", "word_count", "book_intro"]}`),
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: GetAction,
@ -1095,7 +1088,7 @@ func TestDML(t *testing.T) {
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: DeleteAction,
requestBody: []byte(`{"collectionName": "book", "id" : [0]}`),
requestBody: []byte(`{"collectionName": "book", "filter": "book_id in [0]"}`),
errMsg: "",
errCode: 65535,
})
@ -1120,3 +1113,180 @@ func TestDML(t *testing.T) {
})
}
}
func TestSearchV2(t *testing.T) {
paramtable.Init()
mp := mocks.NewMockProxy(t)
mp.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{
CollectionName: DefaultCollectionName,
Schema: generateCollectionSchema(schemapb.DataType_Int64),
ShardsNum: ShardNumDefault,
Status: &StatusSuccess,
}, nil).Times(9)
mp.EXPECT().Search(mock.Anything, mock.Anything).Return(&milvuspb.SearchResults{Status: commonSuccessStatus, Results: &schemapb.SearchResultData{TopK: int64(0)}}, nil).Times(3)
mp.EXPECT().Search(mock.Anything, mock.Anything).Return(&milvuspb.SearchResults{Status: &commonpb.Status{
ErrorCode: 1700, // ErrFieldNotFound
Reason: "groupBy field not found in schema: field not found[field=test]",
}}, nil).Once()
mp.EXPECT().HybridSearch(mock.Anything, mock.Anything).Return(&milvuspb.SearchResults{Status: commonSuccessStatus, Results: &schemapb.SearchResultData{TopK: int64(0)}}, nil).Twice()
collSchema := generateCollectionSchema(schemapb.DataType_Int64)
binaryVectorField := generateVectorFieldSchema(schemapb.DataType_BinaryVector)
binaryVectorField.Name = "binaryVector"
float16VectorField := generateVectorFieldSchema(schemapb.DataType_Float16Vector)
float16VectorField.Name = "float16Vector"
bfloat16VectorField := generateVectorFieldSchema(schemapb.DataType_BFloat16Vector)
bfloat16VectorField.Name = "bfloat16Vector"
collSchema.Fields = append(collSchema.Fields, &binaryVectorField)
collSchema.Fields = append(collSchema.Fields, &float16VectorField)
collSchema.Fields = append(collSchema.Fields, &bfloat16VectorField)
mp.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{
CollectionName: DefaultCollectionName,
Schema: collSchema,
ShardsNum: ShardNumDefault,
Status: &StatusSuccess,
}, nil).Times(9)
mp.EXPECT().Search(mock.Anything, mock.Anything).Return(&milvuspb.SearchResults{Status: commonSuccessStatus, Results: &schemapb.SearchResultData{TopK: int64(0)}}, nil).Twice()
testEngine := initHTTPServerV2(mp, false)
queryTestCases := []requestBodyTestCase{}
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: SearchAction,
requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"]}`),
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: SearchAction,
requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "params": {"radius":0.9}}`),
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: SearchAction,
requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "params": {"range_filter": 0.1}}`),
errMsg: "can only accept json format request, error: invalid search params",
errCode: 1801, // ErrIncorrectParameterFormat
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: SearchAction,
requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "params": {"radius":0.9, "range_filter": 0.1}, "groupingField": "word_count"}`),
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: SearchAction,
requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "params": {"radius":0.9, "range_filter": 0.1}, "groupingField": "test"}`),
errMsg: "groupBy field not found in schema: field not found[field=test]",
errCode: 65535,
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: SearchAction,
requestBody: []byte(`{"collectionName": "book", "data": [["0.1", "0.2"]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "params": {"radius":0.9, "range_filter": 0.1}, "groupingField": "test"}`),
errMsg: "can only accept json format request, error: json: cannot unmarshal string into Go value of type float32: invalid parameter[expected=FloatVector][actual=[\"0.1\", \"0.2\"]]",
errCode: 1801,
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: AdvancedSearchAction,
requestBody: []byte(`{"collectionName": "hello_milvus", "search": [{"data": [[0.1, 0.2]], "annsField": "book_intro", "metricType": "L2", "limit": 3}, {"data": [[0.1, 0.2]], "annsField": "book_intro", "metricType": "L2", "limit": 3}], "rerank": {"strategy": "weighted", "params": {"weights": [0.9, 0.8]}}}`),
})
// annsField
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: SearchAction,
requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "annsField": "word_count", "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "params": {"radius":0.9, "range_filter": 0.1}, "groupingField": "test"}`),
errMsg: "can only accept json format request, error: cannot find a vector field named: word_count",
errCode: 1801,
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: AdvancedSearchAction,
requestBody: []byte(`{"collectionName": "hello_milvus", "search": [{"data": [[0.1, 0.2]], "annsField": "float_vector1", "metricType": "L2", "limit": 3}, {"data": [[0.1, 0.2]], "annsField": "float_vector2", "metricType": "L2", "limit": 3}], "rerank": {"strategy": "rrf", "params": {"k": 1}}}`),
errMsg: "can only accept json format request, error: cannot find a vector field named: float_vector1",
errCode: 1801,
})
// multiple annsFields
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: SearchAction,
requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"]}`),
errMsg: "can only accept json format request, error: search without annsFields, but already found multiple vector fields: [book_intro, binaryVector]",
errCode: 1801,
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: SearchAction,
requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "annsField": "book_intro", "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"]}`),
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: SearchAction,
requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "annsField": "binaryVector", "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"]}`),
errMsg: "can only accept json format request, error: json: cannot unmarshal number 0.1 into Go value of type uint8: invalid parameter[expected=BinaryVector][actual=[[0.1, 0.2]]]",
errCode: 1801,
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: SearchAction,
requestBody: []byte(`{"collectionName": "book", "data": ["AQ=="], "annsField": "binaryVector", "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"]}`),
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: AdvancedSearchAction,
requestBody: []byte(`{"collectionName": "hello_milvus", "search": [` +
`{"data": [[0.1, 0.2]], "annsField": "book_intro", "metricType": "L2", "limit": 3},` +
`{"data": ["AQ=="], "annsField": "binaryVector", "metricType": "L2", "limit": 3},` +
`{"data": ["AQIDBA=="], "annsField": "float16Vector", "metricType": "L2", "limit": 3},` +
`{"data": ["AQIDBA=="], "annsField": "bfloat16Vector", "metricType": "L2", "limit": 3}` +
`], "rerank": {"strategy": "weighted", "params": {"weights": [0.9, 0.8]}}}`),
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: AdvancedSearchAction,
requestBody: []byte(`{"collectionName": "hello_milvus", "search": [` +
`{"data": [[0.1, 0.2, 0.3]], "annsField": "book_intro", "metricType": "L2", "limit": 3},` +
`{"data": ["AQ=="], "annsField": "binaryVector", "metricType": "L2", "limit": 3},` +
`{"data": ["AQIDBA=="], "annsField": "float16Vector", "metricType": "L2", "limit": 3},` +
`{"data": ["AQIDBA=="], "annsField": "bfloat16Vector", "metricType": "L2", "limit": 3}` +
`], "rerank": {"strategy": "weighted", "params": {"weights": [0.9, 0.8]}}}`),
errMsg: "can only accept json format request, error: dimension: 2, but length of []float: 3: invalid parameter[expected=FloatVector][actual=[0.1, 0.2, 0.3]]",
errCode: 1801,
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: AdvancedSearchAction,
requestBody: []byte(`{"collectionName": "hello_milvus", "search": [` +
`{"data": [[0.1, 0.2]], "annsField": "book_intro", "metricType": "L2", "limit": 3},` +
`{"data": ["AQID"], "annsField": "binaryVector", "metricType": "L2", "limit": 3},` +
`{"data": ["AQIDBA=="], "annsField": "float16Vector", "metricType": "L2", "limit": 3},` +
`{"data": ["AQIDBA=="], "annsField": "bfloat16Vector", "metricType": "L2", "limit": 3}` +
`], "rerank": {"strategy": "weighted", "params": {"weights": [0.9, 0.8]}}}`),
errMsg: "can only accept json format request, error: dimension: 8, bytesLen: 1, but length of []byte: 3: invalid parameter[expected=BinaryVector][actual=\x01\x02\x03]",
errCode: 1801,
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: AdvancedSearchAction,
requestBody: []byte(`{"collectionName": "hello_milvus", "search": [` +
`{"data": [[0.1, 0.2]], "annsField": "book_intro", "metricType": "L2", "limit": 3},` +
`{"data": ["AQ=="], "annsField": "binaryVector", "metricType": "L2", "limit": 3},` +
`{"data": ["AQID"], "annsField": "float16Vector", "metricType": "L2", "limit": 3},` +
`{"data": ["AQIDBA=="], "annsField": "bfloat16Vector", "metricType": "L2", "limit": 3}` +
`], "rerank": {"strategy": "weighted", "params": {"weights": [0.9, 0.8]}}}`),
errMsg: "can only accept json format request, error: dimension: 2, bytesLen: 4, but length of []byte: 3: invalid parameter[expected=Float16Vector][actual=\x01\x02\x03]",
errCode: 1801,
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: AdvancedSearchAction,
requestBody: []byte(`{"collectionName": "hello_milvus", "search": [` +
`{"data": [[0.1, 0.2]], "annsField": "book_intro", "metricType": "L2", "limit": 3},` +
`{"data": ["AQ=="], "annsField": "binaryVector", "metricType": "L2", "limit": 3},` +
`{"data": ["AQIDBA=="], "annsField": "float16Vector", "metricType": "L2", "limit": 3},` +
`{"data": ["AQID"], "annsField": "bfloat16Vector", "metricType": "L2", "limit": 3}` +
`], "rerank": {"strategy": "weighted", "params": {"weights": [0.9, 0.8]}}}`),
errMsg: "can only accept json format request, error: dimension: 2, bytesLen: 4, but length of []byte: 3: invalid parameter[expected=BFloat16Vector][actual=\x01\x02\x03]",
errCode: 1801,
})
for _, testcase := range queryTestCases {
t.Run("search", func(t *testing.T) {
bodyReader := bytes.NewReader(testcase.requestBody)
req := httptest.NewRequest(http.MethodPost, versionalV2(EntityCategory, testcase.path), bodyReader)
w := httptest.NewRecorder()
testEngine.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
returnBody := &ReturnErrMsg{}
err := json.Unmarshal(w.Body.Bytes(), returnBody)
assert.Nil(t, err)
if testcase.errCode != 0 {
assert.Equal(t, testcase.errCode, returnBody.Code)
assert.Equal(t, testcase.errMsg, returnBody.Message)
} else {
assert.Equal(t, int32(http.StatusOK), returnBody.Code)
}
fmt.Println(w.Body.String())
})
}
}

View File

@ -99,7 +99,7 @@ type QueryReqV2 struct {
func (req *QueryReqV2) GetDbName() string { return req.DbName }
type CollectionIDOutputReq struct {
type CollectionIDReq struct {
DbName string `json:"dbName"`
CollectionName string `json:"collectionName" binding:"required"`
PartitionName string `json:"partitionName"`
@ -108,17 +108,16 @@ type CollectionIDOutputReq struct {
ID interface{} `json:"id" binding:"required"`
}
func (req *CollectionIDOutputReq) GetDbName() string { return req.DbName }
func (req *CollectionIDReq) GetDbName() string { return req.DbName }
type CollectionIDFilterReq struct {
DbName string `json:"dbName"`
CollectionName string `json:"collectionName" binding:"required"`
PartitionName string `json:"partitionName"`
ID interface{} `json:"id"`
Filter string `json:"filter"`
type CollectionFilterReq struct {
DbName string `json:"dbName"`
CollectionName string `json:"collectionName" binding:"required"`
PartitionName string `json:"partitionName"`
Filter string `json:"filter" binding:"required"`
}
func (req *CollectionIDFilterReq) GetDbName() string { return req.DbName }
func (req *CollectionFilterReq) GetDbName() string { return req.DbName }
type CollectionDataReq struct {
DbName string `json:"dbName"`
@ -132,7 +131,8 @@ func (req *CollectionDataReq) GetDbName() string { return req.DbName }
type SearchReqV2 struct {
DbName string `json:"dbName"`
CollectionName string `json:"collectionName" binding:"required"`
Vector [][]float32 `json:"vector"`
Data []interface{} `json:"data" binding:"required"`
AnnsField string `json:"annsField"`
PartitionNames []string `json:"partitionNames"`
Filter string `json:"filter"`
GroupByField string `json:"groupingField"`
@ -150,7 +150,7 @@ type Rand struct {
}
type SubSearchReq struct {
Vector [][]float32 `json:"vector"`
Data []interface{} `json:"data" binding:"required"`
AnnsField string `json:"annsField"`
Filter string `json:"filter"`
GroupByField string `json:"groupingField"`
@ -296,12 +296,15 @@ type CollectionSchema struct {
}
type CollectionReq struct {
DbName string `json:"dbName"`
CollectionName string `json:"collectionName" binding:"required"`
Dimension int32 `json:"dimension"`
MetricType string `json:"metricType"`
Schema CollectionSchema `json:"schema"`
IndexParams []IndexParam `json:"indexParams"`
DbName string `json:"dbName"`
CollectionName string `json:"collectionName" binding:"required"`
Dimension int32 `json:"dimension"`
IDType string `json:"idType"`
MetricType string `json:"metricType"`
PrimaryFieldName string `json:"primaryFieldName"`
VectorFieldName string `json:"vectorFieldName"`
Schema CollectionSchema `json:"schema"`
IndexParams []IndexParam `json:"indexParams"`
}
func (req *CollectionReq) GetDbName() string { return req.DbName }

View File

@ -183,7 +183,7 @@ func printIndexes(indexes []*milvuspb.IndexDescription) []gin.H {
func checkAndSetData(body string, collSchema *schemapb.CollectionSchema) (error, []map[string]interface{}) {
var reallyDataArray []map[string]interface{}
dataResult := gjson.Get(body, "data")
dataResult := gjson.Get(body, HTTPRequestData)
dataResultArray := dataResult.Array()
if len(dataResultArray) == 0 {
return merr.ErrMissingRequiredParameters, reallyDataArray
@ -914,6 +914,67 @@ func serialize(fv []float32) []byte {
return data
}
func serializeFloatVectors(vectors []gjson.Result, dataType schemapb.DataType, dimension, bytesLen int64) ([][]byte, error) {
values := make([][]byte, 0)
for _, vector := range vectors {
var vectorArray []float32
err := json.Unmarshal([]byte(vector.String()), &vectorArray)
if err != nil {
return nil, merr.WrapErrParameterInvalid(schemapb.DataType_name[int32(dataType)], vector.String(), err.Error())
}
if int64(len(vectorArray)) != dimension {
return nil, merr.WrapErrParameterInvalid(schemapb.DataType_name[int32(dataType)], vector.String(),
fmt.Sprintf("dimension: %d, but length of []float: %d", dimension, len(vectorArray)))
}
vectorBytes := serialize(vectorArray)
values = append(values, vectorBytes)
}
return values, nil
}
func serializeByteVectors(vectorStr string, dataType schemapb.DataType, dimension, bytesLen int64) ([][]byte, error) {
values := make([][]byte, 0)
err := json.Unmarshal([]byte(vectorStr), &values) // todo check len == dimension * 1/2/2
if err != nil {
return nil, merr.WrapErrParameterInvalid(schemapb.DataType_name[int32(dataType)], vectorStr, err.Error())
}
for _, vectorArray := range values {
if int64(len(vectorArray)) != bytesLen {
return nil, merr.WrapErrParameterInvalid(schemapb.DataType_name[int32(dataType)], string(vectorArray),
fmt.Sprintf("dimension: %d, bytesLen: %d, but length of []byte: %d", dimension, bytesLen, len(vectorArray)))
}
}
return values, nil
}
func convertVectors2Placeholder(body string, dataType schemapb.DataType, dimension int64) (*commonpb.PlaceholderValue, error) {
var valueType commonpb.PlaceholderType
var values [][]byte
var err error
switch dataType {
case schemapb.DataType_FloatVector:
valueType = commonpb.PlaceholderType_FloatVector
values, err = serializeFloatVectors(gjson.Get(body, HTTPRequestData).Array(), dataType, dimension, dimension*4)
case schemapb.DataType_BinaryVector:
valueType = commonpb.PlaceholderType_BinaryVector
values, err = serializeByteVectors(gjson.Get(body, HTTPRequestData).Raw, dataType, dimension, dimension/8)
case schemapb.DataType_Float16Vector:
valueType = commonpb.PlaceholderType_Float16Vector
values, err = serializeByteVectors(gjson.Get(body, HTTPRequestData).Raw, dataType, dimension, dimension*2)
case schemapb.DataType_BFloat16Vector:
valueType = commonpb.PlaceholderType_BFloat16Vector
values, err = serializeByteVectors(gjson.Get(body, HTTPRequestData).Raw, dataType, dimension, dimension*2)
}
if err != nil {
return nil, err
}
return &commonpb.PlaceholderValue{
Tag: "$0",
Type: valueType,
Values: values,
}, nil
}
// todo: support [][]byte for BinaryVector
func vectors2PlaceholderGroupBytes(vectors [][]float32) []byte {
var placeHolderType commonpb.PlaceholderType

View File

@ -8,6 +8,7 @@ import (
"testing"
"github.com/gin-gonic/gin"
"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/assert"
"github.com/tidwall/gjson"
@ -492,6 +493,42 @@ func TestSerialize(t *testing.T) {
parameters := []float32{0.11111, 0.22222}
assert.Equal(t, "\xa4\x8d\xe3=\xa4\x8dc>", string(serialize(parameters)))
assert.Equal(t, "\n\x10\n\x02$0\x10e\x1a\b\xa4\x8d\xe3=\xa4\x8dc>", string(vectors2PlaceholderGroupBytes([][]float32{parameters}))) // todo
requestBody := "{\"data\": [[0.11111, 0.22222]]}"
vectors := gjson.Get(requestBody, HTTPRequestData)
values, err := serializeFloatVectors(vectors.Array(), schemapb.DataType_FloatVector, 2, -1)
assert.Nil(t, err)
placeholderValue := &commonpb.PlaceholderValue{
Tag: "$0",
Type: commonpb.PlaceholderType_FloatVector,
Values: values,
}
bytes, err := proto.Marshal(&commonpb.PlaceholderGroup{
Placeholders: []*commonpb.PlaceholderValue{
placeholderValue,
},
})
assert.Nil(t, err)
assert.Equal(t, "\n\x10\n\x02$0\x10e\x1a\b\xa4\x8d\xe3=\xa4\x8dc>", string(bytes)) // todo
for _, dataType := range []schemapb.DataType{schemapb.DataType_BinaryVector, schemapb.DataType_Float16Vector, schemapb.DataType_BFloat16Vector} {
request := map[string]interface{}{
HTTPRequestData: []interface{}{
[]byte{1, 2},
},
}
requestBody, _ := json.Marshal(request)
values, err = serializeByteVectors(gjson.Get(string(requestBody), HTTPRequestData).Raw, dataType, -1, 2)
assert.Nil(t, err)
placeholderValue = &commonpb.PlaceholderValue{
Tag: "$0",
Values: values,
}
_, err = proto.Marshal(&commonpb.PlaceholderGroup{
Placeholders: []*commonpb.PlaceholderValue{
placeholderValue,
},
})
assert.Nil(t, err)
}
}
func compareRow64(m1 map[string]interface{}, m2 map[string]interface{}) bool {