mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-11-30 02:48:45 +08:00
fix: [restful v2] search & advanced_search API (#31113)
issue: #30688 former pr: #30946 1. param `vector` is required #31012 2. param `annsField` is optional, for multiply vector fields #31010 3. support BinaryVector, Float16Vector, BFloat16Vector #31013 4. replace vector with data, to align with pymilvus milvus_client #31093 5. create collection quickly, to align with pymilvus milvus_client #31149 --------- Signed-off-by: PowderLi <min.li@zilliz.com>
This commit is contained in:
parent
9cfe183253
commit
58d7b9f902
@ -1,6 +1,10 @@
|
||||
package httpserver
|
||||
|
||||
import "time"
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||
)
|
||||
|
||||
// v2
|
||||
const (
|
||||
@ -14,23 +18,23 @@ const (
|
||||
AliasCategory = "/aliases/"
|
||||
ImportJobCategory = "/jobs/import/"
|
||||
|
||||
ListAction = "list"
|
||||
HasAction = "has"
|
||||
DescribeAction = "describe"
|
||||
CreateAction = "create"
|
||||
DropAction = "drop"
|
||||
StatsAction = "get_stats"
|
||||
LoadStateAction = "get_load_state"
|
||||
RenameAction = "rename"
|
||||
LoadAction = "load"
|
||||
ReleaseAction = "release"
|
||||
QueryAction = "query"
|
||||
GetAction = "get"
|
||||
DeleteAction = "delete"
|
||||
InsertAction = "insert"
|
||||
UpsertAction = "upsert"
|
||||
SearchAction = "search"
|
||||
HybridSearchAction = "hybrid_search"
|
||||
ListAction = "list"
|
||||
HasAction = "has"
|
||||
DescribeAction = "describe"
|
||||
CreateAction = "create"
|
||||
DropAction = "drop"
|
||||
StatsAction = "get_stats"
|
||||
LoadStateAction = "get_load_state"
|
||||
RenameAction = "rename"
|
||||
LoadAction = "load"
|
||||
ReleaseAction = "release"
|
||||
QueryAction = "query"
|
||||
GetAction = "get"
|
||||
DeleteAction = "delete"
|
||||
InsertAction = "insert"
|
||||
UpsertAction = "upsert"
|
||||
SearchAction = "search"
|
||||
AdvancedSearchAction = "advanced_search"
|
||||
|
||||
UpdatePasswordAction = "update_password"
|
||||
GrantRoleAction = "grant_role"
|
||||
@ -72,6 +76,7 @@ const (
|
||||
HTTPIndexName = "indexName"
|
||||
HTTPIndexField = "fieldName"
|
||||
HTTPAliasName = "aliasName"
|
||||
HTTPRequestData = "data"
|
||||
DefaultDbName = "default"
|
||||
DefaultIndexName = "vector_idx"
|
||||
DefaultAliasName = "the_alias"
|
||||
@ -112,7 +117,7 @@ const (
|
||||
HTTPReturnGrantor = "grantor"
|
||||
HTTPReturnDbName = "dbName"
|
||||
|
||||
DefaultMetricType = "L2"
|
||||
DefaultMetricType = metric.COSINE
|
||||
DefaultPrimaryFieldName = "id"
|
||||
DefaultVectorFieldName = "vector"
|
||||
|
||||
|
@ -22,6 +22,7 @@ import (
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||
"github.com/milvus-io/milvus/pkg/util/requestutil"
|
||||
)
|
||||
|
||||
@ -220,7 +221,7 @@ func (h *HandlersV1) listCollections(c *gin.Context) {
|
||||
func (h *HandlersV1) createCollection(c *gin.Context) {
|
||||
httpReq := CreateCollectionReq{
|
||||
DbName: DefaultDbName,
|
||||
MetricType: DefaultMetricType,
|
||||
MetricType: metric.L2,
|
||||
PrimaryField: DefaultPrimaryFieldName,
|
||||
VectorField: DefaultVectorFieldName,
|
||||
EnableDynamicField: EnableDynamic,
|
||||
|
@ -271,7 +271,7 @@ func TestVectorCollectionsDescribe(t *testing.T) {
|
||||
name: "get load status fail",
|
||||
mp: mp2,
|
||||
exceptCode: http.StatusOK,
|
||||
expectedBody: "{\"code\":200,\"data\":{\"collectionName\":\"" + DefaultCollectionName + "\",\"description\":\"\",\"enableDynamicField\":true,\"fields\":[{\"autoId\":false,\"description\":\"\",\"name\":\"book_id\",\"partitionKey\":false,\"primaryKey\":true,\"type\":\"Int64\"},{\"autoId\":false,\"description\":\"\",\"name\":\"word_count\",\"partitionKey\":false,\"primaryKey\":false,\"type\":\"Int64\"},{\"autoId\":false,\"description\":\"\",\"name\":\"book_intro\",\"partitionKey\":false,\"primaryKey\":false,\"type\":\"FloatVector(2)\"}],\"indexes\":[{\"fieldName\":\"book_intro\",\"indexName\":\"" + DefaultIndexName + "\",\"metricType\":\"L2\"}],\"load\":\"\",\"shardsNum\":1}}",
|
||||
expectedBody: "{\"code\":200,\"data\":{\"collectionName\":\"" + DefaultCollectionName + "\",\"description\":\"\",\"enableDynamicField\":true,\"fields\":[{\"autoId\":false,\"description\":\"\",\"name\":\"book_id\",\"partitionKey\":false,\"primaryKey\":true,\"type\":\"Int64\"},{\"autoId\":false,\"description\":\"\",\"name\":\"word_count\",\"partitionKey\":false,\"primaryKey\":false,\"type\":\"Int64\"},{\"autoId\":false,\"description\":\"\",\"name\":\"book_intro\",\"partitionKey\":false,\"primaryKey\":false,\"type\":\"FloatVector(2)\"}],\"indexes\":[{\"fieldName\":\"book_intro\",\"indexName\":\"" + DefaultIndexName + "\",\"metricType\":\"COSINE\"}],\"load\":\"\",\"shardsNum\":1}}",
|
||||
})
|
||||
|
||||
mp3 := mocks.NewMockProxy(t)
|
||||
@ -293,7 +293,7 @@ func TestVectorCollectionsDescribe(t *testing.T) {
|
||||
name: "show collection details success",
|
||||
mp: mp4,
|
||||
exceptCode: http.StatusOK,
|
||||
expectedBody: "{\"code\":200,\"data\":{\"collectionName\":\"" + DefaultCollectionName + "\",\"description\":\"\",\"enableDynamicField\":true,\"fields\":[{\"autoId\":false,\"description\":\"\",\"name\":\"book_id\",\"partitionKey\":false,\"primaryKey\":true,\"type\":\"Int64\"},{\"autoId\":false,\"description\":\"\",\"name\":\"word_count\",\"partitionKey\":false,\"primaryKey\":false,\"type\":\"Int64\"},{\"autoId\":false,\"description\":\"\",\"name\":\"book_intro\",\"partitionKey\":false,\"primaryKey\":false,\"type\":\"FloatVector(2)\"}],\"indexes\":[{\"fieldName\":\"book_intro\",\"indexName\":\"" + DefaultIndexName + "\",\"metricType\":\"L2\"}],\"load\":\"LoadStateLoaded\",\"shardsNum\":1}}",
|
||||
expectedBody: "{\"code\":200,\"data\":{\"collectionName\":\"" + DefaultCollectionName + "\",\"description\":\"\",\"enableDynamicField\":true,\"fields\":[{\"autoId\":false,\"description\":\"\",\"name\":\"book_id\",\"partitionKey\":false,\"primaryKey\":true,\"type\":\"Int64\"},{\"autoId\":false,\"description\":\"\",\"name\":\"word_count\",\"partitionKey\":false,\"primaryKey\":false,\"type\":\"Int64\"},{\"autoId\":false,\"description\":\"\",\"name\":\"book_intro\",\"partitionKey\":false,\"primaryKey\":false,\"type\":\"FloatVector(2)\"}],\"indexes\":[{\"fieldName\":\"book_intro\",\"indexName\":\"" + DefaultIndexName + "\",\"metricType\":\"COSINE\"}],\"load\":\"LoadStateLoaded\",\"shardsNum\":1}}",
|
||||
})
|
||||
|
||||
for _, tt := range testCases {
|
||||
|
@ -66,12 +66,12 @@ func (h *HandlersV2) RegisterRoutesToV2(router gin.IRouter) {
|
||||
}
|
||||
}, wrapperTraceLog(h.wrapperCheckDatabase(h.query)))))
|
||||
router.POST(EntityCategory+GetAction, timeoutMiddleware(wrapperPost(func() any {
|
||||
return &CollectionIDOutputReq{
|
||||
return &CollectionIDReq{
|
||||
OutputFields: []string{DefaultOutputFields},
|
||||
}
|
||||
}, wrapperTraceLog(h.wrapperCheckDatabase(h.get)))))
|
||||
router.POST(EntityCategory+DeleteAction, timeoutMiddleware(wrapperPost(func() any {
|
||||
return &CollectionIDFilterReq{}
|
||||
return &CollectionFilterReq{}
|
||||
}, wrapperTraceLog(h.wrapperCheckDatabase(h.delete)))))
|
||||
router.POST(EntityCategory+InsertAction, timeoutMiddleware(wrapperPost(func() any {
|
||||
return &CollectionDataReq{}
|
||||
@ -84,11 +84,11 @@ func (h *HandlersV2) RegisterRoutesToV2(router gin.IRouter) {
|
||||
Limit: 100,
|
||||
}
|
||||
}, wrapperTraceLog(h.wrapperCheckDatabase(h.search)))))
|
||||
router.POST(EntityCategory+HybridSearchAction, timeoutMiddleware(wrapperPost(func() any {
|
||||
router.POST(EntityCategory+AdvancedSearchAction, timeoutMiddleware(wrapperPost(func() any {
|
||||
return &HybridSearchReq{
|
||||
Limit: 100,
|
||||
}
|
||||
}, wrapperTraceLog(h.wrapperCheckDatabase(h.hybridSearch)))))
|
||||
}, wrapperTraceLog(h.wrapperCheckDatabase(h.advancedSearch)))))
|
||||
|
||||
router.POST(PartitionCategory+ListAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.listPartitions)))))
|
||||
router.POST(PartitionCategory+HasAction, timeoutMiddleware(wrapperPost(func() any { return &PartitionReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.hasPartitions)))))
|
||||
@ -534,7 +534,7 @@ func (h *HandlersV2) query(ctx context.Context, c *gin.Context, anyReq any, dbNa
|
||||
}
|
||||
|
||||
func (h *HandlersV2) get(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) {
|
||||
httpReq := anyReq.(*CollectionIDOutputReq)
|
||||
httpReq := anyReq.(*CollectionIDReq)
|
||||
collSchema, err := h.GetCollectionSchema(ctx, c, dbName, httpReq.CollectionName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -577,7 +577,7 @@ func (h *HandlersV2) get(ctx context.Context, c *gin.Context, anyReq any, dbName
|
||||
}
|
||||
|
||||
func (h *HandlersV2) delete(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) {
|
||||
httpReq := anyReq.(*CollectionIDFilterReq)
|
||||
httpReq := anyReq.(*CollectionFilterReq)
|
||||
collSchema, err := h.GetCollectionSchema(ctx, c, dbName, httpReq.CollectionName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -728,6 +728,43 @@ func (h *HandlersV2) upsert(ctx context.Context, c *gin.Context, anyReq any, dbN
|
||||
return resp, err
|
||||
}
|
||||
|
||||
func generatePlaceholderGroup(ctx context.Context, body string, collSchema *schemapb.CollectionSchema, fieldName string) ([]byte, error) {
|
||||
var err error
|
||||
var vectorField *schemapb.FieldSchema
|
||||
if len(fieldName) == 0 {
|
||||
for _, field := range collSchema.Fields {
|
||||
if IsVectorField(field) {
|
||||
if len(fieldName) == 0 {
|
||||
fieldName = field.Name
|
||||
vectorField = field
|
||||
} else {
|
||||
return nil, errors.New("search without annsFields, but already found multiple vector fields: [" + fieldName + ", " + field.Name + "]")
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for _, field := range collSchema.Fields {
|
||||
if field.Name == fieldName && IsVectorField(field) {
|
||||
vectorField = field
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if vectorField == nil {
|
||||
return nil, errors.New("cannot find a vector field named: " + fieldName)
|
||||
}
|
||||
dim, _ := getDim(vectorField)
|
||||
phv, err := convertVectors2Placeholder(body, vectorField.DataType, dim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return proto.Marshal(&commonpb.PlaceholderGroup{
|
||||
Placeholders: []*commonpb.PlaceholderValue{
|
||||
phv,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func generateSearchParams(ctx context.Context, c *gin.Context, reqParams map[string]float64) ([]*commonpb.KeyValuePair, error) {
|
||||
params := map[string]interface{}{ // auto generated mapping
|
||||
"level": int(commonpb.ConsistencyLevel_Bounded),
|
||||
@ -759,6 +796,10 @@ func generateSearchParams(ctx context.Context, c *gin.Context, reqParams map[str
|
||||
|
||||
func (h *HandlersV2) search(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) {
|
||||
httpReq := anyReq.(*SearchReqV2)
|
||||
collSchema, err := h.GetCollectionSchema(ctx, c, dbName, httpReq.CollectionName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
searchParams, err := generateSearchParams(ctx, c, httpReq.Params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -766,12 +807,23 @@ func (h *HandlersV2) search(ctx context.Context, c *gin.Context, anyReq any, dbN
|
||||
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: common.TopKKey, Value: strconv.FormatInt(int64(httpReq.Limit), 10)})
|
||||
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamOffset, Value: strconv.FormatInt(int64(httpReq.Offset), 10)})
|
||||
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamGroupByField, Value: httpReq.GroupByField})
|
||||
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: proxy.AnnsFieldKey, Value: httpReq.AnnsField})
|
||||
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamRoundDecimal, Value: "-1"})
|
||||
body, _ := c.Get(gin.BodyBytesKey)
|
||||
placeholderGroup, err := generatePlaceholderGroup(ctx, string(body.([]byte)), collSchema, httpReq.AnnsField)
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Warn("high level restful api, search with vector invalid", zap.Error(err))
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{
|
||||
HTTPReturnCode: merr.Code(merr.ErrIncorrectParameterFormat),
|
||||
HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error() + ", error: " + err.Error(),
|
||||
})
|
||||
return nil, err
|
||||
}
|
||||
req := &milvuspb.SearchRequest{
|
||||
DbName: dbName,
|
||||
CollectionName: httpReq.CollectionName,
|
||||
Dsl: httpReq.Filter,
|
||||
PlaceholderGroup: vectors2PlaceholderGroupBytes(httpReq.Vector),
|
||||
PlaceholderGroup: placeholderGroup,
|
||||
DslType: commonpb.DslType_BoolExprV1,
|
||||
OutputFields: httpReq.OutputFields,
|
||||
PartitionNames: httpReq.PartitionNames,
|
||||
@ -803,14 +855,20 @@ func (h *HandlersV2) search(ctx context.Context, c *gin.Context, anyReq any, dbN
|
||||
return resp, err
|
||||
}
|
||||
|
||||
func (h *HandlersV2) hybridSearch(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) {
|
||||
func (h *HandlersV2) advancedSearch(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) {
|
||||
httpReq := anyReq.(*HybridSearchReq)
|
||||
req := &milvuspb.HybridSearchRequest{
|
||||
DbName: dbName,
|
||||
CollectionName: httpReq.CollectionName,
|
||||
Requests: []*milvuspb.SearchRequest{},
|
||||
}
|
||||
for _, subReq := range httpReq.Search {
|
||||
collSchema, err := h.GetCollectionSchema(ctx, c, dbName, httpReq.CollectionName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
body, _ := c.Get(gin.BodyBytesKey)
|
||||
searchArray := gjson.Get(string(body.([]byte)), "search").Array()
|
||||
for i, subReq := range httpReq.Search {
|
||||
searchParams, err := generateSearchParams(ctx, c, subReq.Params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -820,11 +878,20 @@ func (h *HandlersV2) hybridSearch(ctx context.Context, c *gin.Context, anyReq an
|
||||
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamGroupByField, Value: subReq.GroupByField})
|
||||
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: proxy.AnnsFieldKey, Value: subReq.AnnsField})
|
||||
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamRoundDecimal, Value: "-1"})
|
||||
placeholderGroup, err := generatePlaceholderGroup(ctx, searchArray[i].Raw, collSchema, subReq.AnnsField)
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Warn("high level restful api, search with vector invalid", zap.Error(err))
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{
|
||||
HTTPReturnCode: merr.Code(merr.ErrIncorrectParameterFormat),
|
||||
HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error() + ", error: " + err.Error(),
|
||||
})
|
||||
return nil, err
|
||||
}
|
||||
searchReq := &milvuspb.SearchRequest{
|
||||
DbName: dbName,
|
||||
CollectionName: httpReq.CollectionName,
|
||||
Dsl: subReq.Filter,
|
||||
PlaceholderGroup: vectors2PlaceholderGroupBytes(subReq.Vector),
|
||||
PlaceholderGroup: placeholderGroup,
|
||||
DslType: commonpb.DslType_BoolExprV1,
|
||||
OutputFields: httpReq.OutputFields,
|
||||
PartitionNames: httpReq.PartitionNames,
|
||||
@ -871,20 +938,53 @@ func (h *HandlersV2) createCollection(ctx context.Context, c *gin.Context, anyRe
|
||||
var err error
|
||||
fieldNames := map[string]bool{}
|
||||
if httpReq.Schema.Fields == nil || len(httpReq.Schema.Fields) == 0 {
|
||||
if httpReq.Dimension == 0 {
|
||||
err := merr.WrapErrParameterInvalid("collectionName & dimension", "collectionName",
|
||||
"dimension is required for quickly create collection(default metric type: "+DefaultMetricType+")")
|
||||
log.Ctx(ctx).Warn("high level restful api, quickly create collection fail", zap.Error(err), zap.Any("request", anyReq))
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{
|
||||
HTTPReturnCode: merr.Code(err),
|
||||
HTTPReturnMessage: err.Error(),
|
||||
})
|
||||
return nil, err
|
||||
}
|
||||
idDataType := schemapb.DataType_Int64
|
||||
switch httpReq.IDType {
|
||||
case "Varchar":
|
||||
idDataType = schemapb.DataType_VarChar
|
||||
case "":
|
||||
httpReq.IDType = "Int64"
|
||||
case "Int64":
|
||||
default:
|
||||
err := merr.WrapErrParameterInvalid("Int64, Varchar", httpReq.IDType,
|
||||
"idType can only be [Int64, Varchar](case sensitive), default: Int64")
|
||||
log.Ctx(ctx).Warn("high level restful api, quickly create collection fail", zap.Error(err), zap.Any("request", anyReq))
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{
|
||||
HTTPReturnCode: merr.Code(err),
|
||||
HTTPReturnMessage: err.Error(),
|
||||
})
|
||||
return nil, err
|
||||
}
|
||||
if len(httpReq.PrimaryFieldName) == 0 {
|
||||
httpReq.PrimaryFieldName = PrimaryFieldName
|
||||
}
|
||||
if len(httpReq.VectorFieldName) == 0 {
|
||||
httpReq.VectorFieldName = VectorFieldName
|
||||
}
|
||||
schema, err = proto.Marshal(&schemapb.CollectionSchema{
|
||||
Name: httpReq.CollectionName,
|
||||
AutoID: EnableAutoID,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: common.StartOfUserFieldID,
|
||||
Name: PrimaryFieldName,
|
||||
Name: httpReq.PrimaryFieldName,
|
||||
IsPrimaryKey: true,
|
||||
DataType: schemapb.DataType_Int64,
|
||||
DataType: idDataType,
|
||||
AutoID: EnableAutoID,
|
||||
},
|
||||
{
|
||||
FieldID: common.StartOfUserFieldID + 1,
|
||||
Name: VectorFieldName,
|
||||
Name: httpReq.VectorFieldName,
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_FloatVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
@ -967,6 +1067,9 @@ func (h *HandlersV2) createCollection(ctx context.Context, c *gin.Context, anyRe
|
||||
return resp, err
|
||||
}
|
||||
if httpReq.Schema.Fields == nil || len(httpReq.Schema.Fields) == 0 {
|
||||
if len(httpReq.MetricType) == 0 {
|
||||
httpReq.MetricType = DefaultMetricType
|
||||
}
|
||||
createIndexReq := &milvuspb.CreateIndexRequest{
|
||||
DbName: dbName,
|
||||
CollectionName: httpReq.CollectionName,
|
||||
|
@ -400,13 +400,34 @@ func TestDatabaseWrapper(t *testing.T) {
|
||||
func TestCreateCollection(t *testing.T) {
|
||||
postTestCases := []requestBodyTestCase{}
|
||||
mp := mocks.NewMockProxy(t)
|
||||
mp.EXPECT().CreateCollection(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Times(7)
|
||||
mp.EXPECT().CreateIndex(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Twice()
|
||||
mp.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Twice()
|
||||
mp.EXPECT().CreateCollection(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Times(9)
|
||||
mp.EXPECT().CreateIndex(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Times(4)
|
||||
mp.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Times(4)
|
||||
mp.EXPECT().CreateIndex(mock.Anything, mock.Anything).Return(commonErrorStatus, nil).Twice()
|
||||
mp.EXPECT().CreateCollection(mock.Anything, mock.Anything).Return(commonErrorStatus, nil).Once()
|
||||
testEngine := initHTTPServerV2(mp, false)
|
||||
path := versionalV2(CollectionCategory, CreateAction)
|
||||
// quickly create collection
|
||||
postTestCases = append(postTestCases, requestBodyTestCase{
|
||||
path: path,
|
||||
requestBody: []byte(`{"collectionName": "` + DefaultCollectionName + `"}`),
|
||||
errMsg: "dimension is required for quickly create collection(default metric type: COSINE): invalid parameter[expected=collectionName & dimension][actual=collectionName]",
|
||||
errCode: 1100, // ErrParameterInvalid
|
||||
})
|
||||
postTestCases = append(postTestCases, requestBodyTestCase{
|
||||
path: path,
|
||||
requestBody: []byte(`{"collectionName": "` + DefaultCollectionName + `", "dimension": 2, "idType": "Varchar"}`),
|
||||
})
|
||||
postTestCases = append(postTestCases, requestBodyTestCase{
|
||||
path: path,
|
||||
requestBody: []byte(`{"collectionName": "` + DefaultCollectionName + `", "dimension": 2, "idType": "unknown"}`),
|
||||
errMsg: "idType can only be [Int64, Varchar](case sensitive), default: Int64: invalid parameter[expected=Int64, Varchar][actual=unknown]",
|
||||
errCode: 1100, // ErrParameterInvalid
|
||||
})
|
||||
postTestCases = append(postTestCases, requestBodyTestCase{
|
||||
path: path,
|
||||
requestBody: []byte(`{"collectionName": "` + DefaultCollectionName + `", "dimension": 2}`),
|
||||
})
|
||||
postTestCases = append(postTestCases, requestBodyTestCase{
|
||||
path: path,
|
||||
requestBody: []byte(`{"collectionName": "` + DefaultCollectionName + `", "dimension": 2, "metricType": "L2"}`),
|
||||
@ -995,13 +1016,7 @@ func TestDML(t *testing.T) {
|
||||
Status: &StatusSuccess,
|
||||
}, nil).Times(6)
|
||||
mp.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{Status: commonErrorStatus}, nil).Times(4)
|
||||
mp.EXPECT().Search(mock.Anything, mock.Anything).Return(&milvuspb.SearchResults{Status: commonSuccessStatus, Results: &schemapb.SearchResultData{TopK: int64(0)}}, nil).Times(3)
|
||||
mp.EXPECT().Search(mock.Anything, mock.Anything).Return(&milvuspb.SearchResults{Status: &commonpb.Status{
|
||||
ErrorCode: 1700, // ErrFieldNotFound
|
||||
Reason: "groupBy field not found in schema: field not found[field=test]",
|
||||
}}, nil).Once()
|
||||
mp.EXPECT().HybridSearch(mock.Anything, mock.Anything).Return(&milvuspb.SearchResults{Status: commonSuccessStatus, Results: &schemapb.SearchResultData{TopK: int64(0)}}, nil).Twice()
|
||||
mp.EXPECT().Query(mock.Anything, mock.Anything).Return(&milvuspb.QueryResults{Status: commonSuccessStatus, OutputFields: []string{}, FieldsData: []*schemapb.FieldData{}}, nil).Twice()
|
||||
mp.EXPECT().Query(mock.Anything, mock.Anything).Return(&milvuspb.QueryResults{Status: commonSuccessStatus, OutputFields: []string{}, FieldsData: []*schemapb.FieldData{}}, nil).Times(3)
|
||||
mp.EXPECT().Insert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{Status: commonSuccessStatus, InsertCnt: int64(0), IDs: &schemapb.IDs{IdField: &schemapb.IDs_IntId{IntId: &schemapb.LongArray{Data: []int64{}}}}}, nil).Once()
|
||||
mp.EXPECT().Insert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{Status: commonSuccessStatus, InsertCnt: int64(0), IDs: &schemapb.IDs{IdField: &schemapb.IDs_StrId{StrId: &schemapb.StringArray{Data: []string{}}}}}, nil).Once()
|
||||
mp.EXPECT().Upsert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{Status: commonSuccessStatus, UpsertCnt: int64(0), IDs: &schemapb.IDs{IdField: &schemapb.IDs_IntId{IntId: &schemapb.LongArray{Data: []int64{}}}}}, nil).Once()
|
||||
@ -1009,51 +1024,19 @@ func TestDML(t *testing.T) {
|
||||
mp.EXPECT().Delete(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{Status: commonSuccessStatus}, nil).Once()
|
||||
testEngine := initHTTPServerV2(mp, false)
|
||||
queryTestCases := []requestBodyTestCase{}
|
||||
queryTestCases = append(queryTestCases, requestBodyTestCase{
|
||||
path: SearchAction,
|
||||
requestBody: []byte(`{"collectionName": "book", "vector": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"]}`),
|
||||
})
|
||||
queryTestCases = append(queryTestCases, requestBodyTestCase{
|
||||
path: SearchAction,
|
||||
requestBody: []byte(`{"collectionName": "book", "vector": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "params": {"radius":0.9}}`),
|
||||
})
|
||||
queryTestCases = append(queryTestCases, requestBodyTestCase{
|
||||
path: SearchAction,
|
||||
requestBody: []byte(`{"collectionName": "book", "vector": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "params": {"range_filter": 0.1}}`),
|
||||
errMsg: "can only accept json format request, error: invalid search params",
|
||||
errCode: 1801, // ErrIncorrectParameterFormat
|
||||
})
|
||||
queryTestCases = append(queryTestCases, requestBodyTestCase{
|
||||
path: SearchAction,
|
||||
requestBody: []byte(`{"collectionName": "book", "vector": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "params": {"radius":0.9, "range_filter": 0.1}, "groupingField": "word_count"}`),
|
||||
})
|
||||
queryTestCases = append(queryTestCases, requestBodyTestCase{
|
||||
path: SearchAction,
|
||||
requestBody: []byte(`{"collectionName": "book", "vector": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "params": {"radius":0.9, "range_filter": 0.1}, "groupingField": "test"}`),
|
||||
errMsg: "groupBy field not found in schema: field not found[field=test]",
|
||||
errCode: 65535,
|
||||
})
|
||||
queryTestCases = append(queryTestCases, requestBodyTestCase{
|
||||
path: SearchAction,
|
||||
requestBody: []byte(`{"collectionName": "book", "vector": [["0.1", "0.2"]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "params": {"radius":0.9, "range_filter": 0.1}, "groupingField": "test"}`),
|
||||
errMsg: "can only accept json format request, error: json: cannot unmarshal string into Go struct field SearchReqV2.vector of type float32",
|
||||
errCode: 1801,
|
||||
})
|
||||
queryTestCases = append(queryTestCases, requestBodyTestCase{
|
||||
path: HybridSearchAction,
|
||||
requestBody: []byte(`{"collectionName": "hello_milvus", "search": [{"vector": [[0.1, 0.2]], "annsField": "float_vector1", "metricType": "L2", "limit": 3}, {"vector": [[0.1, 0.2]], "annsField": "float_vector2", "metricType": "L2", "limit": 3}], "rerank": {"strategy": "rrf", "params": {"k": 1}}}`),
|
||||
})
|
||||
queryTestCases = append(queryTestCases, requestBodyTestCase{
|
||||
path: HybridSearchAction,
|
||||
requestBody: []byte(`{"collectionName": "hello_milvus", "search": [{"vector": [[0.1, 0.2]], "annsField": "float_vector1", "metricType": "L2", "limit": 3}, {"vector": [[0.1, 0.2]], "annsField": "float_vector2", "metricType": "L2", "limit": 3}], "rerank": {"strategy": "weighted", "params": {"weights": [0.9, 0.8]}}}`),
|
||||
})
|
||||
queryTestCases = append(queryTestCases, requestBodyTestCase{
|
||||
path: QueryAction,
|
||||
requestBody: []byte(`{"collectionName": "book", "filter": "book_id in [2, 4, 6, 8]", "outputFields": ["book_id", "word_count", "book_intro"], "offset": 1}`),
|
||||
})
|
||||
queryTestCases = append(queryTestCases, requestBodyTestCase{
|
||||
path: GetAction,
|
||||
requestBody: []byte(`{"collectionName": "book", "id" : [2, 4, 6, 8, 0], "outputFields": ["book_id", "word_count", "book_intro"]}`),
|
||||
requestBody: []byte(`{"collectionName": "book", "outputFields": ["book_id", "word_count", "book_intro"]}`),
|
||||
errMsg: "missing required parameters, error: Key: 'CollectionIDReq.ID' Error:Field validation for 'ID' failed on the 'required' tag",
|
||||
errCode: 1802, // ErrMissingRequiredParameters
|
||||
})
|
||||
queryTestCases = append(queryTestCases, requestBodyTestCase{
|
||||
path: QueryAction,
|
||||
requestBody: []byte(`{"collectionName": "book", "filter": "book_id in [2, 4, 6, 8]"}`),
|
||||
})
|
||||
queryTestCases = append(queryTestCases, requestBodyTestCase{
|
||||
path: InsertAction,
|
||||
@ -1071,9 +1054,19 @@ func TestDML(t *testing.T) {
|
||||
path: UpsertAction,
|
||||
requestBody: []byte(`{"collectionName": "book", "data": [{"book_id": 0, "word_count": 0, "book_intro": [0.11825, 0.6]}]}`),
|
||||
})
|
||||
queryTestCases = append(queryTestCases, requestBodyTestCase{
|
||||
path: DeleteAction,
|
||||
requestBody: []byte(`{"collectionName": "book", "filter": "book_id in [0]"}`),
|
||||
})
|
||||
queryTestCases = append(queryTestCases, requestBodyTestCase{
|
||||
path: DeleteAction,
|
||||
requestBody: []byte(`{"collectionName": "book", "id" : [0]}`),
|
||||
errMsg: "missing required parameters, error: Key: 'CollectionFilterReq.Filter' Error:Field validation for 'Filter' failed on the 'required' tag",
|
||||
errCode: 1802, // ErrMissingRequiredParameters
|
||||
})
|
||||
queryTestCases = append(queryTestCases, requestBodyTestCase{
|
||||
path: GetAction,
|
||||
requestBody: []byte(`{"collectionName": "book", "id" : [2, 4, 6, 8, 0], "outputFields": ["book_id", "word_count", "book_intro"]}`),
|
||||
})
|
||||
queryTestCases = append(queryTestCases, requestBodyTestCase{
|
||||
path: GetAction,
|
||||
@ -1095,7 +1088,7 @@ func TestDML(t *testing.T) {
|
||||
})
|
||||
queryTestCases = append(queryTestCases, requestBodyTestCase{
|
||||
path: DeleteAction,
|
||||
requestBody: []byte(`{"collectionName": "book", "id" : [0]}`),
|
||||
requestBody: []byte(`{"collectionName": "book", "filter": "book_id in [0]"}`),
|
||||
errMsg: "",
|
||||
errCode: 65535,
|
||||
})
|
||||
@ -1120,3 +1113,180 @@ func TestDML(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchV2(t *testing.T) {
|
||||
paramtable.Init()
|
||||
mp := mocks.NewMockProxy(t)
|
||||
mp.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{
|
||||
CollectionName: DefaultCollectionName,
|
||||
Schema: generateCollectionSchema(schemapb.DataType_Int64),
|
||||
ShardsNum: ShardNumDefault,
|
||||
Status: &StatusSuccess,
|
||||
}, nil).Times(9)
|
||||
mp.EXPECT().Search(mock.Anything, mock.Anything).Return(&milvuspb.SearchResults{Status: commonSuccessStatus, Results: &schemapb.SearchResultData{TopK: int64(0)}}, nil).Times(3)
|
||||
mp.EXPECT().Search(mock.Anything, mock.Anything).Return(&milvuspb.SearchResults{Status: &commonpb.Status{
|
||||
ErrorCode: 1700, // ErrFieldNotFound
|
||||
Reason: "groupBy field not found in schema: field not found[field=test]",
|
||||
}}, nil).Once()
|
||||
mp.EXPECT().HybridSearch(mock.Anything, mock.Anything).Return(&milvuspb.SearchResults{Status: commonSuccessStatus, Results: &schemapb.SearchResultData{TopK: int64(0)}}, nil).Twice()
|
||||
collSchema := generateCollectionSchema(schemapb.DataType_Int64)
|
||||
binaryVectorField := generateVectorFieldSchema(schemapb.DataType_BinaryVector)
|
||||
binaryVectorField.Name = "binaryVector"
|
||||
float16VectorField := generateVectorFieldSchema(schemapb.DataType_Float16Vector)
|
||||
float16VectorField.Name = "float16Vector"
|
||||
bfloat16VectorField := generateVectorFieldSchema(schemapb.DataType_BFloat16Vector)
|
||||
bfloat16VectorField.Name = "bfloat16Vector"
|
||||
collSchema.Fields = append(collSchema.Fields, &binaryVectorField)
|
||||
collSchema.Fields = append(collSchema.Fields, &float16VectorField)
|
||||
collSchema.Fields = append(collSchema.Fields, &bfloat16VectorField)
|
||||
mp.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{
|
||||
CollectionName: DefaultCollectionName,
|
||||
Schema: collSchema,
|
||||
ShardsNum: ShardNumDefault,
|
||||
Status: &StatusSuccess,
|
||||
}, nil).Times(9)
|
||||
mp.EXPECT().Search(mock.Anything, mock.Anything).Return(&milvuspb.SearchResults{Status: commonSuccessStatus, Results: &schemapb.SearchResultData{TopK: int64(0)}}, nil).Twice()
|
||||
testEngine := initHTTPServerV2(mp, false)
|
||||
queryTestCases := []requestBodyTestCase{}
|
||||
queryTestCases = append(queryTestCases, requestBodyTestCase{
|
||||
path: SearchAction,
|
||||
requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"]}`),
|
||||
})
|
||||
queryTestCases = append(queryTestCases, requestBodyTestCase{
|
||||
path: SearchAction,
|
||||
requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "params": {"radius":0.9}}`),
|
||||
})
|
||||
queryTestCases = append(queryTestCases, requestBodyTestCase{
|
||||
path: SearchAction,
|
||||
requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "params": {"range_filter": 0.1}}`),
|
||||
errMsg: "can only accept json format request, error: invalid search params",
|
||||
errCode: 1801, // ErrIncorrectParameterFormat
|
||||
})
|
||||
queryTestCases = append(queryTestCases, requestBodyTestCase{
|
||||
path: SearchAction,
|
||||
requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "params": {"radius":0.9, "range_filter": 0.1}, "groupingField": "word_count"}`),
|
||||
})
|
||||
queryTestCases = append(queryTestCases, requestBodyTestCase{
|
||||
path: SearchAction,
|
||||
requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "params": {"radius":0.9, "range_filter": 0.1}, "groupingField": "test"}`),
|
||||
errMsg: "groupBy field not found in schema: field not found[field=test]",
|
||||
errCode: 65535,
|
||||
})
|
||||
queryTestCases = append(queryTestCases, requestBodyTestCase{
|
||||
path: SearchAction,
|
||||
requestBody: []byte(`{"collectionName": "book", "data": [["0.1", "0.2"]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "params": {"radius":0.9, "range_filter": 0.1}, "groupingField": "test"}`),
|
||||
errMsg: "can only accept json format request, error: json: cannot unmarshal string into Go value of type float32: invalid parameter[expected=FloatVector][actual=[\"0.1\", \"0.2\"]]",
|
||||
errCode: 1801,
|
||||
})
|
||||
queryTestCases = append(queryTestCases, requestBodyTestCase{
|
||||
path: AdvancedSearchAction,
|
||||
requestBody: []byte(`{"collectionName": "hello_milvus", "search": [{"data": [[0.1, 0.2]], "annsField": "book_intro", "metricType": "L2", "limit": 3}, {"data": [[0.1, 0.2]], "annsField": "book_intro", "metricType": "L2", "limit": 3}], "rerank": {"strategy": "weighted", "params": {"weights": [0.9, 0.8]}}}`),
|
||||
})
|
||||
// annsField
|
||||
queryTestCases = append(queryTestCases, requestBodyTestCase{
|
||||
path: SearchAction,
|
||||
requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "annsField": "word_count", "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "params": {"radius":0.9, "range_filter": 0.1}, "groupingField": "test"}`),
|
||||
errMsg: "can only accept json format request, error: cannot find a vector field named: word_count",
|
||||
errCode: 1801,
|
||||
})
|
||||
queryTestCases = append(queryTestCases, requestBodyTestCase{
|
||||
path: AdvancedSearchAction,
|
||||
requestBody: []byte(`{"collectionName": "hello_milvus", "search": [{"data": [[0.1, 0.2]], "annsField": "float_vector1", "metricType": "L2", "limit": 3}, {"data": [[0.1, 0.2]], "annsField": "float_vector2", "metricType": "L2", "limit": 3}], "rerank": {"strategy": "rrf", "params": {"k": 1}}}`),
|
||||
errMsg: "can only accept json format request, error: cannot find a vector field named: float_vector1",
|
||||
errCode: 1801,
|
||||
})
|
||||
// multiple annsFields
|
||||
queryTestCases = append(queryTestCases, requestBodyTestCase{
|
||||
path: SearchAction,
|
||||
requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"]}`),
|
||||
errMsg: "can only accept json format request, error: search without annsFields, but already found multiple vector fields: [book_intro, binaryVector]",
|
||||
errCode: 1801,
|
||||
})
|
||||
queryTestCases = append(queryTestCases, requestBodyTestCase{
|
||||
path: SearchAction,
|
||||
requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "annsField": "book_intro", "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"]}`),
|
||||
})
|
||||
queryTestCases = append(queryTestCases, requestBodyTestCase{
|
||||
path: SearchAction,
|
||||
requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "annsField": "binaryVector", "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"]}`),
|
||||
errMsg: "can only accept json format request, error: json: cannot unmarshal number 0.1 into Go value of type uint8: invalid parameter[expected=BinaryVector][actual=[[0.1, 0.2]]]",
|
||||
errCode: 1801,
|
||||
})
|
||||
queryTestCases = append(queryTestCases, requestBodyTestCase{
|
||||
path: SearchAction,
|
||||
requestBody: []byte(`{"collectionName": "book", "data": ["AQ=="], "annsField": "binaryVector", "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"]}`),
|
||||
})
|
||||
queryTestCases = append(queryTestCases, requestBodyTestCase{
|
||||
path: AdvancedSearchAction,
|
||||
requestBody: []byte(`{"collectionName": "hello_milvus", "search": [` +
|
||||
`{"data": [[0.1, 0.2]], "annsField": "book_intro", "metricType": "L2", "limit": 3},` +
|
||||
`{"data": ["AQ=="], "annsField": "binaryVector", "metricType": "L2", "limit": 3},` +
|
||||
`{"data": ["AQIDBA=="], "annsField": "float16Vector", "metricType": "L2", "limit": 3},` +
|
||||
`{"data": ["AQIDBA=="], "annsField": "bfloat16Vector", "metricType": "L2", "limit": 3}` +
|
||||
`], "rerank": {"strategy": "weighted", "params": {"weights": [0.9, 0.8]}}}`),
|
||||
})
|
||||
queryTestCases = append(queryTestCases, requestBodyTestCase{
|
||||
path: AdvancedSearchAction,
|
||||
requestBody: []byte(`{"collectionName": "hello_milvus", "search": [` +
|
||||
`{"data": [[0.1, 0.2, 0.3]], "annsField": "book_intro", "metricType": "L2", "limit": 3},` +
|
||||
`{"data": ["AQ=="], "annsField": "binaryVector", "metricType": "L2", "limit": 3},` +
|
||||
`{"data": ["AQIDBA=="], "annsField": "float16Vector", "metricType": "L2", "limit": 3},` +
|
||||
`{"data": ["AQIDBA=="], "annsField": "bfloat16Vector", "metricType": "L2", "limit": 3}` +
|
||||
`], "rerank": {"strategy": "weighted", "params": {"weights": [0.9, 0.8]}}}`),
|
||||
errMsg: "can only accept json format request, error: dimension: 2, but length of []float: 3: invalid parameter[expected=FloatVector][actual=[0.1, 0.2, 0.3]]",
|
||||
errCode: 1801,
|
||||
})
|
||||
queryTestCases = append(queryTestCases, requestBodyTestCase{
|
||||
path: AdvancedSearchAction,
|
||||
requestBody: []byte(`{"collectionName": "hello_milvus", "search": [` +
|
||||
`{"data": [[0.1, 0.2]], "annsField": "book_intro", "metricType": "L2", "limit": 3},` +
|
||||
`{"data": ["AQID"], "annsField": "binaryVector", "metricType": "L2", "limit": 3},` +
|
||||
`{"data": ["AQIDBA=="], "annsField": "float16Vector", "metricType": "L2", "limit": 3},` +
|
||||
`{"data": ["AQIDBA=="], "annsField": "bfloat16Vector", "metricType": "L2", "limit": 3}` +
|
||||
`], "rerank": {"strategy": "weighted", "params": {"weights": [0.9, 0.8]}}}`),
|
||||
errMsg: "can only accept json format request, error: dimension: 8, bytesLen: 1, but length of []byte: 3: invalid parameter[expected=BinaryVector][actual=\x01\x02\x03]",
|
||||
errCode: 1801,
|
||||
})
|
||||
queryTestCases = append(queryTestCases, requestBodyTestCase{
|
||||
path: AdvancedSearchAction,
|
||||
requestBody: []byte(`{"collectionName": "hello_milvus", "search": [` +
|
||||
`{"data": [[0.1, 0.2]], "annsField": "book_intro", "metricType": "L2", "limit": 3},` +
|
||||
`{"data": ["AQ=="], "annsField": "binaryVector", "metricType": "L2", "limit": 3},` +
|
||||
`{"data": ["AQID"], "annsField": "float16Vector", "metricType": "L2", "limit": 3},` +
|
||||
`{"data": ["AQIDBA=="], "annsField": "bfloat16Vector", "metricType": "L2", "limit": 3}` +
|
||||
`], "rerank": {"strategy": "weighted", "params": {"weights": [0.9, 0.8]}}}`),
|
||||
errMsg: "can only accept json format request, error: dimension: 2, bytesLen: 4, but length of []byte: 3: invalid parameter[expected=Float16Vector][actual=\x01\x02\x03]",
|
||||
errCode: 1801,
|
||||
})
|
||||
queryTestCases = append(queryTestCases, requestBodyTestCase{
|
||||
path: AdvancedSearchAction,
|
||||
requestBody: []byte(`{"collectionName": "hello_milvus", "search": [` +
|
||||
`{"data": [[0.1, 0.2]], "annsField": "book_intro", "metricType": "L2", "limit": 3},` +
|
||||
`{"data": ["AQ=="], "annsField": "binaryVector", "metricType": "L2", "limit": 3},` +
|
||||
`{"data": ["AQIDBA=="], "annsField": "float16Vector", "metricType": "L2", "limit": 3},` +
|
||||
`{"data": ["AQID"], "annsField": "bfloat16Vector", "metricType": "L2", "limit": 3}` +
|
||||
`], "rerank": {"strategy": "weighted", "params": {"weights": [0.9, 0.8]}}}`),
|
||||
errMsg: "can only accept json format request, error: dimension: 2, bytesLen: 4, but length of []byte: 3: invalid parameter[expected=BFloat16Vector][actual=\x01\x02\x03]",
|
||||
errCode: 1801,
|
||||
})
|
||||
|
||||
for _, testcase := range queryTestCases {
|
||||
t.Run("search", func(t *testing.T) {
|
||||
bodyReader := bytes.NewReader(testcase.requestBody)
|
||||
req := httptest.NewRequest(http.MethodPost, versionalV2(EntityCategory, testcase.path), bodyReader)
|
||||
w := httptest.NewRecorder()
|
||||
testEngine.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
returnBody := &ReturnErrMsg{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), returnBody)
|
||||
assert.Nil(t, err)
|
||||
if testcase.errCode != 0 {
|
||||
assert.Equal(t, testcase.errCode, returnBody.Code)
|
||||
assert.Equal(t, testcase.errMsg, returnBody.Message)
|
||||
} else {
|
||||
assert.Equal(t, int32(http.StatusOK), returnBody.Code)
|
||||
}
|
||||
fmt.Println(w.Body.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -99,7 +99,7 @@ type QueryReqV2 struct {
|
||||
|
||||
func (req *QueryReqV2) GetDbName() string { return req.DbName }
|
||||
|
||||
type CollectionIDOutputReq struct {
|
||||
type CollectionIDReq struct {
|
||||
DbName string `json:"dbName"`
|
||||
CollectionName string `json:"collectionName" binding:"required"`
|
||||
PartitionName string `json:"partitionName"`
|
||||
@ -108,17 +108,16 @@ type CollectionIDOutputReq struct {
|
||||
ID interface{} `json:"id" binding:"required"`
|
||||
}
|
||||
|
||||
func (req *CollectionIDOutputReq) GetDbName() string { return req.DbName }
|
||||
func (req *CollectionIDReq) GetDbName() string { return req.DbName }
|
||||
|
||||
type CollectionIDFilterReq struct {
|
||||
DbName string `json:"dbName"`
|
||||
CollectionName string `json:"collectionName" binding:"required"`
|
||||
PartitionName string `json:"partitionName"`
|
||||
ID interface{} `json:"id"`
|
||||
Filter string `json:"filter"`
|
||||
type CollectionFilterReq struct {
|
||||
DbName string `json:"dbName"`
|
||||
CollectionName string `json:"collectionName" binding:"required"`
|
||||
PartitionName string `json:"partitionName"`
|
||||
Filter string `json:"filter" binding:"required"`
|
||||
}
|
||||
|
||||
func (req *CollectionIDFilterReq) GetDbName() string { return req.DbName }
|
||||
func (req *CollectionFilterReq) GetDbName() string { return req.DbName }
|
||||
|
||||
type CollectionDataReq struct {
|
||||
DbName string `json:"dbName"`
|
||||
@ -132,7 +131,8 @@ func (req *CollectionDataReq) GetDbName() string { return req.DbName }
|
||||
type SearchReqV2 struct {
|
||||
DbName string `json:"dbName"`
|
||||
CollectionName string `json:"collectionName" binding:"required"`
|
||||
Vector [][]float32 `json:"vector"`
|
||||
Data []interface{} `json:"data" binding:"required"`
|
||||
AnnsField string `json:"annsField"`
|
||||
PartitionNames []string `json:"partitionNames"`
|
||||
Filter string `json:"filter"`
|
||||
GroupByField string `json:"groupingField"`
|
||||
@ -150,7 +150,7 @@ type Rand struct {
|
||||
}
|
||||
|
||||
type SubSearchReq struct {
|
||||
Vector [][]float32 `json:"vector"`
|
||||
Data []interface{} `json:"data" binding:"required"`
|
||||
AnnsField string `json:"annsField"`
|
||||
Filter string `json:"filter"`
|
||||
GroupByField string `json:"groupingField"`
|
||||
@ -296,12 +296,15 @@ type CollectionSchema struct {
|
||||
}
|
||||
|
||||
type CollectionReq struct {
|
||||
DbName string `json:"dbName"`
|
||||
CollectionName string `json:"collectionName" binding:"required"`
|
||||
Dimension int32 `json:"dimension"`
|
||||
MetricType string `json:"metricType"`
|
||||
Schema CollectionSchema `json:"schema"`
|
||||
IndexParams []IndexParam `json:"indexParams"`
|
||||
DbName string `json:"dbName"`
|
||||
CollectionName string `json:"collectionName" binding:"required"`
|
||||
Dimension int32 `json:"dimension"`
|
||||
IDType string `json:"idType"`
|
||||
MetricType string `json:"metricType"`
|
||||
PrimaryFieldName string `json:"primaryFieldName"`
|
||||
VectorFieldName string `json:"vectorFieldName"`
|
||||
Schema CollectionSchema `json:"schema"`
|
||||
IndexParams []IndexParam `json:"indexParams"`
|
||||
}
|
||||
|
||||
func (req *CollectionReq) GetDbName() string { return req.DbName }
|
||||
|
@ -183,7 +183,7 @@ func printIndexes(indexes []*milvuspb.IndexDescription) []gin.H {
|
||||
|
||||
func checkAndSetData(body string, collSchema *schemapb.CollectionSchema) (error, []map[string]interface{}) {
|
||||
var reallyDataArray []map[string]interface{}
|
||||
dataResult := gjson.Get(body, "data")
|
||||
dataResult := gjson.Get(body, HTTPRequestData)
|
||||
dataResultArray := dataResult.Array()
|
||||
if len(dataResultArray) == 0 {
|
||||
return merr.ErrMissingRequiredParameters, reallyDataArray
|
||||
@ -914,6 +914,67 @@ func serialize(fv []float32) []byte {
|
||||
return data
|
||||
}
|
||||
|
||||
func serializeFloatVectors(vectors []gjson.Result, dataType schemapb.DataType, dimension, bytesLen int64) ([][]byte, error) {
|
||||
values := make([][]byte, 0)
|
||||
for _, vector := range vectors {
|
||||
var vectorArray []float32
|
||||
err := json.Unmarshal([]byte(vector.String()), &vectorArray)
|
||||
if err != nil {
|
||||
return nil, merr.WrapErrParameterInvalid(schemapb.DataType_name[int32(dataType)], vector.String(), err.Error())
|
||||
}
|
||||
if int64(len(vectorArray)) != dimension {
|
||||
return nil, merr.WrapErrParameterInvalid(schemapb.DataType_name[int32(dataType)], vector.String(),
|
||||
fmt.Sprintf("dimension: %d, but length of []float: %d", dimension, len(vectorArray)))
|
||||
}
|
||||
vectorBytes := serialize(vectorArray)
|
||||
values = append(values, vectorBytes)
|
||||
}
|
||||
return values, nil
|
||||
}
|
||||
|
||||
func serializeByteVectors(vectorStr string, dataType schemapb.DataType, dimension, bytesLen int64) ([][]byte, error) {
|
||||
values := make([][]byte, 0)
|
||||
err := json.Unmarshal([]byte(vectorStr), &values) // todo check len == dimension * 1/2/2
|
||||
if err != nil {
|
||||
return nil, merr.WrapErrParameterInvalid(schemapb.DataType_name[int32(dataType)], vectorStr, err.Error())
|
||||
}
|
||||
for _, vectorArray := range values {
|
||||
if int64(len(vectorArray)) != bytesLen {
|
||||
return nil, merr.WrapErrParameterInvalid(schemapb.DataType_name[int32(dataType)], string(vectorArray),
|
||||
fmt.Sprintf("dimension: %d, bytesLen: %d, but length of []byte: %d", dimension, bytesLen, len(vectorArray)))
|
||||
}
|
||||
}
|
||||
return values, nil
|
||||
}
|
||||
|
||||
func convertVectors2Placeholder(body string, dataType schemapb.DataType, dimension int64) (*commonpb.PlaceholderValue, error) {
|
||||
var valueType commonpb.PlaceholderType
|
||||
var values [][]byte
|
||||
var err error
|
||||
switch dataType {
|
||||
case schemapb.DataType_FloatVector:
|
||||
valueType = commonpb.PlaceholderType_FloatVector
|
||||
values, err = serializeFloatVectors(gjson.Get(body, HTTPRequestData).Array(), dataType, dimension, dimension*4)
|
||||
case schemapb.DataType_BinaryVector:
|
||||
valueType = commonpb.PlaceholderType_BinaryVector
|
||||
values, err = serializeByteVectors(gjson.Get(body, HTTPRequestData).Raw, dataType, dimension, dimension/8)
|
||||
case schemapb.DataType_Float16Vector:
|
||||
valueType = commonpb.PlaceholderType_Float16Vector
|
||||
values, err = serializeByteVectors(gjson.Get(body, HTTPRequestData).Raw, dataType, dimension, dimension*2)
|
||||
case schemapb.DataType_BFloat16Vector:
|
||||
valueType = commonpb.PlaceholderType_BFloat16Vector
|
||||
values, err = serializeByteVectors(gjson.Get(body, HTTPRequestData).Raw, dataType, dimension, dimension*2)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &commonpb.PlaceholderValue{
|
||||
Tag: "$0",
|
||||
Type: valueType,
|
||||
Values: values,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// todo: support [][]byte for BinaryVector
|
||||
func vectors2PlaceholderGroupBytes(vectors [][]float32) []byte {
|
||||
var placeHolderType commonpb.PlaceholderType
|
||||
|
@ -8,6 +8,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/tidwall/gjson"
|
||||
|
||||
@ -492,6 +493,42 @@ func TestSerialize(t *testing.T) {
|
||||
parameters := []float32{0.11111, 0.22222}
|
||||
assert.Equal(t, "\xa4\x8d\xe3=\xa4\x8dc>", string(serialize(parameters)))
|
||||
assert.Equal(t, "\n\x10\n\x02$0\x10e\x1a\b\xa4\x8d\xe3=\xa4\x8dc>", string(vectors2PlaceholderGroupBytes([][]float32{parameters}))) // todo
|
||||
requestBody := "{\"data\": [[0.11111, 0.22222]]}"
|
||||
vectors := gjson.Get(requestBody, HTTPRequestData)
|
||||
values, err := serializeFloatVectors(vectors.Array(), schemapb.DataType_FloatVector, 2, -1)
|
||||
assert.Nil(t, err)
|
||||
placeholderValue := &commonpb.PlaceholderValue{
|
||||
Tag: "$0",
|
||||
Type: commonpb.PlaceholderType_FloatVector,
|
||||
Values: values,
|
||||
}
|
||||
bytes, err := proto.Marshal(&commonpb.PlaceholderGroup{
|
||||
Placeholders: []*commonpb.PlaceholderValue{
|
||||
placeholderValue,
|
||||
},
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "\n\x10\n\x02$0\x10e\x1a\b\xa4\x8d\xe3=\xa4\x8dc>", string(bytes)) // todo
|
||||
for _, dataType := range []schemapb.DataType{schemapb.DataType_BinaryVector, schemapb.DataType_Float16Vector, schemapb.DataType_BFloat16Vector} {
|
||||
request := map[string]interface{}{
|
||||
HTTPRequestData: []interface{}{
|
||||
[]byte{1, 2},
|
||||
},
|
||||
}
|
||||
requestBody, _ := json.Marshal(request)
|
||||
values, err = serializeByteVectors(gjson.Get(string(requestBody), HTTPRequestData).Raw, dataType, -1, 2)
|
||||
assert.Nil(t, err)
|
||||
placeholderValue = &commonpb.PlaceholderValue{
|
||||
Tag: "$0",
|
||||
Values: values,
|
||||
}
|
||||
_, err = proto.Marshal(&commonpb.PlaceholderGroup{
|
||||
Placeholders: []*commonpb.PlaceholderValue{
|
||||
placeholderValue,
|
||||
},
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
func compareRow64(m1 map[string]interface{}, m2 map[string]interface{}) bool {
|
||||
|
Loading…
Reference in New Issue
Block a user