diff --git a/internal/distributed/proxy/httpserver/constant.go b/internal/distributed/proxy/httpserver/constant.go index 93e56393aa..4653b9231d 100644 --- a/internal/distributed/proxy/httpserver/constant.go +++ b/internal/distributed/proxy/httpserver/constant.go @@ -54,5 +54,7 @@ const ( ParamRoundDecimal = "round_decimal" ParamOffset = "offset" ParamLimit = "limit" + ParamRadius = "radius" + ParamRangeFilter = "range_filter" BoundedTimestamp = 2 ) diff --git a/internal/distributed/proxy/httpserver/handler_v1.go b/internal/distributed/proxy/httpserver/handler_v1.go index 033cfb70ae..1bbb4cfe48 100644 --- a/internal/distributed/proxy/httpserver/handler_v1.go +++ b/internal/distributed/proxy/httpserver/handler_v1.go @@ -862,6 +862,24 @@ func (h *Handlers) search(c *gin.Context) { params := map[string]interface{}{ // auto generated mapping "level": int(commonpb.ConsistencyLevel_Bounded), } + if httpReq.Params != nil { + radius, radiusOk := httpReq.Params[ParamRadius] + rangeFilter, rangeFilterOk := httpReq.Params[ParamRangeFilter] + if rangeFilterOk { + if !radiusOk { + log.Warn("high level restful api, search params invalid, because only " + ParamRangeFilter) + c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(merr.ErrIncorrectParameterFormat), + HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error() + ", error: invalid search params", + }) + return + } + params[ParamRangeFilter] = rangeFilter + } + if radiusOk { + params[ParamRadius] = radius + } + } bs, _ := json.Marshal(params) searchParams := []*commonpb.KeyValuePair{ {Key: common.TopKKey, Value: strconv.FormatInt(int64(httpReq.Limit), 10)}, diff --git a/internal/distributed/proxy/httpserver/handler_v1_test.go b/internal/distributed/proxy/httpserver/handler_v1_test.go index 85eaabe1f3..1795be20dc 100644 --- a/internal/distributed/proxy/httpserver/handler_v1_test.go +++ b/internal/distributed/proxy/httpserver/handler_v1_test.go @@ -1294,6 +1294,38 @@ func TestSearch(t *testing.T) { } }) } + mp := mocks.NewMockProxy(t) + mp.EXPECT().Search(mock.Anything, mock.Anything).Return(&milvuspb.SearchResults{ + Status: &StatusSuccess, + Results: &schemapb.SearchResultData{ + FieldsData: generateFieldData(), + Scores: []float32{0.01, 0.04, 0.09}, + TopK: 3, + }, + }, nil).Once() + tt := testCase{ + name: "search success with params", + mp: mp, + exceptCode: 200, + } + t.Run(tt.name, func(t *testing.T) { + testEngine := initHTTPServer(tt.mp, true) + rows := []float32{0.0, 0.0} + data, _ := json.Marshal(map[string]interface{}{ + HTTPCollectionName: DefaultCollectionName, + "vector": rows, + Params: map[string]float64{ + ParamRadius: 0.9, + ParamRangeFilter: 0.1, + }, + }) + bodyReader := bytes.NewReader(data) + req := httptest.NewRequest(http.MethodPost, versional(VectorSearchPath), bodyReader) + req.SetBasicAuth(util.UserRoot, util.DefaultRootPassword) + w := httptest.NewRecorder() + testEngine.ServeHTTP(w, req) + assert.Equal(t, tt.exceptCode, w.Code) + }) } type ReturnType int @@ -1405,12 +1437,14 @@ func TestHttpRequestFormat(t *testing.T) { merr.ErrMissingRequiredParameters, merr.ErrMissingRequiredParameters, merr.ErrMissingRequiredParameters, + merr.ErrIncorrectParameterFormat, } requestJsons := [][]byte{ []byte(`{"collectionName": {"` + DefaultCollectionName + `", "dimension": 2}`), []byte(`{"collName": "` + DefaultCollectionName + `", "dimension": 2}`), []byte(`{"collName": "` + DefaultCollectionName + `", "dim": 2}`), []byte(`{"collectionName": "` + DefaultCollectionName + `"}`), + []byte(`{"collectionName": "` + DefaultCollectionName + `", "vector": [0.0, 0.0], "` + Params + `": {"` + ParamRangeFilter + `": 0.1}}`), } paths := [][]string{ { @@ -1439,6 +1473,8 @@ func TestHttpRequestFormat(t *testing.T) { versional(VectorInsertPath), versional(VectorUpsertPath), versional(VectorDeletePath), + }, { + versional(VectorSearchPath), }, } for i, pathArr := range paths { diff --git a/internal/distributed/proxy/httpserver/request.go b/internal/distributed/proxy/httpserver/request.go index 0ffded9104..228fe3be1d 100644 --- a/internal/distributed/proxy/httpserver/request.go +++ b/internal/distributed/proxy/httpserver/request.go @@ -63,11 +63,12 @@ type SingleUpsertReq struct { } type SearchReq struct { - DbName string `json:"dbName"` - CollectionName string `json:"collectionName" validate:"required"` - Filter string `json:"filter"` - Limit int32 `json:"limit"` - Offset int32 `json:"offset"` - OutputFields []string `json:"outputFields"` - Vector []float32 `json:"vector"` + DbName string `json:"dbName"` + CollectionName string `json:"collectionName" validate:"required"` + Filter string `json:"filter"` + Limit int32 `json:"limit"` + Offset int32 `json:"offset"` + OutputFields []string `json:"outputFields"` + Vector []float32 `json:"vector"` + Params map[string]float64 `json:"params"` }