fix: restful v2 (#32144)

issue: #31176

1. cannot get dbName correctly while describe alias #31978
2. return a valid json string even if the user doesn't have the whole
privileges to describe collection #31635
3. rename IndexParam.IndexConfig to IndexParam.Params
4. FieldSchema.ElementTypeParams, IndexParam.Params can not only accept
string

Signed-off-by: PowderLi <min.li@zilliz.com>
This commit is contained in:
PowderLi 2024-04-13 21:55:29 +08:00 committed by GitHub
parent ab6ddf6929
commit 610a65af14
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 124 additions and 51 deletions

View File

@ -3,6 +3,7 @@ package httpserver
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strconv"
@ -10,7 +11,7 @@ import (
"github.com/cockroachdb/errors"
"github.com/gin-gonic/gin"
"github.com/gin-gonic/gin/binding"
"github.com/go-playground/validator/v10"
validator "github.com/go-playground/validator/v10"
"github.com/golang/protobuf/proto"
"github.com/samber/lo"
"github.com/tidwall/gjson"
@ -222,13 +223,32 @@ func wrapperTraceLog(v2 handlerFuncV2) handlerFuncV2 {
}
}
func checkAuthorizationV2(ctx context.Context, c *gin.Context, ignoreErr bool, req interface{}) error {
username, ok := c.Get(ContextUsername)
if !ok || username.(string) == "" {
if !ignoreErr {
c.JSON(http.StatusUnauthorized, gin.H{HTTPReturnCode: merr.Code(merr.ErrNeedAuthenticate), HTTPReturnMessage: merr.ErrNeedAuthenticate.Error()})
}
return merr.ErrNeedAuthenticate
}
_, authErr := proxy.PrivilegeInterceptor(ctx, req)
if authErr != nil {
if !ignoreErr {
c.JSON(http.StatusForbidden, gin.H{HTTPReturnCode: merr.Code(authErr), HTTPReturnMessage: authErr.Error()})
}
return authErr
}
return nil
}
func wrapperProxy(ctx context.Context, c *gin.Context, req any, checkAuth bool, ignoreErr bool, handler func(reqCtx context.Context, req any) (any, error)) (interface{}, error) {
if baseGetter, ok := req.(BaseGetter); ok {
span := trace.SpanFromContext(ctx)
span.AddEvent(baseGetter.GetBase().GetMsgType().String())
}
if checkAuth {
err := checkAuthorization(ctx, c, req)
err := checkAuthorizationV2(ctx, c, ignoreErr, req)
if err != nil {
return nil, err
}
@ -331,6 +351,7 @@ func (h *HandlersV2) getCollectionDetails(ctx context.Context, c *gin.Context, a
} else {
autoID = primaryField.AutoID
}
errMessage := ""
loadStateReq := &milvuspb.GetLoadStateRequest{
DbName: dbName,
CollectionName: collectionName,
@ -341,6 +362,8 @@ func (h *HandlersV2) getCollectionDetails(ctx context.Context, c *gin.Context, a
collLoadState := ""
if err == nil {
collLoadState = stateResp.(*milvuspb.GetLoadStateResponse).State.String()
} else {
errMessage += err.Error() + ";"
}
vectorField := ""
for _, field := range coll.Schema.Fields {
@ -355,22 +378,26 @@ func (h *HandlersV2) getCollectionDetails(ctx context.Context, c *gin.Context, a
CollectionName: collectionName,
FieldName: vectorField,
}
indexResp, err := wrapperProxy(ctx, c, descIndexReq, false, true, func(reqCtx context.Context, req any) (any, error) {
indexResp, err := wrapperProxy(ctx, c, descIndexReq, h.checkAuth, true, func(reqCtx context.Context, req any) (any, error) {
return h.proxy.DescribeIndex(reqCtx, req.(*milvuspb.DescribeIndexRequest))
})
if err == nil {
indexDesc = printIndexes(indexResp.(*milvuspb.DescribeIndexResponse).IndexDescriptions)
} else {
errMessage += err.Error() + ";"
}
var aliases []string
aliasReq := &milvuspb.ListAliasesRequest{
DbName: dbName,
CollectionName: collectionName,
}
aliasResp, err := wrapperProxy(ctx, c, aliasReq, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) {
aliasResp, err := wrapperProxy(ctx, c, aliasReq, h.checkAuth, true, func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.ListAliases(reqCtx, req.(*milvuspb.ListAliasesRequest))
})
if err == nil {
aliases = aliasResp.(*milvuspb.ListAliasesResponse).GetAliases()
} else {
errMessage += err.Error() + "."
}
if aliases == nil {
aliases = []string{}
@ -392,7 +419,7 @@ func (h *HandlersV2) getCollectionDetails(ctx context.Context, c *gin.Context, a
"consistencyLevel": commonpb.ConsistencyLevel_name[int32(coll.ConsistencyLevel)],
"enableDynamicField": coll.Schema.EnableDynamicField,
"properties": coll.Properties,
}})
}, HTTPReturnMessage: errMessage})
return resp, nil
}
@ -443,8 +470,11 @@ func (h *HandlersV2) getCollectionLoadState(ctx context.Context, c *gin.Context,
return h.proxy.GetLoadingProgress(reqCtx, req.(*milvuspb.GetLoadingProgressRequest))
})
progress := int64(-1)
errMessage := ""
if err == nil {
progress = progressResp.(*milvuspb.GetLoadingProgressResponse).Progress
} else {
errMessage += err.Error() + "."
}
state := commonpb.LoadState_LoadStateLoading.String()
if progress >= 100 {
@ -453,7 +483,7 @@ func (h *HandlersV2) getCollectionLoadState(ctx context.Context, c *gin.Context,
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: gin.H{
HTTPReturnLoadState: state,
HTTPReturnLoadProgress: progress,
}})
}, HTTPReturnMessage: errMessage})
return resp, err
}
@ -1081,7 +1111,7 @@ func (h *HandlersV2) createCollection(ctx context.Context, c *gin.Context, anyRe
}
}
for key, fieldParam := range field.ElementTypeParams {
fieldSchema.TypeParams = append(fieldSchema.TypeParams, &commonpb.KeyValuePair{Key: key, Value: fieldParam})
fieldSchema.TypeParams = append(fieldSchema.TypeParams, &commonpb.KeyValuePair{Key: key, Value: fmt.Sprintf("%v", fieldParam)})
}
collSchema.Fields = append(collSchema.Fields, &fieldSchema)
fieldNames[field.FieldName] = true
@ -1175,8 +1205,8 @@ func (h *HandlersV2) createCollection(ctx context.Context, c *gin.Context, anyRe
IndexName: indexParam.IndexName,
ExtraParams: []*commonpb.KeyValuePair{{Key: common.MetricTypeKey, Value: indexParam.MetricType}},
}
for key, value := range indexParam.IndexConfig {
createIndexReq.ExtraParams = append(createIndexReq.ExtraParams, &commonpb.KeyValuePair{Key: key, Value: value})
for key, value := range indexParam.Params {
createIndexReq.ExtraParams = append(createIndexReq.ExtraParams, &commonpb.KeyValuePair{Key: key, Value: fmt.Sprintf("%v", value)})
}
statusResponse, err := wrapperProxy(ctx, c, createIndexReq, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.CreateIndex(ctx, req.(*milvuspb.CreateIndexRequest))
@ -1603,8 +1633,8 @@ func (h *HandlersV2) createIndex(ctx context.Context, c *gin.Context, anyReq any
{Key: common.MetricTypeKey, Value: indexParam.MetricType},
},
}
for key, value := range indexParam.IndexConfig {
req.ExtraParams = append(req.ExtraParams, &commonpb.KeyValuePair{Key: key, Value: value})
for key, value := range indexParam.Params {
req.ExtraParams = append(req.ExtraParams, &commonpb.KeyValuePair{Key: key, Value: fmt.Sprintf("%v", value)})
}
resp, err := wrapperProxy(ctx, c, req, false, false, func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.CreateIndex(reqCtx, req.(*milvuspb.CreateIndexRequest))
@ -1728,7 +1758,7 @@ func (h *HandlersV2) listImportJob(ctx context.Context, c *gin.Context, anyReq a
CollectionName: collectionName,
}
if h.checkAuth {
err := checkAuthorization(ctx, c, &milvuspb.ListImportsAuthPlaceholder{
err := checkAuthorizationV2(ctx, c, false, &milvuspb.ListImportsAuthPlaceholder{
DbName: dbName,
CollectionName: collectionName,
})
@ -1778,7 +1808,7 @@ func (h *HandlersV2) createImportJob(ctx context.Context, c *gin.Context, anyReq
Options: funcutil.Map2KeyValuePair(optionsGetter.GetOptions()),
}
if h.checkAuth {
err := checkAuthorization(ctx, c, &milvuspb.ImportAuthPlaceholder{
err := checkAuthorizationV2(ctx, c, false, &milvuspb.ImportAuthPlaceholder{
DbName: dbName,
CollectionName: collectionGetter.GetCollectionName(),
PartitionName: partitionGetter.GetPartitionName(),
@ -1805,7 +1835,7 @@ func (h *HandlersV2) getImportJobProcess(ctx context.Context, c *gin.Context, an
JobID: jobIDGetter.GetJobID(),
}
if h.checkAuth {
err := checkAuthorization(ctx, c, &milvuspb.GetImportProgressAuthPlaceholder{
err := checkAuthorizationV2(ctx, c, false, &milvuspb.GetImportProgressAuthPlaceholder{
DbName: dbName,
})
if err != nil {

View File

@ -288,6 +288,39 @@ func TestGrpcWrapper(t *testing.T) {
fmt.Println(w.Body.String())
})
}
path = "/wrapper/grpc/auth"
app.GET(path, func(c *gin.Context) {
wrapperProxy(context.Background(), c, &milvuspb.DescribeCollectionRequest{}, true, false, handle)
})
appNeedAuth.GET(path, func(c *gin.Context) {
ctx := proxy.NewContextWithMetadata(c, "test", DefaultDbName)
wrapperProxy(ctx, c, &milvuspb.LoadCollectionRequest{}, true, false, handle)
})
t.Run("check authorization", func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, path, nil)
w := httptest.NewRecorder()
ginHandler.ServeHTTP(w, req)
assert.Equal(t, http.StatusUnauthorized, w.Code)
returnBody := &ReturnErrMsg{}
err := json.Unmarshal(w.Body.Bytes(), returnBody)
assert.Nil(t, err)
assert.Equal(t, int32(1800), returnBody.Code)
assert.Equal(t, "user hasn't authenticated", returnBody.Message)
fmt.Println(w.Body.String())
paramtable.Get().Save(proxy.Params.CommonCfg.AuthorizationEnabled.Key, "true")
req = httptest.NewRequest(http.MethodGet, needAuthPrefix+path, nil)
req.SetBasicAuth("test", util.DefaultRootPassword)
w = httptest.NewRecorder()
ginHandler.ServeHTTP(w, req)
assert.Equal(t, http.StatusForbidden, w.Code)
err = json.Unmarshal(w.Body.Bytes(), returnBody)
assert.Nil(t, err)
assert.Equal(t, int32(2), returnBody.Code)
assert.Equal(t, "service unavailable: internal: Milvus Proxy is not ready yet. please wait", returnBody.Message)
fmt.Println(w.Body.String())
})
}
type headerTestCase struct {
@ -446,8 +479,8 @@ func TestCreateCollection(t *testing.T) {
"fields": [
{"fieldName": "book_id", "dataType": "Int64", "isPrimary": true, "elementTypeParams": {}},
{"fieldName": "word_count", "dataType": "Int64", "isPartitionKey": false, "elementTypeParams": {}},
{"fieldName": "partition_field", "dataType": "VarChar", "isPartitionKey": true, "elementTypeParams": {"max_length": "256"}},
{"fieldName": "book_intro", "dataType": "FloatVector", "elementTypeParams": {"dim": "2"}}
{"fieldName": "partition_field", "dataType": "VarChar", "isPartitionKey": true, "elementTypeParams": {"max_length": 256}},
{"fieldName": "book_intro", "dataType": "FloatVector", "elementTypeParams": {"dim": 2}}
]
}, "params": {"partitionsNum": "32"}}`),
})
@ -457,7 +490,7 @@ func TestCreateCollection(t *testing.T) {
"fields": [
{"fieldName": "book_id", "dataType": "Int64", "isPrimary": true, "elementTypeParams": {}},
{"fieldName": "word_count", "dataType": "Int64", "elementTypeParams": {}},
{"fieldName": "book_intro", "dataType": "FloatVector", "elementTypeParams": {"dim": "2"}}
{"fieldName": "book_intro", "dataType": "FloatVector", "elementTypeParams": {"dim": 2}}
]
}, "indexParams": [{"fieldName": "book_intro", "indexName": "book_intro_vector", "metricType": "L2"}]}`),
})
@ -467,7 +500,7 @@ func TestCreateCollection(t *testing.T) {
"fields": [
{"fieldName": "book_id", "dataType": "int64", "isPrimary": true, "elementTypeParams": {}},
{"fieldName": "word_count", "dataType": "Int64", "elementTypeParams": {}},
{"fieldName": "book_intro", "dataType": "FloatVector", "elementTypeParams": {"dim": "2"}}
{"fieldName": "book_intro", "dataType": "FloatVector", "elementTypeParams": {"dim": 2}}
]
}}`),
errMsg: "invalid parameter, data type int64 is invalid(case sensitive).",
@ -478,8 +511,8 @@ func TestCreateCollection(t *testing.T) {
requestBody: []byte(`{"collectionName": "` + DefaultCollectionName + `", "schema": {
"fields": [
{"fieldName": "book_id", "dataType": "Int64", "isPrimary": true, "elementTypeParams": {}},
{"fieldName": "word_count", "dataType": "Array", "elementDataType": "Int64", "elementTypeParams": {"max_capacity": "2"}},
{"fieldName": "book_intro", "dataType": "FloatVector", "elementTypeParams": {"dim": "2"}}
{"fieldName": "word_count", "dataType": "Array", "elementDataType": "Int64", "elementTypeParams": {"max_capacity": 2}},
{"fieldName": "book_intro", "dataType": "FloatVector", "elementTypeParams": {"dim": 2}}
]
}}`),
})
@ -489,7 +522,7 @@ func TestCreateCollection(t *testing.T) {
"fields": [
{"fieldName": "book_id", "dataType": "Int64", "isPrimary": true, "elementTypeParams": {}},
{"fieldName": "word_count", "dataType": "Array", "elementDataType": "int64", "elementTypeParams": {}},
{"fieldName": "book_intro", "dataType": "FloatVector", "elementTypeParams": {"dim": "2"}}
{"fieldName": "book_intro", "dataType": "FloatVector", "elementTypeParams": {"dim": 2}}
]
}}`),
errMsg: "invalid parameter, element data type int64 is invalid(case sensitive).",
@ -501,7 +534,7 @@ func TestCreateCollection(t *testing.T) {
"fields": [
{"fieldName": "book_id", "dataType": "Int64", "isPrimary": true, "elementTypeParams": {}},
{"fieldName": "word_count", "dataType": "Int64", "elementTypeParams": {}},
{"fieldName": "book_intro", "dataType": "FloatVector", "elementTypeParams": {"dim": "2"}}
{"fieldName": "book_intro", "dataType": "FloatVector", "elementTypeParams": {"dim": 2}}
]
}, "indexParams": [{"fieldName": "book_xxx", "indexName": "book_intro_vector", "metricType": "L2"}]}`),
errMsg: "missing required parameters, error: `book_xxx` hasn't defined in schema",
@ -519,7 +552,7 @@ func TestCreateCollection(t *testing.T) {
"fields": [
{"fieldName": "book_id", "dataType": "Int64", "isPrimary": true, "elementTypeParams": {}},
{"fieldName": "word_count", "dataType": "Int64", "elementTypeParams": {}},
{"fieldName": "book_intro", "dataType": "FloatVector", "elementTypeParams": {"dim": "2"}}
{"fieldName": "book_intro", "dataType": "FloatVector", "elementTypeParams": {"dim": 2}}
]
}, "indexParams": [{"fieldName": "book_intro", "indexName": "book_intro_vector", "metricType": "L2"}]}`),
errMsg: "",
@ -634,9 +667,11 @@ func TestMethodGet(t *testing.T) {
Schema: generateCollectionSchema(schemapb.DataType_Int64),
ShardsNum: ShardNumDefault,
Status: &StatusSuccess,
}, nil).Once()
}, nil).Twice()
mp.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{Status: commonErrorStatus}, nil).Once()
mp.EXPECT().GetLoadState(mock.Anything, mock.Anything).Return(&DefaultLoadStateResp, nil).Twice()
mp.EXPECT().GetLoadState(mock.Anything, mock.Anything).Return(&milvuspb.GetLoadStateResponse{Status: commonErrorStatus}, nil).Once()
mp.EXPECT().GetLoadState(mock.Anything, mock.Anything).Return(&DefaultLoadStateResp, nil).Times(3)
mp.EXPECT().DescribeIndex(mock.Anything, mock.Anything).Return(&milvuspb.DescribeIndexResponse{Status: commonErrorStatus}, nil).Once()
mp.EXPECT().DescribeIndex(mock.Anything, mock.Anything).Return(&DefaultDescIndexesReqp, nil).Times(3)
mp.EXPECT().DescribeIndex(mock.Anything, mock.Anything).Return(nil, merr.WrapErrIndexNotFoundForCollection(DefaultCollectionName)).Once()
mp.EXPECT().DescribeIndex(mock.Anything, mock.Anything).Return(&milvuspb.DescribeIndexResponse{
@ -658,6 +693,7 @@ func TestMethodGet(t *testing.T) {
Status: commonSuccessStatus,
Progress: int64(77),
}, nil).Once()
mp.EXPECT().GetLoadingProgress(mock.Anything, mock.Anything).Return(&milvuspb.GetLoadingProgressResponse{Status: commonErrorStatus}, nil).Once()
mp.EXPECT().ShowPartitions(mock.Anything, mock.Anything).Return(&milvuspb.ShowPartitionsResponse{
Status: &StatusSuccess,
PartitionNames: []string{DefaultPartitionName},
@ -705,6 +741,7 @@ func TestMethodGet(t *testing.T) {
},
},
}, nil).Once()
mp.EXPECT().ListAliases(mock.Anything, mock.Anything).Return(&milvuspb.ListAliasesResponse{Status: commonErrorStatus}, nil).Once()
mp.EXPECT().ListAliases(mock.Anything, mock.Anything).Return(&milvuspb.ListAliasesResponse{
Status: &StatusSuccess,
}, nil).Once()
@ -736,6 +773,9 @@ func TestMethodGet(t *testing.T) {
queryTestCases = append(queryTestCases, rawTestCase{
path: versionalV2(CollectionCategory, DescribeAction),
})
queryTestCases = append(queryTestCases, rawTestCase{
path: versionalV2(CollectionCategory, DescribeAction),
})
queryTestCases = append(queryTestCases, rawTestCase{
path: versionalV2(CollectionCategory, DescribeAction),
errMsg: "",
@ -750,6 +790,9 @@ func TestMethodGet(t *testing.T) {
queryTestCases = append(queryTestCases, rawTestCase{
path: versionalV2(CollectionCategory, LoadStateAction),
})
queryTestCases = append(queryTestCases, rawTestCase{
path: versionalV2(CollectionCategory, LoadStateAction),
})
queryTestCases = append(queryTestCases, rawTestCase{
path: versionalV2(PartitionCategory, ListAction),
})
@ -993,8 +1036,8 @@ func TestMethodPost(t *testing.T) {
bodyReader := bytes.NewReader([]byte(`{` +
`"collectionName": "` + DefaultCollectionName + `", "newCollectionName": "test", "newDbName": "",` +
`"partitionName": "` + DefaultPartitionName + `", "partitionNames": ["` + DefaultPartitionName + `"],` +
`"schema": {"fields": [{"fieldName": "book_id", "dataType": "Int64", "elementTypeParams": {}}, {"fieldName": "book_intro", "dataType": "FloatVector", "elementTypeParams": {"dim": "2"}}]},` +
`"indexParams": [{"indexName": "` + DefaultIndexName + `", "fieldName": "book_intro", "metricType": "L2", "indexConfig": {"nlist": "30", "index_type": "IVF_FLAT"}}],` +
`"schema": {"fields": [{"fieldName": "book_id", "dataType": "Int64", "elementTypeParams": {}}, {"fieldName": "book_intro", "dataType": "FloatVector", "elementTypeParams": {"dim": 2}}]},` +
`"indexParams": [{"indexName": "` + DefaultIndexName + `", "fieldName": "book_intro", "metricType": "L2", "params": {"nlist": 30, "index_type": "IVF_FLAT"}}],` +
`"userName": "` + util.UserRoot + `", "password": "Milvus", "newPassword": "milvus", "roleName": "` + util.RoleAdmin + `",` +
`"roleName": "` + util.RoleAdmin + `", "objectType": "Global", "objectName": "*", "privilege": "*",` +
`"aliasName": "` + DefaultAliasName + `",` +

View File

@ -264,10 +264,10 @@ type GrantReq struct {
}
type IndexParam struct {
FieldName string `json:"fieldName" binding:"required"`
IndexName string `json:"indexName" binding:"required"`
MetricType string `json:"metricType" binding:"required"`
IndexConfig map[string]string `json:"indexConfig"`
FieldName string `json:"fieldName" binding:"required"`
IndexName string `json:"indexName" binding:"required"`
MetricType string `json:"metricType" binding:"required"`
Params map[string]interface{} `json:"params"`
}
type IndexParamReq struct {
@ -294,12 +294,12 @@ func (req *IndexReq) GetIndexName() string {
}
type FieldSchema struct {
FieldName string `json:"fieldName" binding:"required"`
DataType string `json:"dataType" binding:"required"`
ElementDataType string `json:"elementDataType"`
IsPrimary bool `json:"isPrimary"`
IsPartitionKey bool `json:"isPartitionKey"`
ElementTypeParams map[string]string `json:"elementTypeParams" binding:"required"`
FieldName string `json:"fieldName" binding:"required"`
DataType string `json:"dataType" binding:"required"`
ElementDataType string `json:"elementDataType"`
IsPrimary bool `json:"isPrimary"`
IsPartitionKey bool `json:"isPartitionKey"`
ElementTypeParams map[string]interface{} `json:"elementTypeParams" binding:"required"`
}
type CollectionSchema struct {

View File

@ -72,9 +72,9 @@ class TestCreateIndex(TestBase):
"metricType": f"{metric_type}"}]
}
if index_type == "HNSW":
payload["indexParams"][0]["indexConfig"] = {"index_type": "HNSW", "M": "16", "efConstruction": "200"}
payload["indexParams"][0]["params"] = {"index_type": "HNSW", "M": "16", "efConstruction": "200"}
if index_type == "AUTOINDEX":
payload["indexParams"][0]["indexConfig"] = {"index_type": "AUTOINDEX"}
payload["indexParams"][0]["params"] = {"index_type": "AUTOINDEX"}
rsp = self.index_client.index_create(payload)
assert rsp['code'] == 200
time.sleep(10)
@ -90,7 +90,7 @@ class TestCreateIndex(TestBase):
assert expected_index[i]['fieldName'] == actual_index[i]['fieldName']
assert expected_index[i]['indexName'] == actual_index[i]['indexName']
assert expected_index[i]['metricType'] == actual_index[i]['metricType']
assert expected_index[i]["indexConfig"]['index_type'] == actual_index[i]['indexType']
assert expected_index[i]["params"]['index_type'] == actual_index[i]['indexType']
# drop index
for i in range(len(actual_index)):
@ -153,7 +153,7 @@ class TestCreateIndex(TestBase):
payload = {
"collectionName": name,
"indexParams": [{"fieldName": "word_count", "indexName": "word_count_vector",
"indexConfig": {"index_type": "INVERTED"}}]
"params": {"index_type": "INVERTED"}}]
}
rsp = self.index_client.index_create(payload)
assert rsp['code'] == 200
@ -169,7 +169,7 @@ class TestCreateIndex(TestBase):
for i in range(len(expected_index)):
assert expected_index[i]['fieldName'] == actual_index[i]['fieldName']
assert expected_index[i]['indexName'] == actual_index[i]['indexName']
assert expected_index[i]['indexConfig']['index_type'] == actual_index[i]['indexType']
assert expected_index[i]['params']['index_type'] == actual_index[i]['indexType']
@pytest.mark.parametrize("index_type", ["BIN_FLAT", "BIN_IVF_FLAT"])
@pytest.mark.parametrize("metric_type", ["JACCARD", "HAMMING"])
@ -221,10 +221,10 @@ class TestCreateIndex(TestBase):
payload = {
"collectionName": name,
"indexParams": [{"fieldName": "binary_vector", "indexName": index_name, "metricType": metric_type,
"indexConfig": {"index_type": index_type}}]
"params": {"index_type": index_type}}]
}
if index_type == "BIN_IVF_FLAT":
payload["indexParams"][0]["indexConfig"]["nlist"] = "16384"
payload["indexParams"][0]["params"]["nlist"] = "16384"
rsp = self.index_client.index_create(payload)
assert rsp['code'] == 200
time.sleep(10)
@ -239,7 +239,7 @@ class TestCreateIndex(TestBase):
for i in range(len(expected_index)):
assert expected_index[i]['fieldName'] == actual_index[i]['fieldName']
assert expected_index[i]['indexName'] == actual_index[i]['indexName']
assert expected_index[i]['indexConfig']['index_type'] == actual_index[i]['indexType']
assert expected_index[i]['params']['index_type'] == actual_index[i]['indexType']
@pytest.mark.L1
@ -292,10 +292,10 @@ class TestCreateIndexNegative(TestBase):
payload = {
"collectionName": name,
"indexParams": [{"fieldName": "binary_vector", "indexName": index_name, "metricType": metric_type,
"indexConfig": {"index_type": index_type}}]
"params": {"index_type": index_type}}]
}
if index_type == "BIN_IVF_FLAT":
payload["indexParams"][0]["indexConfig"]["nlist"] = "16384"
payload["indexParams"][0]["params"]["nlist"] = "16384"
rsp = self.index_client.index_create(payload)
assert rsp['code'] == 1100
assert "not supported" in rsp['message']

