mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-12-01 11:29:48 +08:00
Add index params check in proxy (#5958)
Signed-off-by: dragondriver <jiquan.long@zilliz.com>
This commit is contained in:
parent
2926a78968
commit
fd57554d32
@ -26,6 +26,8 @@ import (
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/util/indexparamcheck"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/planpb"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/util/funcutil"
|
||||
@ -3121,6 +3123,42 @@ func (cit *CreateIndexTask) PreExecute(ctx context.Context) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// check index param, not accurate, only some static rules
|
||||
indexParams := make(map[string]string)
|
||||
for _, kv := range cit.CreateIndexRequest.ExtraParams {
|
||||
if kv.Key == "params" { // TODO(dragondriver): change `params` to const variable
|
||||
params, err := funcutil.ParseIndexParamsMap(kv.Value)
|
||||
if err != nil {
|
||||
log.Warn("Failed to parse index params",
|
||||
zap.String("params", kv.Value),
|
||||
zap.Error(err))
|
||||
continue
|
||||
}
|
||||
for k, v := range params {
|
||||
indexParams[k] = v
|
||||
}
|
||||
} else {
|
||||
indexParams[kv.Key] = kv.Value
|
||||
}
|
||||
}
|
||||
|
||||
indexType, exist := indexParams["index_type"] // TODO(dragondriver): change `index_type` to const variable
|
||||
if !exist {
|
||||
indexType = indexparamcheck.IndexFaissIvfPQ // IVF_PQ is the default index type
|
||||
}
|
||||
|
||||
adapter, err := indexparamcheck.GetConfAdapterMgrInstance().GetAdapter(indexType)
|
||||
if err != nil {
|
||||
log.Warn("Failed to get conf adapter", zap.String("index_type", indexType))
|
||||
return fmt.Errorf("invalid index type: %s", indexType)
|
||||
}
|
||||
|
||||
ok := adapter.CheckTrain(indexParams)
|
||||
if !ok {
|
||||
log.Warn("Create index with invalid params", zap.Any("index_params", indexParams))
|
||||
return fmt.Errorf("invalid index params: %v", cit.CreateIndexRequest.ExtraParams)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -16,14 +16,12 @@ import (
|
||||
"errors"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
"github.com/milvus-io/milvus/internal/proto/commonpb"
|
||||
"github.com/milvus-io/milvus/internal/util/funcutil"
|
||||
"github.com/milvus-io/milvus/internal/util/retry"
|
||||
)
|
||||
|
||||
@ -71,29 +69,6 @@ func getMin(a, b int) int {
|
||||
return b
|
||||
}
|
||||
|
||||
func CheckIntByRange(params map[string]string, key string, min, max int) bool {
|
||||
valueStr, ok := params[key]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
value, err := strconv.Atoi(valueStr)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return value >= min && value <= max
|
||||
}
|
||||
|
||||
func CheckStrByValues(params map[string]string, key string, container []string) bool {
|
||||
value, ok := params[key]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
return funcutil.SliceContain(container, value)
|
||||
}
|
||||
|
||||
func GetAttrByKeyFromRepeatedKV(key string, kvs []*commonpb.KeyValuePair) (string, error) {
|
||||
for _, kv := range kvs {
|
||||
if kv.Key == key {
|
||||
|
@ -50,87 +50,6 @@ func TestGetPulsarConfig(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckIntByRange(t *testing.T) {
|
||||
params := map[string]string{
|
||||
"1": strconv.Itoa(1),
|
||||
"2": strconv.Itoa(2),
|
||||
"3": strconv.Itoa(3),
|
||||
"s1": "s1",
|
||||
"s2": "s2",
|
||||
"s3": "s3",
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
params map[string]string
|
||||
key string
|
||||
min int
|
||||
max int
|
||||
want bool
|
||||
}{
|
||||
{params, "1", 0, 4, true},
|
||||
{params, "2", 0, 4, true},
|
||||
{params, "3", 0, 4, true},
|
||||
{params, "1", 4, 5, false},
|
||||
{params, "2", 4, 5, false},
|
||||
{params, "3", 4, 5, false},
|
||||
{params, "4", 0, 4, false},
|
||||
{params, "5", 0, 4, false},
|
||||
{params, "6", 0, 4, false},
|
||||
{params, "s1", 0, 4, false},
|
||||
{params, "s2", 0, 4, false},
|
||||
{params, "s3", 0, 4, false},
|
||||
{params, "s4", 0, 4, false},
|
||||
{params, "s5", 0, 4, false},
|
||||
{params, "s6", 0, 4, false},
|
||||
}
|
||||
|
||||
for _, test := range cases {
|
||||
if got := CheckIntByRange(test.params, test.key, test.min, test.max); got != test.want {
|
||||
t.Errorf("CheckIntByRange(%v, %v, %v, %v) = %v", test.params, test.key, test.min, test.max, test.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckStrByValues(t *testing.T) {
|
||||
params := map[string]string{
|
||||
"1": strconv.Itoa(1),
|
||||
"2": strconv.Itoa(2),
|
||||
"3": strconv.Itoa(3),
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
params map[string]string
|
||||
key string
|
||||
container []string
|
||||
want bool
|
||||
}{
|
||||
{params, "1", []string{"1", "2", "3"}, true},
|
||||
{params, "2", []string{"1", "2", "3"}, true},
|
||||
{params, "3", []string{"1", "2", "3"}, true},
|
||||
{params, "1", []string{"4", "5", "6"}, false},
|
||||
{params, "2", []string{"4", "5", "6"}, false},
|
||||
{params, "3", []string{"4", "5", "6"}, false},
|
||||
{params, "1", []string{}, false},
|
||||
{params, "2", []string{}, false},
|
||||
{params, "3", []string{}, false},
|
||||
{params, "4", []string{"1", "2", "3"}, false},
|
||||
{params, "5", []string{"1", "2", "3"}, false},
|
||||
{params, "6", []string{"1", "2", "3"}, false},
|
||||
{params, "4", []string{"4", "5", "6"}, false},
|
||||
{params, "5", []string{"4", "5", "6"}, false},
|
||||
{params, "6", []string{"4", "5", "6"}, false},
|
||||
{params, "4", []string{}, false},
|
||||
{params, "5", []string{}, false},
|
||||
{params, "6", []string{}, false},
|
||||
}
|
||||
|
||||
for _, test := range cases {
|
||||
if got := CheckStrByValues(test.params, test.key, test.container); got != test.want {
|
||||
t.Errorf("CheckStrByValues(%v, %v, %v) = %v", test.params, test.key, test.container, test.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAttrByKeyFromRepeatedKV(t *testing.T) {
|
||||
kvs := []*commonpb.KeyValuePair{
|
||||
{Key: "Key1", Value: "Value1"},
|
||||
|
@ -9,7 +9,7 @@
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
package proxy
|
||||
package indexparamcheck
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
@ -87,7 +87,7 @@ const (
|
||||
)
|
||||
|
||||
var METRICS = []string{L2, IP} // const
|
||||
var BinIDMapMetrics = []string{HAMMING, JACCARD, TANIMOTO, SUBSTRUCTURE, SUBSTRUCTURE} // const
|
||||
var BinIDMapMetrics = []string{HAMMING, JACCARD, TANIMOTO, SUBSTRUCTURE, SUPERSTRUCTURE} // const
|
||||
var BinIvfMetrics = []string{HAMMING, JACCARD, TANIMOTO} // const
|
||||
var supportDimPerSubQuantizer = []int{32, 28, 24, 20, 16, 12, 10, 8, 6, 4, 3, 2, 1} // const
|
||||
var supportSubQuantizer = []int{96, 64, 56, 48, 40, 32, 28, 24, 20, 16, 12, 8, 4, 3, 2, 1} // const
|
||||
@ -100,9 +100,10 @@ type BaseConfAdapter struct {
|
||||
}
|
||||
|
||||
func (adapter *BaseConfAdapter) CheckTrain(params map[string]string) bool {
|
||||
if !CheckIntByRange(params, DIM, DefaultMinDim, DefaultMaxDim) {
|
||||
return false
|
||||
}
|
||||
// dimension is specified when create collection
|
||||
//if !CheckIntByRange(params, DIM, DefaultMinDim, DefaultMaxDim) {
|
||||
// return false
|
||||
//}
|
||||
|
||||
return CheckStrByValues(params, Metric, METRICS)
|
||||
}
|
||||
@ -138,7 +139,9 @@ func (adapter *IVFPQConfAdapter) CheckTrain(params map[string]string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
if !CheckIntByRange(params, NBITS, MinNBits, MaxNBits) {
|
||||
// nbits can be set to default: 8
|
||||
_, nbitsExist := params[NBITS]
|
||||
if nbitsExist && !CheckIntByRange(params, NBITS, MinNBits, MaxNBits) {
|
||||
return false
|
||||
}
|
||||
|
||||
@ -205,9 +208,10 @@ type BinIDMAPConfAdapter struct {
|
||||
}
|
||||
|
||||
func (adapter *BinIDMAPConfAdapter) CheckTrain(params map[string]string) bool {
|
||||
if !CheckIntByRange(params, DIM, DefaultMinDim, DefaultMaxDim) {
|
||||
return false
|
||||
}
|
||||
// dimension is specified when create collection
|
||||
//if !CheckIntByRange(params, DIM, DefaultMinDim, DefaultMaxDim) {
|
||||
// return false
|
||||
//}
|
||||
|
||||
return CheckStrByValues(params, Metric, BinIDMapMetrics)
|
||||
}
|
||||
@ -220,9 +224,10 @@ type BinIVFConfAdapter struct {
|
||||
}
|
||||
|
||||
func (adapter *BinIVFConfAdapter) CheckTrain(params map[string]string) bool {
|
||||
if !CheckIntByRange(params, DIM, DefaultMinDim, DefaultMaxDim) {
|
||||
return false
|
||||
}
|
||||
// dimension is specified when create collection
|
||||
//if !CheckIntByRange(params, DIM, DefaultMinDim, DefaultMaxDim) {
|
||||
// return false
|
||||
//}
|
||||
|
||||
if !CheckIntByRange(params, NLIST, MinNList, MaxNList) {
|
||||
return false
|
@ -9,7 +9,7 @@
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
package proxy
|
||||
package indexparamcheck
|
||||
|
||||
import (
|
||||
"errors"
|
@ -9,7 +9,7 @@
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
package proxy
|
||||
package indexparamcheck
|
||||
|
||||
import (
|
||||
"testing"
|
@ -9,7 +9,7 @@
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
package proxy
|
||||
package indexparamcheck
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
@ -68,11 +68,19 @@ func TestIVFPQConfAdapter_CheckTrain(t *testing.T) {
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: L2,
|
||||
}
|
||||
validParamsWithoutNbits := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(1024),
|
||||
IVFM: strconv.Itoa(4),
|
||||
Metric: L2,
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
params map[string]string
|
||||
want bool
|
||||
}{
|
||||
{validParams, true},
|
||||
{validParamsWithoutNbits, true},
|
||||
}
|
||||
|
||||
adapter := newIVFPQConfAdapter()
|
@ -9,7 +9,7 @@
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
package proxy
|
||||
package indexparamcheck
|
||||
|
||||
type IndexType = string
|
||||
|
30
internal/util/indexparamcheck/utils.go
Normal file
30
internal/util/indexparamcheck/utils.go
Normal file
@ -0,0 +1,30 @@
|
||||
package indexparamcheck
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/util/funcutil"
|
||||
)
|
||||
|
||||
func CheckIntByRange(params map[string]string, key string, min, max int) bool {
|
||||
valueStr, ok := params[key]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
value, err := strconv.Atoi(valueStr)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return value >= min && value <= max
|
||||
}
|
||||
|
||||
func CheckStrByValues(params map[string]string, key string, container []string) bool {
|
||||
value, ok := params[key]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
return funcutil.SliceContain(container, value)
|
||||
}
|
87
internal/util/indexparamcheck/utils_test.go
Normal file
87
internal/util/indexparamcheck/utils_test.go
Normal file
@ -0,0 +1,87 @@
|
||||
package indexparamcheck
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCheckIntByRange(t *testing.T) {
|
||||
params := map[string]string{
|
||||
"1": strconv.Itoa(1),
|
||||
"2": strconv.Itoa(2),
|
||||
"3": strconv.Itoa(3),
|
||||
"s1": "s1",
|
||||
"s2": "s2",
|
||||
"s3": "s3",
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
params map[string]string
|
||||
key string
|
||||
min int
|
||||
max int
|
||||
want bool
|
||||
}{
|
||||
{params, "1", 0, 4, true},
|
||||
{params, "2", 0, 4, true},
|
||||
{params, "3", 0, 4, true},
|
||||
{params, "1", 4, 5, false},
|
||||
{params, "2", 4, 5, false},
|
||||
{params, "3", 4, 5, false},
|
||||
{params, "4", 0, 4, false},
|
||||
{params, "5", 0, 4, false},
|
||||
{params, "6", 0, 4, false},
|
||||
{params, "s1", 0, 4, false},
|
||||
{params, "s2", 0, 4, false},
|
||||
{params, "s3", 0, 4, false},
|
||||
{params, "s4", 0, 4, false},
|
||||
{params, "s5", 0, 4, false},
|
||||
{params, "s6", 0, 4, false},
|
||||
}
|
||||
|
||||
for _, test := range cases {
|
||||
if got := CheckIntByRange(test.params, test.key, test.min, test.max); got != test.want {
|
||||
t.Errorf("CheckIntByRange(%v, %v, %v, %v) = %v", test.params, test.key, test.min, test.max, test.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckStrByValues(t *testing.T) {
|
||||
params := map[string]string{
|
||||
"1": strconv.Itoa(1),
|
||||
"2": strconv.Itoa(2),
|
||||
"3": strconv.Itoa(3),
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
params map[string]string
|
||||
key string
|
||||
container []string
|
||||
want bool
|
||||
}{
|
||||
{params, "1", []string{"1", "2", "3"}, true},
|
||||
{params, "2", []string{"1", "2", "3"}, true},
|
||||
{params, "3", []string{"1", "2", "3"}, true},
|
||||
{params, "1", []string{"4", "5", "6"}, false},
|
||||
{params, "2", []string{"4", "5", "6"}, false},
|
||||
{params, "3", []string{"4", "5", "6"}, false},
|
||||
{params, "1", []string{}, false},
|
||||
{params, "2", []string{}, false},
|
||||
{params, "3", []string{}, false},
|
||||
{params, "4", []string{"1", "2", "3"}, false},
|
||||
{params, "5", []string{"1", "2", "3"}, false},
|
||||
{params, "6", []string{"1", "2", "3"}, false},
|
||||
{params, "4", []string{"4", "5", "6"}, false},
|
||||
{params, "5", []string{"4", "5", "6"}, false},
|
||||
{params, "6", []string{"4", "5", "6"}, false},
|
||||
{params, "4", []string{}, false},
|
||||
{params, "5", []string{}, false},
|
||||
{params, "6", []string{}, false},
|
||||
}
|
||||
|
||||
for _, test := range cases {
|
||||
if got := CheckStrByValues(test.params, test.key, test.container); got != test.want {
|
||||
t.Errorf("CheckStrByValues(%v, %v, %v) = %v", test.params, test.key, test.container, test.want)
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user