From ab2d8dd0a72552534f90d4e8d0183fb440e4ca71 Mon Sep 17 00:00:00 2001 From: Gao Date: Thu, 21 Sep 2023 14:51:25 +0800 Subject: [PATCH] Add ScaNN index param checker (#27268) Signed-off-by: chasingegg --- pkg/util/indexparamcheck/conf_adapter_mgr.go | 2 +- .../indexparamcheck/conf_adapter_mgr_test.go | 4 +- pkg/util/indexparamcheck/ivf_pq_checker.go | 2 +- .../indexparamcheck/raft_ivf_pq_checker.go | 2 +- pkg/util/indexparamcheck/scann_checker.go | 41 +++++ .../indexparamcheck/scann_checker_test.go | 170 ++++++++++++++++++ 6 files changed, 216 insertions(+), 5 deletions(-) create mode 100644 pkg/util/indexparamcheck/scann_checker.go create mode 100644 pkg/util/indexparamcheck/scann_checker_test.go diff --git a/pkg/util/indexparamcheck/conf_adapter_mgr.go b/pkg/util/indexparamcheck/conf_adapter_mgr.go index dd60ae638a..5099da0ca3 100644 --- a/pkg/util/indexparamcheck/conf_adapter_mgr.go +++ b/pkg/util/indexparamcheck/conf_adapter_mgr.go @@ -48,7 +48,7 @@ func (mgr *indexCheckerMgrImpl) registerIndexChecker() { mgr.checkers[IndexFaissIDMap] = newFlatChecker() mgr.checkers[IndexFaissIvfFlat] = newIVFBaseChecker() mgr.checkers[IndexFaissIvfPQ] = newIVFPQChecker() - mgr.checkers[IndexScaNN] = newIVFBaseChecker() + mgr.checkers[IndexScaNN] = newScaNNChecker() mgr.checkers[IndexFaissIvfSQ8] = newIVFSQChecker() mgr.checkers[IndexFaissBinIDMap] = newBinFlatChecker() mgr.checkers[IndexFaissBinIvfFlat] = newBinIVFFlatChecker() diff --git a/pkg/util/indexparamcheck/conf_adapter_mgr_test.go b/pkg/util/indexparamcheck/conf_adapter_mgr_test.go index 370a98e2c2..6ab9469ee5 100644 --- a/pkg/util/indexparamcheck/conf_adapter_mgr_test.go +++ b/pkg/util/indexparamcheck/conf_adapter_mgr_test.go @@ -44,7 +44,7 @@ func Test_GetConfAdapterMgrInstance(t *testing.T) { adapter, err = adapterMgr.GetChecker(IndexScaNN) assert.Equal(t, nil, err) assert.NotEqual(t, nil, adapter) - _, ok = adapter.(*ivfBaseChecker) + _, ok = adapter.(*scaNNChecker) assert.Equal(t, true, ok) adapter, err = adapterMgr.GetChecker(IndexFaissIvfPQ) @@ -104,7 +104,7 @@ func TestConfAdapterMgrImpl_GetAdapter(t *testing.T) { adapter, err = adapterMgr.GetChecker(IndexScaNN) assert.Equal(t, nil, err) assert.NotEqual(t, nil, adapter) - _, ok = adapter.(*ivfBaseChecker) + _, ok = adapter.(*scaNNChecker) assert.Equal(t, true, ok) adapter, err = adapterMgr.GetChecker(IndexFaissIvfPQ) diff --git a/pkg/util/indexparamcheck/ivf_pq_checker.go b/pkg/util/indexparamcheck/ivf_pq_checker.go index 07c830f26e..51da64e0ff 100644 --- a/pkg/util/indexparamcheck/ivf_pq_checker.go +++ b/pkg/util/indexparamcheck/ivf_pq_checker.go @@ -53,7 +53,7 @@ func (c *ivfPQChecker) checkPQParams(params map[string]string) error { func (c *ivfPQChecker) checkCPUPQParams(dimension, m int) error { if (dimension % m) != 0 { - return fmt.Errorf("dimension must be abled to be divided by `m`, dimension: %d, m: %d", dimension, m) + return fmt.Errorf("dimension must be able to be divided by `m`, dimension: %d, m: %d", dimension, m) } return nil } diff --git a/pkg/util/indexparamcheck/raft_ivf_pq_checker.go b/pkg/util/indexparamcheck/raft_ivf_pq_checker.go index 8ce89d72e4..65f6d1d1b7 100644 --- a/pkg/util/indexparamcheck/raft_ivf_pq_checker.go +++ b/pkg/util/indexparamcheck/raft_ivf_pq_checker.go @@ -53,7 +53,7 @@ func (c *raftIVFPQChecker) checkPQParams(params map[string]string) error { return nil } if dimension%m != 0 { - return fmt.Errorf("dimension must be abled to be divided by `m`, dimension: %d, m: %d", dimension, m) + return fmt.Errorf("dimension must be able to be divided by `m`, dimension: %d, m: %d", dimension, m) } return nil } diff --git a/pkg/util/indexparamcheck/scann_checker.go b/pkg/util/indexparamcheck/scann_checker.go new file mode 100644 index 0000000000..eecf2ded64 --- /dev/null +++ b/pkg/util/indexparamcheck/scann_checker.go @@ -0,0 +1,41 @@ +package indexparamcheck + +import ( + "fmt" + "strconv" +) + +// scaNNChecker checks if a SCANN index can be built. +type scaNNChecker struct { + ivfBaseChecker +} + +// CheckTrain checks if SCANN index can be built with the specific index parameters. +func (c *scaNNChecker) CheckTrain(params map[string]string) error { + if err := c.ivfBaseChecker.CheckTrain(params); err != nil { + return err + } + + return c.checkScaNNParams(params) +} + +func (c *scaNNChecker) checkScaNNParams(params map[string]string) error { + dimStr, dimensionExist := params[DIM] + if !dimensionExist { + return fmt.Errorf("dimension not found") + } + + dimension, err := strconv.Atoi(dimStr) + if err != nil { // invalid dimension + return fmt.Errorf("invalid dimension: %s", dimStr) + } + + if (dimension % 2) != 0 { + return fmt.Errorf("dimension must be able to be divided by 2, dimension: %d", dimension) + } + return nil +} + +func newScaNNChecker() IndexChecker { + return &scaNNChecker{} +} diff --git a/pkg/util/indexparamcheck/scann_checker_test.go b/pkg/util/indexparamcheck/scann_checker_test.go new file mode 100644 index 0000000000..91a1f6c6a0 --- /dev/null +++ b/pkg/util/indexparamcheck/scann_checker_test.go @@ -0,0 +1,170 @@ +package indexparamcheck + +import ( + "strconv" + "testing" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/util/metric" + + "github.com/stretchr/testify/assert" +) + +func Test_scaNNChecker_CheckTrain(t *testing.T) { + validParams := map[string]string{ + DIM: strconv.Itoa(128), + NLIST: strconv.Itoa(1024), + Metric: metric.L2, + } + + paramsNotMultiplier := map[string]string{ + DIM: strconv.Itoa(127), + NLIST: strconv.Itoa(1024), + Metric: metric.L2, + } + + validParamsWithoutDim := map[string]string{ + NLIST: strconv.Itoa(1024), + Metric: metric.L2, + } + + invalidParamsDim := copyParams(validParams) + invalidParamsDim[DIM] = "NAN" + + p1 := map[string]string{ + DIM: strconv.Itoa(128), + NLIST: strconv.Itoa(1024), + Metric: metric.L2, + } + p2 := map[string]string{ + DIM: strconv.Itoa(128), + NLIST: strconv.Itoa(1024), + Metric: metric.IP, + } + p3 := map[string]string{ + DIM: strconv.Itoa(128), + NLIST: strconv.Itoa(1024), + Metric: metric.COSINE, + } + + p4 := map[string]string{ + DIM: strconv.Itoa(128), + NLIST: strconv.Itoa(1024), + Metric: metric.HAMMING, + } + p5 := map[string]string{ + DIM: strconv.Itoa(128), + NLIST: strconv.Itoa(1024), + Metric: metric.JACCARD, + } + p6 := map[string]string{ + DIM: strconv.Itoa(128), + NLIST: strconv.Itoa(1024), + Metric: metric.SUBSTRUCTURE, + } + p7 := map[string]string{ + DIM: strconv.Itoa(128), + NLIST: strconv.Itoa(1024), + Metric: metric.SUPERSTRUCTURE, + } + + cases := []struct { + params map[string]string + errIsNil bool + }{ + {validParams, true}, + {paramsNotMultiplier, false}, + {invalidIVFParamsMin(), false}, + {invalidIVFParamsMax(), false}, + {validParamsWithoutDim, false}, + {invalidParamsDim, false}, + {p1, true}, + {p2, true}, + {p3, true}, + {p4, false}, + {p5, false}, + {p6, false}, + {p7, false}, + } + + c := newScaNNChecker() + for _, test := range cases { + err := c.CheckTrain(test.params) + if test.errIsNil { + assert.NoError(t, err) + } else { + assert.Error(t, err) + } + } +} + +func Test_scaNNChecker_CheckValidDataType(t *testing.T) { + + cases := []struct { + dType schemapb.DataType + errIsNil bool + }{ + { + dType: schemapb.DataType_Bool, + errIsNil: false, + }, + { + dType: schemapb.DataType_Int8, + errIsNil: false, + }, + { + dType: schemapb.DataType_Int16, + errIsNil: false, + }, + { + dType: schemapb.DataType_Int32, + errIsNil: false, + }, + { + dType: schemapb.DataType_Int64, + errIsNil: false, + }, + { + dType: schemapb.DataType_Float, + errIsNil: false, + }, + { + dType: schemapb.DataType_Double, + errIsNil: false, + }, + { + dType: schemapb.DataType_String, + errIsNil: false, + }, + { + dType: schemapb.DataType_VarChar, + errIsNil: false, + }, + { + dType: schemapb.DataType_Array, + errIsNil: false, + }, + { + dType: schemapb.DataType_JSON, + errIsNil: false, + }, + { + dType: schemapb.DataType_FloatVector, + errIsNil: true, + }, + { + dType: schemapb.DataType_BinaryVector, + errIsNil: false, + }, + } + + c := newScaNNChecker() + for _, test := range cases { + err := c.CheckValidDataType(test.dType) + if test.errIsNil { + assert.NoError(t, err) + } else { + assert.Error(t, err) + } + } +}