View File

@ -183,7 +183,7 @@ class TestInsertVector(TestBase):
{"fieldName": "float16_vector", "indexName": "float16_vector", "metricType": "L2"},
{"fieldName": "bfloat16_vector", "indexName": "bfloat16_vector", "metricType": "L2"},
{"fieldName": "binary_vector", "indexName": "binary_vector", "metricType": "HAMMING",
"indexConfig": {"index_type": "BIN_IVF_FLAT", "nlist": "512"}}
"params": {"index_type": "BIN_IVF_FLAT", "nlist": "512"}}
]
}
rsp = self.collection_client.collection_create(payload)
@ -572,7 +572,7 @@ class TestSearchVector(TestBase):
{"fieldName": "float16_vector", "indexName": "float16_vector", "metricType": "COSINE"},
{"fieldName": "bfloat16_vector", "indexName": "bfloat16_vector", "metricType": "COSINE"},
{"fieldName": "binary_vector", "indexName": "binary_vector", "metricType": "HAMMING",
"indexConfig": {"index_type": "BIN_IVF_FLAT", "nlist": "512"}}
"params": {"index_type": "BIN_IVF_FLAT", "nlist": "512"}}
]
}
rsp = self.collection_client.collection_create(payload)
@ -751,7 +751,7 @@ class TestSearchVector(TestBase):
},
"indexParams": [
{"fieldName": "binary_vector", "indexName": "binary_vector", "metricType": "HAMMING",
"indexConfig": {"index_type": "BIN_IVF_FLAT", "nlist": "512"}}
"params": {"index_type": "BIN_IVF_FLAT", "nlist": "512"}}
]
}
rsp = self.collection_client.collection_create(payload)
@ -1511,7 +1511,7 @@ class TestQueryVector(TestBase):
{"fieldName": "float16_vector", "indexName": "float16_vector", "metricType": "L2"},
{"fieldName": "bfloat16_vector", "indexName": "bfloat16_vector", "metricType": "L2"},
{"fieldName": "binary_vector", "indexName": "binary_vector", "metricType": "HAMMING",
"indexConfig": {"index_type": "BIN_IVF_FLAT", "nlist": "512"}}
"params": {"index_type": "BIN_IVF_FLAT", "nlist": "512"}}
]
}
rsp = self.collection_client.collection_create(payload)