Refactor check logic of index parameters (#23856)

Signed-off-by: longjiquan <jiquan.long@zilliz.com>
This commit is contained in:
Jiquan Long 2023-05-06 10:40:39 +08:00 committed by GitHub
parent 899702f13c
commit 7be7e6f360
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
32 changed files with 2325 additions and 872 deletions

View File

@ -1645,7 +1645,7 @@ func (node *Proxy) CreateIndex(ctx context.Context, request *milvuspb.CreateInde
zap.String("field", request.FieldName),
zap.Any("extra_params", request.ExtraParams))
log.Debug(rpcReceived(method))
log.Info(rpcReceived(method))
if err := node.sched.ddQueue.Enqueue(cit); err != nil {
log.Warn(
@ -1661,7 +1661,7 @@ func (node *Proxy) CreateIndex(ctx context.Context, request *milvuspb.CreateInde
}, nil
}
log.Debug(
log.Info(
rpcEnqueued(method),
zap.Uint64("BeginTs", cit.BeginTs()),
zap.Uint64("EndTs", cit.EndTs()))
@ -1682,7 +1682,7 @@ func (node *Proxy) CreateIndex(ctx context.Context, request *milvuspb.CreateInde
}, nil
}
log.Debug(
log.Info(
rpcDone(method),
zap.Uint64("BeginTs", cit.BeginTs()),
zap.Uint64("EndTs", cit.EndTs()))

View File

@ -242,9 +242,9 @@ func checkTrain(field *schemapb.FieldSchema, indexParams map[string]string) erro
return indexparamcheck.CheckIndexValid(field.GetDataType(), indexType, indexParams)
}
adapter, err := indexparamcheck.GetConfAdapterMgrInstance().GetAdapter(indexType)
checker, err := indexparamcheck.GetIndexCheckerMgrInstance().GetChecker(indexType)
if err != nil {
log.Warn("Failed to get conf adapter", zap.String(common.IndexTypeKey, indexType))
log.Warn("Failed to get index checker", zap.String(common.IndexTypeKey, indexType))
return fmt.Errorf("invalid index type: %s", indexType)
}
@ -252,16 +252,14 @@ func checkTrain(field *schemapb.FieldSchema, indexParams map[string]string) erro
return err
}
ok := adapter.CheckValidDataType(field.GetDataType())
if !ok {
log.Warn("Field data type don't support the index build type", zap.String("fieldDataType", field.GetDataType().String()), zap.String("indexType", indexType))
return fmt.Errorf("field data type %s don't support the index build type %s", field.GetDataType().String(), indexType)
if err := checker.CheckValidDataType(field.GetDataType()); err != nil {
log.Info("create index with invalid data type", zap.Error(err), zap.String("data_type", field.GetDataType().String()))
return err
}
ok = adapter.CheckTrain(indexParams)
if !ok {
log.Warn("Create index with invalid params", zap.Any("index_params", indexParams))
return fmt.Errorf("invalid index params: %v", indexParams)
if err := checker.CheckTrain(indexParams); err != nil {
log.Info("create index with invalid parameters", zap.Error(err))
return err
}
return nil

View File

@ -0,0 +1,25 @@
package indexparamcheck
import (
"github.com/milvus-io/milvus-proto/go-api/schemapb"
)
type baseChecker struct {
}
func (c *baseChecker) CheckTrain(params map[string]string) error {
if !CheckIntByRange(params, DIM, DefaultMinDim, DefaultMaxDim) {
return errOutOfRange(DIM, DefaultMinDim, DefaultMaxDim)
}
return nil
}
// CheckValidDataType check whether the field data type is supported for the index type
func (c *baseChecker) CheckValidDataType(dType schemapb.DataType) error {
return nil
}
func newBaseChecker() IndexChecker {
return &baseChecker{}
}

View File

@ -0,0 +1,107 @@
package indexparamcheck
import (
"strconv"
"testing"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/stretchr/testify/assert"
)
func Test_baseChecker_CheckTrain(t *testing.T) {
validParams := map[string]string{
DIM: strconv.Itoa(128),
Metric: L2,
}
paramsWithoutDim := map[string]string{
Metric: L2,
}
cases := []struct {
params map[string]string
errIsNil bool
}{
{validParams, true},
{paramsWithoutDim, false},
}
c := newBaseChecker()
for _, test := range cases {
err := c.CheckTrain(test.params)
if test.errIsNil {
assert.NoError(t, err)
} else {
assert.Error(t, err)
}
}
}
func Test_baseChecker_CheckValidDataType(t *testing.T) {
cases := []struct {
dType schemapb.DataType
errIsNil bool
}{
{
dType: schemapb.DataType_Bool,
errIsNil: true,
},
{
dType: schemapb.DataType_Int8,
errIsNil: true,
},
{
dType: schemapb.DataType_Int16,
errIsNil: true,
},
{
dType: schemapb.DataType_Int32,
errIsNil: true,
},
{
dType: schemapb.DataType_Int64,
errIsNil: true,
},
{
dType: schemapb.DataType_Float,
errIsNil: true,
},
{
dType: schemapb.DataType_Double,
errIsNil: true,
},
{
dType: schemapb.DataType_String,
errIsNil: true,
},
{
dType: schemapb.DataType_VarChar,
errIsNil: true,
},
{
dType: schemapb.DataType_Array,
errIsNil: true,
},
{
dType: schemapb.DataType_JSON,
errIsNil: true,
},
{
dType: schemapb.DataType_FloatVector,
errIsNil: true,
},
{
dType: schemapb.DataType_BinaryVector,
errIsNil: true,
},
}
c := newBaseChecker()
for _, test := range cases {
err := c.CheckValidDataType(test.dType)
if test.errIsNil {
assert.NoError(t, err)
} else {
assert.Error(t, err)
}
}
}

View File

@ -0,0 +1,14 @@
package indexparamcheck
type binFlatChecker struct {
binaryVectorBaseChecker
}
// CheckTrain checks if a binary flat index can be built with the specific parameters.
func (c *binFlatChecker) CheckTrain(params map[string]string) error {
return c.binaryVectorBaseChecker.CheckTrain(params)
}
func newBinFlatChecker() IndexChecker {
return &binFlatChecker{}
}

View File

@ -0,0 +1,151 @@
package indexparamcheck
import (
"strconv"
"testing"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/stretchr/testify/assert"
)
func Test_binFlatChecker_CheckTrain(t *testing.T) {
validParams := map[string]string{
DIM: strconv.Itoa(128),
Metric: JACCARD,
}
paramsWithoutDim := map[string]string{
Metric: JACCARD,
}
p1 := map[string]string{
DIM: strconv.Itoa(128),
Metric: L2,
}
p2 := map[string]string{
DIM: strconv.Itoa(128),
Metric: IP,
}
p3 := map[string]string{
DIM: strconv.Itoa(128),
Metric: COSINE,
}
p4 := map[string]string{
DIM: strconv.Itoa(128),
Metric: HAMMING,
}
p5 := map[string]string{
DIM: strconv.Itoa(128),
Metric: JACCARD,
}
p6 := map[string]string{
DIM: strconv.Itoa(128),
Metric: TANIMOTO,
}
p7 := map[string]string{
DIM: strconv.Itoa(128),
Metric: SUBSTRUCTURE,
}
p8 := map[string]string{
DIM: strconv.Itoa(128),
Metric: SUPERSTRUCTURE,
}
cases := []struct {
params map[string]string
errIsNil bool
}{
{validParams, true},
{paramsWithoutDim, false},
{p1, false},
{p2, false},
{p3, false},
{p4, true},
{p5, true},
{p6, true},
{p7, true},
{p8, true},
}
c := newBinFlatChecker()
for _, test := range cases {
err := c.CheckTrain(test.params)
if test.errIsNil {
assert.NoError(t, err)
} else {
assert.Error(t, err)
}
}
}
func Test_binFlatChecker_CheckValidDataType(t *testing.T) {
cases := []struct {
dType schemapb.DataType
errIsNil bool
}{
{
dType: schemapb.DataType_Bool,
errIsNil: false,
},
{
dType: schemapb.DataType_Int8,
errIsNil: false,
},
{
dType: schemapb.DataType_Int16,
errIsNil: false,
},
{
dType: schemapb.DataType_Int32,
errIsNil: false,
},
{
dType: schemapb.DataType_Int64,
errIsNil: false,
},
{
dType: schemapb.DataType_Float,
errIsNil: false,
},
{
dType: schemapb.DataType_Double,
errIsNil: false,
},
{
dType: schemapb.DataType_String,
errIsNil: false,
},
{
dType: schemapb.DataType_VarChar,
errIsNil: false,
},
{
dType: schemapb.DataType_Array,
errIsNil: false,
},
{
dType: schemapb.DataType_JSON,
errIsNil: false,
},
{
dType: schemapb.DataType_FloatVector,
errIsNil: false,
},
{
dType: schemapb.DataType_BinaryVector,
errIsNil: true,
},
}
c := newBinFlatChecker()
for _, test := range cases {
err := c.CheckValidDataType(test.dType)
if test.errIsNil {
assert.NoError(t, err)
} else {
assert.Error(t, err)
}
}
}

View File

@ -0,0 +1,29 @@
package indexparamcheck
import (
"fmt"
)
type binIVFFlatChecker struct {
binaryVectorBaseChecker
}
func (c *binIVFFlatChecker) CheckTrain(params map[string]string) error {
if err := c.binaryVectorBaseChecker.CheckTrain(params); err != nil {
return err
}
if !CheckStrByValues(params, Metric, BinIvfMetrics) {
return fmt.Errorf("metric type not found or not supported, supported: %v", BinIvfMetrics)
}
if !CheckIntByRange(params, NLIST, MinNList, MaxNList) {
return errOutOfRange(NLIST, MinNList, MaxNList)
}
return nil
}
func newBinIVFFlatChecker() IndexChecker {
return &binIVFFlatChecker{}
}

View File

@ -0,0 +1,207 @@
package indexparamcheck
import (
"strconv"
"testing"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/stretchr/testify/assert"
)
func Test_binIVFFlatChecker_CheckTrain(t *testing.T) {
validParams := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(100),
IVFM: strconv.Itoa(4),
NBITS: strconv.Itoa(8),
Metric: JACCARD,
}
paramsWithoutDim := map[string]string{
NLIST: strconv.Itoa(100),
IVFM: strconv.Itoa(4),
NBITS: strconv.Itoa(8),
Metric: JACCARD,
}
invalidParams := copyParams(validParams)
invalidParams[Metric] = L2
paramsWithLargeNlist := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(MaxNList + 1),
IVFM: strconv.Itoa(4),
NBITS: strconv.Itoa(8),
Metric: JACCARD,
}
paramsWithSmallNlist := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(MinNList - 1),
IVFM: strconv.Itoa(4),
NBITS: strconv.Itoa(8),
Metric: JACCARD,
}
p1 := map[string]string{
DIM: strconv.Itoa(128),
Metric: L2,
NLIST: strconv.Itoa(100),
IVFM: strconv.Itoa(4),
NBITS: strconv.Itoa(8),
}
p2 := map[string]string{
DIM: strconv.Itoa(128),
Metric: IP,
NLIST: strconv.Itoa(100),
IVFM: strconv.Itoa(4),
NBITS: strconv.Itoa(8),
}
p3 := map[string]string{
DIM: strconv.Itoa(128),
Metric: COSINE,
NLIST: strconv.Itoa(100),
IVFM: strconv.Itoa(4),
NBITS: strconv.Itoa(8),
}
p4 := map[string]string{
DIM: strconv.Itoa(128),
Metric: HAMMING,
NLIST: strconv.Itoa(100),
IVFM: strconv.Itoa(4),
NBITS: strconv.Itoa(8),
}
p5 := map[string]string{
DIM: strconv.Itoa(128),
Metric: JACCARD,
NLIST: strconv.Itoa(100),
IVFM: strconv.Itoa(4),
NBITS: strconv.Itoa(8),
}
p6 := map[string]string{
DIM: strconv.Itoa(128),
Metric: TANIMOTO,
NLIST: strconv.Itoa(100),
IVFM: strconv.Itoa(4),
NBITS: strconv.Itoa(8),
}
p7 := map[string]string{
DIM: strconv.Itoa(128),
Metric: SUBSTRUCTURE,
NLIST: strconv.Itoa(100),
IVFM: strconv.Itoa(4),
NBITS: strconv.Itoa(8),
}
p8 := map[string]string{
DIM: strconv.Itoa(128),
Metric: SUPERSTRUCTURE,
NLIST: strconv.Itoa(100),
IVFM: strconv.Itoa(4),
NBITS: strconv.Itoa(8),
}
cases := []struct {
params map[string]string
errIsNil bool
}{
{validParams, true},
{paramsWithoutDim, false},
{paramsWithLargeNlist, false},
{paramsWithSmallNlist, false},
{invalidParams, false},
{p1, false},
{p2, false},
{p3, false},
{p4, true},
{p5, true},
{p6, true},
{p7, false},
{p8, false},
}
c := newBinIVFFlatChecker()
for _, test := range cases {
err := c.CheckTrain(test.params)
if test.errIsNil {
assert.NoError(t, err)
} else {
assert.Error(t, err)
}
}
}
func Test_binIVFFlatChecker_CheckValidDataType(t *testing.T) {
cases := []struct {
dType schemapb.DataType
errIsNil bool
}{
{
dType: schemapb.DataType_Bool,
errIsNil: false,
},
{
dType: schemapb.DataType_Int8,
errIsNil: false,
},
{
dType: schemapb.DataType_Int16,
errIsNil: false,
},
{
dType: schemapb.DataType_Int32,
errIsNil: false,
},
{
dType: schemapb.DataType_Int64,
errIsNil: false,
},
{
dType: schemapb.DataType_Float,
errIsNil: false,
},
{
dType: schemapb.DataType_Double,
errIsNil: false,
},
{
dType: schemapb.DataType_String,
errIsNil: false,
},
{
dType: schemapb.DataType_VarChar,
errIsNil: false,
},
{
dType: schemapb.DataType_Array,
errIsNil: false,
},
{
dType: schemapb.DataType_JSON,
errIsNil: false,
},
{
dType: schemapb.DataType_FloatVector,
errIsNil: false,
},
{
dType: schemapb.DataType_BinaryVector,
errIsNil: true,
},
}
c := newBinIVFFlatChecker()
for _, test := range cases {
err := c.CheckValidDataType(test.dType)
if test.errIsNil {
assert.NoError(t, err)
} else {
assert.Error(t, err)
}
}
}

View File

@ -0,0 +1,34 @@
package indexparamcheck
import (
"fmt"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
)
type binaryVectorBaseChecker struct {
baseChecker
}
func (c *binaryVectorBaseChecker) CheckTrain(params map[string]string) error {
if err := c.baseChecker.CheckTrain(params); err != nil {
return err
}
if !CheckStrByValues(params, Metric, BinIDMapMetrics) {
return fmt.Errorf("metric type not found or not supported, supported: %v", BinIDMapMetrics)
}
return nil
}
func (c *binaryVectorBaseChecker) CheckValidDataType(dType schemapb.DataType) error {
if dType != schemapb.DataType_BinaryVector {
return fmt.Errorf("binary vector is only supported")
}
return nil
}
func newBinaryVectorBaseChecker() IndexChecker {
return &binaryVectorBaseChecker{}
}

View File

@ -0,0 +1,79 @@
package indexparamcheck
import (
"testing"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/stretchr/testify/assert"
)
func Test_binaryVectorBaseChecker_CheckValidDataType(t *testing.T) {
cases := []struct {
dType schemapb.DataType
errIsNil bool
}{
{
dType: schemapb.DataType_Bool,
errIsNil: false,
},
{
dType: schemapb.DataType_Int8,
errIsNil: false,
},
{
dType: schemapb.DataType_Int16,
errIsNil: false,
},
{
dType: schemapb.DataType_Int32,
errIsNil: false,
},
{
dType: schemapb.DataType_Int64,
errIsNil: false,
},
{
dType: schemapb.DataType_Float,
errIsNil: false,
},
{
dType: schemapb.DataType_Double,
errIsNil: false,
},
{
dType: schemapb.DataType_String,
errIsNil: false,
},
{
dType: schemapb.DataType_VarChar,
errIsNil: false,
},
{
dType: schemapb.DataType_Array,
errIsNil: false,
},
{
dType: schemapb.DataType_JSON,
errIsNil: false,
},
{
dType: schemapb.DataType_FloatVector,
errIsNil: false,
},
{
dType: schemapb.DataType_BinaryVector,
errIsNil: true,
},
}
c := newBinaryVectorBaseChecker()
for _, test := range cases {
err := c.CheckValidDataType(test.dType)
if test.errIsNil {
assert.NoError(t, err)
} else {
assert.Error(t, err)
}
}
}

View File

@ -1,381 +0,0 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package indexparamcheck
import (
"strconv"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/pkg/util/funcutil"
)
const (
// L2 represents Euclidean distance
L2 = "L2"
// IP represents inner product distance
IP = "IP"
// COSINE represents cosine distance
COSINE = "COSINE"
// HAMMING represents hamming distance
HAMMING = "HAMMING"
// JACCARD represents jaccard distance
JACCARD = "JACCARD"
// TANIMOTO represents tanimoto distance
TANIMOTO = "TANIMOTO"
// SUBSTRUCTURE represents substructure distance
SUBSTRUCTURE = "SUBSTRUCTURE"
// SUPERSTRUCTURE represents superstructure distance
SUPERSTRUCTURE = "SUPERSTRUCTURE"
MinNBits = 1
MaxNBits = 16
DefaultNBits = 8
// MinNList is the lower limit of nlist that used in Index IVFxxx
MinNList = 1
// MaxNList is the upper limit of nlist that used in Index IVFxxx
MaxNList = 65536
// DefaultMinDim is the smallest dimension supported in Milvus
DefaultMinDim = 1
// DefaultMaxDim is the largest dimension supported in Milvus
DefaultMaxDim = 32768
// If Dim = 32 and raw vector data = 2G, query node need 24G disk space When loading the vectors' disk index
// If Dim = 2, and raw vector data = 2G, query node need 240G disk space When loading the vectors' disk index
// So DiskAnnMinDim should be greater than or equal to 32 to avoid running out of disk space
DiskAnnMinDim = 32
HNSWMinEfConstruction = 8
HNSWMaxEfConstruction = 512
HNSWMinM = 4
HNSWMaxM = 64
// DIM is a constant used to represent dimension
DIM = "dim"
// Metric is a constant used to metric type
Metric = "metric_type"
// NLIST is a constant used to nlist in Index IVFxxx
NLIST = "nlist"
NBITS = "nbits"
IVFM = "m"
EFConstruction = "efConstruction"
HNSWM = "M"
)
// METRICS is a set of all metrics types supported for float vector.
var METRICS = []string{L2, IP, COSINE} // const
// BinIDMapMetrics is a set of all metric types supported for binary vector.
var BinIDMapMetrics = []string{HAMMING, JACCARD, TANIMOTO, SUBSTRUCTURE, SUPERSTRUCTURE} // const
var BinIvfMetrics = []string{HAMMING, JACCARD, TANIMOTO} // const
var supportDimPerSubQuantizer = []int{32, 28, 24, 20, 16, 12, 10, 8, 6, 4, 3, 2, 1} // const
var supportSubQuantizer = []int{96, 64, 56, 48, 40, 32, 28, 24, 20, 16, 12, 8, 4, 3, 2, 1} // const
type ConfAdapter interface {
// CheckTrain returns true if the index can be built with the specific index parameters.
CheckTrain(map[string]string) bool
CheckValidDataType(dType schemapb.DataType) bool
}
// BaseConfAdapter checks if a `FLAT` index can be built.
type BaseConfAdapter struct {
}
// CheckTrain check whether the params contains supported metrics types
func (adapter *BaseConfAdapter) CheckTrain(params map[string]string) bool {
if !CheckIntByRange(params, DIM, DefaultMinDim, DefaultMaxDim) {
return false
}
return CheckStrByValues(params, Metric, METRICS)
}
// CheckValidDataType check whether the field data type is supported for the index type
func (adapter *BaseConfAdapter) CheckValidDataType(dType schemapb.DataType) bool {
return true
}
func newBaseConfAdapter() *BaseConfAdapter {
return &BaseConfAdapter{}
}
// IVFConfAdapter checks if a IVF index can be built.
type IVFConfAdapter struct {
BaseConfAdapter
}
// CheckTrain returns true if the index can be built with the specific index parameters.
func (adapter *IVFConfAdapter) CheckTrain(params map[string]string) bool {
if !CheckIntByRange(params, NLIST, MinNList, MaxNList) {
return false
}
// skip check number of rows
return adapter.BaseConfAdapter.CheckTrain(params)
}
func newIVFConfAdapter() *IVFConfAdapter {
return &IVFConfAdapter{}
}
// IVFPQConfAdapter checks if a IVF_PQ index can be built.
type IVFPQConfAdapter struct {
IVFConfAdapter
}
// CheckTrain checks if ivf-pq index can be built with the specific index parameters.
func (adapter *IVFPQConfAdapter) CheckTrain(params map[string]string) bool {
if !adapter.IVFConfAdapter.CheckTrain(params) {
return false
}
return adapter.checkPQParams(params)
}
func (adapter *IVFPQConfAdapter) checkPQParams(params map[string]string) bool {
dimStr, dimensionExist := params[DIM]
if !dimensionExist {
return false
}
dimension, err := strconv.Atoi(dimStr)
if err != nil { // invalid dimension
return false
}
// nbits can be set to default: 8
nbitsStr, nbitsExist := params[NBITS]
if nbitsExist {
_, err := strconv.Atoi(nbitsStr)
if err != nil { // invalid nbits
return false
}
}
mStr, ok := params[IVFM]
if !ok {
return false
}
m, err := strconv.Atoi(mStr)
if err != nil || m == 0 { // invalid m
return false
}
return adapter.checkCPUPQParams(dimension, m)
}
func (adapter *IVFPQConfAdapter) checkGPUPQParams(dimension, m, nbits int) bool {
/*
* Faiss 1.6
* Only 1, 2, 3, 4, 6, 8, 10, 12, 16, 20, 24, 28, 32 dims per sub-quantizer are currently supported with
* no precomputed codes. Precomputed codes supports any number of dimensions, but will involve memory overheads.
*/
subDim := dimension / m
return funcutil.SliceContain(supportSubQuantizer, m) && funcutil.SliceContain(supportDimPerSubQuantizer, subDim) && nbits == 8
}
func (adapter *IVFPQConfAdapter) checkCPUPQParams(dimension, m int) bool {
return (dimension % m) == 0
}
func newIVFPQConfAdapter() *IVFPQConfAdapter {
return &IVFPQConfAdapter{}
}
// RaftIVFPQConfAdapter checks if a RAFT_IVF_PQ index can be built.
type RaftIVFPQConfAdapter struct {
IVFConfAdapter
}
// CheckTrain checks if ivf-pq index can be built with the specific index parameters.
func (adapter *RaftIVFPQConfAdapter) CheckTrain(params map[string]string) bool {
if !adapter.IVFConfAdapter.CheckTrain(params) {
return false
}
return adapter.checkPQParams(params)
}
func (adapter *RaftIVFPQConfAdapter) checkPQParams(params map[string]string) bool {
dimStr, dimensionExist := params[DIM]
if !dimensionExist {
return false
}
dimension, err := strconv.Atoi(dimStr)
if err != nil { // invalid dimension
return false
}
// nbits can be set to default: 8
nbitsStr, nbitsExist := params[NBITS]
if nbitsExist {
_, err := strconv.Atoi(nbitsStr)
if err != nil { // invalid nbits
return false
}
}
mStr, ok := params[IVFM]
if !ok {
return false
}
m, err := strconv.Atoi(mStr)
if err != nil { // invalid m
return false
}
// here is the only difference with IVF_PQ
if m == 0 {
return true
}
return dimension%m == 0
}
func newRaftIVFPQConfAdapter() *RaftIVFPQConfAdapter {
return &RaftIVFPQConfAdapter{}
}
// IVFSQConfAdapter checks if a IVF_SQ index can be built.
type IVFSQConfAdapter struct {
IVFConfAdapter
}
func (adapter *IVFSQConfAdapter) checkNBits(params map[string]string) bool {
// cgo will set this key to DefaultNBits (8), which is the only value Milvus supports.
_, exist := params[NBITS]
if exist {
// 8 is the only supported nbits.
return CheckIntByRange(params, NBITS, DefaultNBits, DefaultNBits)
}
return true
}
// CheckTrain returns true if the index can be built with the specific index parameters.
func (adapter *IVFSQConfAdapter) CheckTrain(params map[string]string) bool {
if !adapter.checkNBits(params) {
return false
}
return adapter.IVFConfAdapter.CheckTrain(params)
}
func newIVFSQConfAdapter() *IVFSQConfAdapter {
return &IVFSQConfAdapter{}
}
type BinIDMAPConfAdapter struct {
BaseConfAdapter
}
// CheckTrain checks if a binary flat index can be built with the specific parameters.
func (adapter *BinIDMAPConfAdapter) CheckTrain(params map[string]string) bool {
if !CheckIntByRange(params, DIM, DefaultMinDim, DefaultMaxDim) {
return false
}
return CheckStrByValues(params, Metric, BinIDMapMetrics)
}
func newBinIDMAPConfAdapter() *BinIDMAPConfAdapter {
return &BinIDMAPConfAdapter{}
}
// BinIVFConfAdapter checks if a bin IFV index can be built.
type BinIVFConfAdapter struct {
BaseConfAdapter
}
// CheckTrain checks if a binary ivf index can be built with specific parameters.
func (adapter *BinIVFConfAdapter) CheckTrain(params map[string]string) bool {
if !CheckIntByRange(params, DIM, DefaultMinDim, DefaultMaxDim) {
return false
}
if !CheckIntByRange(params, NLIST, MinNList, MaxNList) {
return false
}
if !CheckStrByValues(params, Metric, BinIvfMetrics) {
return false
}
// skip checking the number of rows
return true
}
func newBinIVFConfAdapter() *BinIVFConfAdapter {
return &BinIVFConfAdapter{}
}
// HNSWConfAdapter checks if a hnsw index can be built.
type HNSWConfAdapter struct {
BaseConfAdapter
}
// CheckTrain checks if a hnsw index can be built with specific parameters.
func (adapter *HNSWConfAdapter) CheckTrain(params map[string]string) bool {
if !CheckIntByRange(params, EFConstruction, HNSWMinEfConstruction, HNSWMaxEfConstruction) {
return false
}
if !CheckIntByRange(params, HNSWM, HNSWMinM, HNSWMaxM) {
return false
}
return adapter.BaseConfAdapter.CheckTrain(params)
}
func newHNSWConfAdapter() *HNSWConfAdapter {
return &HNSWConfAdapter{}
}
// DISKANNConfAdapter checks if an diskann index can be built.
type DISKANNConfAdapter struct {
BaseConfAdapter
}
func (adapter *DISKANNConfAdapter) CheckTrain(params map[string]string) bool {
if !CheckIntByRange(params, DIM, DiskAnnMinDim, DefaultMaxDim) {
return false
}
return adapter.BaseConfAdapter.CheckTrain(params)
}
// CheckValidDataType check whether the field data type is supported for the index type
func (adapter *DISKANNConfAdapter) CheckValidDataType(dType schemapb.DataType) bool {
vecDataTypes := []schemapb.DataType{
schemapb.DataType_FloatVector,
}
return funcutil.SliceContain(vecDataTypes, dType)
}
func newDISKANNConfAdapter() *DISKANNConfAdapter {
return &DISKANNConfAdapter{}
}

View File

@ -22,56 +22,53 @@ import (
"github.com/cockroachdb/errors"
)
// ConfAdapterMgr manages the conf adapter.
type ConfAdapterMgr interface {
// GetAdapter gets the conf adapter by the index type.
GetAdapter(indexType string) (ConfAdapter, error)
type IndexCheckerMgr interface {
GetChecker(indexType string) (IndexChecker, error)
}
// ConfAdapterMgrImpl implements ConfAdapter.
type ConfAdapterMgrImpl struct {
adapters map[IndexType]ConfAdapter
// indexCheckerMgrImpl implements IndexChecker.
type indexCheckerMgrImpl struct {
checkers map[IndexType]IndexChecker
once sync.Once
}
// GetAdapter gets the conf adapter by the index type.
func (mgr *ConfAdapterMgrImpl) GetAdapter(indexType string) (ConfAdapter, error) {
mgr.once.Do(mgr.registerConfAdapter)
func (mgr *indexCheckerMgrImpl) GetChecker(indexType string) (IndexChecker, error) {
mgr.once.Do(mgr.registerIndexChecker)
adapter, ok := mgr.adapters[indexType]
adapter, ok := mgr.checkers[indexType]
if ok {
return adapter, nil
}
return nil, errors.New("Can not find conf adapter: " + indexType)
}
func (mgr *ConfAdapterMgrImpl) registerConfAdapter() {
mgr.adapters[IndexRaftIvfFlat] = newIVFConfAdapter()
mgr.adapters[IndexRaftIvfPQ] = newRaftIVFPQConfAdapter()
mgr.adapters[IndexFaissIDMap] = newBaseConfAdapter()
mgr.adapters[IndexFaissIvfFlat] = newIVFConfAdapter()
mgr.adapters[IndexFaissIvfPQ] = newIVFPQConfAdapter()
mgr.adapters[IndexFaissIvfSQ8] = newIVFSQConfAdapter()
mgr.adapters[IndexFaissBinIDMap] = newBinIDMAPConfAdapter()
mgr.adapters[IndexFaissBinIvfFlat] = newBinIVFConfAdapter()
mgr.adapters[IndexHNSW] = newHNSWConfAdapter()
mgr.adapters[IndexDISKANN] = newDISKANNConfAdapter()
func (mgr *indexCheckerMgrImpl) registerIndexChecker() {
mgr.checkers[IndexRaftIvfFlat] = newIVFBaseChecker()
mgr.checkers[IndexRaftIvfPQ] = newRaftIVFPQChecker()
mgr.checkers[IndexFaissIDMap] = newBaseChecker()
mgr.checkers[IndexFaissIvfFlat] = newIVFBaseChecker()
mgr.checkers[IndexFaissIvfPQ] = newIVFPQChecker()
mgr.checkers[IndexFaissIvfSQ8] = newIVFSQChecker()
mgr.checkers[IndexFaissBinIDMap] = newBinFlatChecker()
mgr.checkers[IndexFaissBinIvfFlat] = newBinIVFFlatChecker()
mgr.checkers[IndexHNSW] = newHnswChecker()
mgr.checkers[IndexDISKANN] = newDiskannChecker()
}
func newConfAdapterMgrImpl() *ConfAdapterMgrImpl {
return &ConfAdapterMgrImpl{
adapters: make(map[IndexType]ConfAdapter),
func newIndexCheckerMgr() *indexCheckerMgrImpl {
return &indexCheckerMgrImpl{
checkers: make(map[IndexType]IndexChecker),
}
}
var confAdapterMgr ConfAdapterMgr
var indexCheckerMgr IndexCheckerMgr
var getConfAdapterMgrOnce sync.Once
var getIndexCheckerMgrOnce sync.Once
// GetConfAdapterMgrInstance gets the instance of ConfAdapterMgr.
func GetConfAdapterMgrInstance() ConfAdapterMgr {
getConfAdapterMgrOnce.Do(func() {
confAdapterMgr = newConfAdapterMgrImpl()
// GetIndexCheckerMgrInstance gets the instance of IndexCheckerMgr.
func GetIndexCheckerMgrInstance() IndexCheckerMgr {
getIndexCheckerMgrOnce.Do(func() {
indexCheckerMgr = newIndexCheckerMgr()
})
return confAdapterMgr
return indexCheckerMgr
}

View File

@ -19,122 +19,122 @@ import (
)
func Test_GetConfAdapterMgrInstance(t *testing.T) {
adapterMgr := GetConfAdapterMgrInstance()
adapterMgr := GetIndexCheckerMgrInstance()
var adapter ConfAdapter
var adapter IndexChecker
var err error
var ok bool
adapter, err = adapterMgr.GetAdapter("invalid")
adapter, err = adapterMgr.GetChecker("invalid")
assert.NotEqual(t, nil, err)
assert.Equal(t, nil, adapter)
adapter, err = adapterMgr.GetAdapter(IndexFaissIDMap)
adapter, err = adapterMgr.GetChecker(IndexFaissIDMap)
assert.Equal(t, nil, err)
assert.NotEqual(t, nil, adapter)
_, ok = adapter.(*BaseConfAdapter)
_, ok = adapter.(*baseChecker)
assert.Equal(t, true, ok)
adapter, err = adapterMgr.GetAdapter(IndexFaissIvfFlat)
adapter, err = adapterMgr.GetChecker(IndexFaissIvfFlat)
assert.Equal(t, nil, err)
assert.NotEqual(t, nil, adapter)
_, ok = adapter.(*IVFConfAdapter)
_, ok = adapter.(*ivfBaseChecker)
assert.Equal(t, true, ok)
adapter, err = adapterMgr.GetAdapter(IndexFaissIvfPQ)
adapter, err = adapterMgr.GetChecker(IndexFaissIvfPQ)
assert.Equal(t, nil, err)
assert.NotEqual(t, nil, adapter)
_, ok = adapter.(*IVFPQConfAdapter)
_, ok = adapter.(*ivfPQChecker)
assert.Equal(t, true, ok)
adapter, err = adapterMgr.GetAdapter(IndexFaissIvfSQ8)
adapter, err = adapterMgr.GetChecker(IndexFaissIvfSQ8)
assert.Equal(t, nil, err)
assert.NotEqual(t, nil, adapter)
_, ok = adapter.(*IVFSQConfAdapter)
_, ok = adapter.(*ivfSQChecker)
assert.Equal(t, true, ok)
adapter, err = adapterMgr.GetAdapter(IndexFaissBinIDMap)
adapter, err = adapterMgr.GetChecker(IndexFaissBinIDMap)
assert.Equal(t, nil, err)
assert.NotEqual(t, nil, adapter)
_, ok = adapter.(*BinIDMAPConfAdapter)
_, ok = adapter.(*binFlatChecker)
assert.Equal(t, true, ok)
adapter, err = adapterMgr.GetAdapter(IndexFaissBinIvfFlat)
adapter, err = adapterMgr.GetChecker(IndexFaissBinIvfFlat)
assert.Equal(t, nil, err)
assert.NotEqual(t, nil, adapter)
_, ok = adapter.(*BinIVFConfAdapter)
_, ok = adapter.(*binIVFFlatChecker)
assert.Equal(t, true, ok)
adapter, err = adapterMgr.GetAdapter(IndexHNSW)
adapter, err = adapterMgr.GetChecker(IndexHNSW)
assert.Equal(t, nil, err)
assert.NotEqual(t, nil, adapter)
_, ok = adapter.(*HNSWConfAdapter)
_, ok = adapter.(*hnswChecker)
assert.Equal(t, true, ok)
}
func TestConfAdapterMgrImpl_GetAdapter(t *testing.T) {
adapterMgr := newConfAdapterMgrImpl()
adapterMgr := newIndexCheckerMgr()
var adapter ConfAdapter
var adapter IndexChecker
var err error
var ok bool
adapter, err = adapterMgr.GetAdapter("invalid")
adapter, err = adapterMgr.GetChecker("invalid")
assert.NotEqual(t, nil, err)
assert.Equal(t, nil, adapter)
adapter, err = adapterMgr.GetAdapter(IndexFaissIDMap)
adapter, err = adapterMgr.GetChecker(IndexFaissIDMap)
assert.Equal(t, nil, err)
assert.NotEqual(t, nil, adapter)
_, ok = adapter.(*BaseConfAdapter)
_, ok = adapter.(*baseChecker)
assert.Equal(t, true, ok)
adapter, err = adapterMgr.GetAdapter(IndexFaissIvfFlat)
adapter, err = adapterMgr.GetChecker(IndexFaissIvfFlat)
assert.Equal(t, nil, err)
assert.NotEqual(t, nil, adapter)
_, ok = adapter.(*IVFConfAdapter)
_, ok = adapter.(*ivfBaseChecker)
assert.Equal(t, true, ok)
adapter, err = adapterMgr.GetAdapter(IndexFaissIvfPQ)
adapter, err = adapterMgr.GetChecker(IndexFaissIvfPQ)
assert.Equal(t, nil, err)
assert.NotEqual(t, nil, adapter)
_, ok = adapter.(*IVFPQConfAdapter)
_, ok = adapter.(*ivfPQChecker)
assert.Equal(t, true, ok)
adapter, err = adapterMgr.GetAdapter(IndexFaissIvfSQ8)
adapter, err = adapterMgr.GetChecker(IndexFaissIvfSQ8)
assert.Equal(t, nil, err)
assert.NotEqual(t, nil, adapter)
_, ok = adapter.(*IVFSQConfAdapter)
_, ok = adapter.(*ivfSQChecker)
assert.Equal(t, true, ok)
adapter, err = adapterMgr.GetAdapter(IndexFaissBinIDMap)
adapter, err = adapterMgr.GetChecker(IndexFaissBinIDMap)
assert.Equal(t, nil, err)
assert.NotEqual(t, nil, adapter)
_, ok = adapter.(*BinIDMAPConfAdapter)
_, ok = adapter.(*binFlatChecker)
assert.Equal(t, true, ok)
adapter, err = adapterMgr.GetAdapter(IndexFaissBinIvfFlat)
adapter, err = adapterMgr.GetChecker(IndexFaissBinIvfFlat)
assert.Equal(t, nil, err)
assert.NotEqual(t, nil, adapter)
_, ok = adapter.(*BinIVFConfAdapter)
_, ok = adapter.(*binIVFFlatChecker)
assert.Equal(t, true, ok)
adapter, err = adapterMgr.GetAdapter(IndexHNSW)
adapter, err = adapterMgr.GetChecker(IndexHNSW)
assert.Equal(t, nil, err)
assert.NotEqual(t, nil, adapter)
_, ok = adapter.(*HNSWConfAdapter)
_, ok = adapter.(*hnswChecker)
assert.Equal(t, true, ok)
}
func TestConfAdapterMgrImpl_GetAdapter_multiple_threads(t *testing.T) {
num := 4
mgr := newConfAdapterMgrImpl()
mgr := newIndexCheckerMgr()
var wg sync.WaitGroup
for i := 0; i < num; i++ {
wg.Add(1)
go func() {
defer wg.Done()
adapter, err := mgr.GetAdapter(IndexHNSW)
adapter, err := mgr.GetChecker(IndexHNSW)
assert.NoError(t, err)
assert.NotNil(t, adapter)
}()

View File

@ -1,410 +0,0 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
package indexparamcheck
import (
"strconv"
"testing"
)
// TODO: add more test cases which `ConfAdapter.CheckTrain` return false,
// for example, maybe we can copy test cases from regression test later.
func invalidIVFParamsMin() map[string]string {
invalidIVFParams := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(MinNList - 1),
Metric: L2,
}
return invalidIVFParams
}
func invalidIVFParamsMax() map[string]string {
invalidIVFParams := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(MaxNList + 1),
Metric: L2,
}
return invalidIVFParams
}
func copyParams(original map[string]string) map[string]string {
result := make(map[string]string)
for key, value := range original {
result[key] = value
}
return result
}
func TestBaseConfAdapter_CheckTrain(t *testing.T) {
validParams := map[string]string{
DIM: strconv.Itoa(128),
Metric: L2,
}
paramsWithoutDim := map[string]string{
Metric: L2,
}
cases := []struct {
params map[string]string
want bool
}{
{validParams, true},
{paramsWithoutDim, false},
}
adapter := newBaseConfAdapter()
for _, test := range cases {
if got := adapter.CheckTrain(test.params); got != test.want {
t.Errorf("BaseConfAdapter.CheckTrain(%v) = %v", test.params, test.want)
}
}
}
// IVFConfAdapter checks if an ivf index can be built.
func TestIVFConfAdapter_CheckTrain(t *testing.T) {
validParams := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
Metric: L2,
}
cases := []struct {
params map[string]string
want bool
}{
{validParams, true},
{invalidIVFParamsMin(), false},
{invalidIVFParamsMax(), false},
}
adapter := newIVFConfAdapter()
for _, test := range cases {
if got := adapter.CheckTrain(test.params); got != test.want {
t.Errorf("IVFConfAdapter.CheckTrain(%v) = %v", test.params, test.want)
}
}
}
func TestIVFPQConfAdapter_CheckTrain(t *testing.T) {
validParams := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
IVFM: strconv.Itoa(4),
NBITS: strconv.Itoa(8),
Metric: L2,
}
validParamsWithoutNbits := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
IVFM: strconv.Itoa(4),
Metric: L2,
}
validParamsWithoutDim := map[string]string{
NLIST: strconv.Itoa(1024),
IVFM: strconv.Itoa(4),
NBITS: strconv.Itoa(8),
Metric: L2,
}
invalidParamsDim := copyParams(validParams)
invalidParamsDim[DIM] = "NAN"
invalidParamsNbits := copyParams(validParams)
invalidParamsNbits[NBITS] = "NAN"
invalidParamsWithoutIVF := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
NBITS: strconv.Itoa(8),
Metric: L2,
}
invalidParamsIVF := copyParams(validParams)
invalidParamsIVF[IVFM] = "NAN"
invalidParamsM := copyParams(validParams)
invalidParamsM[DIM] = strconv.Itoa(65536)
invalidParamsMzero := copyParams(validParams)
invalidParamsMzero[IVFM] = "0"
cases := []struct {
params map[string]string
want bool
}{
{validParams, true},
{validParamsWithoutNbits, true},
{invalidIVFParamsMin(), false},
{invalidIVFParamsMax(), false},
{validParamsWithoutDim, false},
{invalidParamsDim, false},
{invalidParamsNbits, false},
{invalidParamsWithoutIVF, false},
{invalidParamsIVF, false},
{invalidParamsM, false},
{invalidParamsMzero, false},
}
adapter := newIVFPQConfAdapter()
for i, test := range cases {
if got := adapter.CheckTrain(test.params); got != test.want {
t.Log("i:", i, "params", test.params)
t.Errorf("IVFPQConfAdapter.CheckTrain(%v) = %v", test.params, test.want)
}
}
}
func TestRaftIVFPQConfAdapter_CheckTrain(t *testing.T) {
validParams := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
IVFM: strconv.Itoa(4),
NBITS: strconv.Itoa(8),
Metric: L2,
}
validParamsWithoutNbits := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
IVFM: strconv.Itoa(4),
Metric: L2,
}
validParamsWithoutDim := map[string]string{
NLIST: strconv.Itoa(1024),
IVFM: strconv.Itoa(4),
NBITS: strconv.Itoa(8),
Metric: L2,
}
invalidParamsDim := copyParams(validParams)
invalidParamsDim[DIM] = "NAN"
invalidParamsNbits := copyParams(validParams)
invalidParamsNbits[NBITS] = "NAN"
invalidParamsWithoutIVF := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
NBITS: strconv.Itoa(8),
Metric: L2,
}
invalidParamsIVF := copyParams(validParams)
invalidParamsIVF[IVFM] = "NAN"
invalidParamsM := copyParams(validParams)
invalidParamsM[DIM] = strconv.Itoa(65536)
validParamsMzero := copyParams(validParams)
validParamsMzero[IVFM] = "0"
cases := []struct {
params map[string]string
want bool
}{
{validParams, true},
{validParamsWithoutNbits, true},
{invalidIVFParamsMin(), false},
{invalidIVFParamsMax(), false},
{validParamsWithoutDim, false},
{invalidParamsDim, false},
{invalidParamsNbits, false},
{invalidParamsWithoutIVF, false},
{invalidParamsIVF, false},
{invalidParamsM, false},
{validParamsMzero, true},
}
adapter := newRaftIVFPQConfAdapter()
for i, test := range cases {
if got := adapter.CheckTrain(test.params); got != test.want {
t.Log("i:", i, "params", test.params)
t.Errorf("RaftIVFPQConfAdapter.CheckTrain(%v) = %v", test.params, test.want)
}
}
}
func TestIVFSQConfAdapter_CheckTrain(t *testing.T) {
getValidParams := func(withNBits bool) map[string]string {
validParams := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(100),
NBITS: strconv.Itoa(8),
Metric: L2,
}
if withNBits {
validParams[NBITS] = strconv.Itoa(DefaultNBits)
}
return validParams
}
validParams := getValidParams(false)
validParamsWithNBits := getValidParams(true)
paramsWithInvalidNBits := getValidParams(false)
paramsWithInvalidNBits[NBITS] = strconv.Itoa(DefaultNBits + 1)
cases := []struct {
params map[string]string
want bool
}{
{validParams, true},
{validParamsWithNBits, true},
{paramsWithInvalidNBits, false},
{invalidIVFParamsMin(), false},
{invalidIVFParamsMax(), false},
}
adapter := newIVFSQConfAdapter()
for _, test := range cases {
if got := adapter.CheckTrain(test.params); got != test.want {
t.Errorf("IVFSQConfAdapter.CheckTrain(%v) = %v", test.params, test.want)
}
}
}
// BinIDMAPConfAdapter checks if a bin id map index can be built.
func TestBinIDMAPConfAdapter_CheckTrain(t *testing.T) {
validParams := map[string]string{
DIM: strconv.Itoa(128),
Metric: JACCARD,
}
paramsWithoutDim := map[string]string{
Metric: JACCARD,
}
cases := []struct {
params map[string]string
want bool
}{
{validParams, true},
{paramsWithoutDim, false},
}
adapter := newBinIDMAPConfAdapter()
for _, test := range cases {
if got := adapter.CheckTrain(test.params); got != test.want {
t.Errorf("BinIDMAPConfAdapter.CheckTrain(%v) = %v", test.params, test.want)
}
}
}
// BinIVFConfAdapter checks if a bin ivf index can be built.
func TestBinIVFConfAdapter_CheckTrain(t *testing.T) {
validParams := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(100),
IVFM: strconv.Itoa(4),
NBITS: strconv.Itoa(8),
Metric: JACCARD,
}
paramsWithoutDim := map[string]string{
NLIST: strconv.Itoa(100),
IVFM: strconv.Itoa(4),
NBITS: strconv.Itoa(8),
Metric: JACCARD,
}
invalidParams := copyParams(validParams)
invalidParams[Metric] = L2
cases := []struct {
params map[string]string
want bool
}{
{validParams, true},
{paramsWithoutDim, false},
{invalidIVFParamsMin(), false},
{invalidIVFParamsMax(), false},
{invalidParams, false},
}
adapter := newBinIVFConfAdapter()
for _, test := range cases {
if got := adapter.CheckTrain(test.params); got != test.want {
t.Errorf("BinIVFConfAdapter.CheckTrain(%v) = %v", test.params, test.want)
}
}
}
// HNSWConfAdapter checks if a hnsw index can be built.
func TestHNSWConfAdapter_CheckTrain(t *testing.T) {
validParams := map[string]string{
DIM: strconv.Itoa(128),
HNSWM: strconv.Itoa(16),
EFConstruction: strconv.Itoa(200),
Metric: L2,
}
invalidEfParamsMin := copyParams(validParams)
invalidEfParamsMin[EFConstruction] = strconv.Itoa(HNSWMinEfConstruction - 1)
invalidEfParamsMax := copyParams(validParams)
invalidEfParamsMax[EFConstruction] = strconv.Itoa(HNSWMaxEfConstruction + 1)
invalidMParamsMin := copyParams(validParams)
invalidMParamsMin[HNSWM] = strconv.Itoa(HNSWMinM - 1)
invalidMParamsMax := copyParams(validParams)
invalidMParamsMax[HNSWM] = strconv.Itoa(HNSWMaxM + 1)
cases := []struct {
params map[string]string
want bool
}{
{validParams, true},
{invalidEfParamsMin, false},
{invalidEfParamsMax, false},
{invalidMParamsMin, false},
{invalidMParamsMax, false},
}
adapter := newHNSWConfAdapter()
for _, test := range cases {
if got := adapter.CheckTrain(test.params); got != test.want {
t.Errorf("HNSWConfAdapter.CheckTrain(%v) = %v", test.params, test.want)
}
}
}
// DISKANNConfAdapter checks if an diskann index can be built
func TestDiskAnnConfAdapter_CheckTrain(t *testing.T) {
validParams := map[string]string{
DIM: strconv.Itoa(128),
Metric: L2,
}
validParamsBigDim := copyParams(validParams)
validParamsBigDim[DIM] = strconv.Itoa(2048)
invalidParamsWithoutDim := map[string]string{
Metric: L2,
}
invalidParamsSmallDim := copyParams(validParams)
invalidParamsSmallDim[DIM] = strconv.Itoa(15)
cases := []struct {
params map[string]string
want bool
}{
{validParams, true},
{validParamsBigDim, true},
{invalidParamsWithoutDim, false},
{invalidParamsSmallDim, false},
}
adapter := newDISKANNConfAdapter()
for _, test := range cases {
if got := adapter.CheckTrain(test.params); got != test.want {
t.Errorf("DiskAnnConfAdapter.CheckTrain(%v) = %v", test.params, test.want)
}
}
}

View File

@ -0,0 +1,72 @@
package indexparamcheck
const (
// L2 represents Euclidean distance
L2 = "L2"
// IP represents inner product distance
IP = "IP"
// COSINE represents cosine distance
COSINE = "COSINE"
// HAMMING represents hamming distance
HAMMING = "HAMMING"
// JACCARD represents jaccard distance
JACCARD = "JACCARD"
// TANIMOTO represents tanimoto distance
TANIMOTO = "TANIMOTO"
// SUBSTRUCTURE represents substructure distance
SUBSTRUCTURE = "SUBSTRUCTURE"
// SUPERSTRUCTURE represents superstructure distance
SUPERSTRUCTURE = "SUPERSTRUCTURE"
MinNBits = 1
MaxNBits = 16
DefaultNBits = 8
// MinNList is the lower limit of nlist that used in Index IVFxxx
MinNList = 1
// MaxNList is the upper limit of nlist that used in Index IVFxxx
MaxNList = 65536
// DefaultMinDim is the smallest dimension supported in Milvus
DefaultMinDim = 1
// DefaultMaxDim is the largest dimension supported in Milvus
DefaultMaxDim = 32768
// If Dim = 32 and raw vector data = 2G, query node need 24G disk space When loading the vectors' disk index
// If Dim = 2, and raw vector data = 2G, query node need 240G disk space When loading the vectors' disk index
// So DiskAnnMinDim should be greater than or equal to 32 to avoid running out of disk space
DiskAnnMinDim = 32
HNSWMinEfConstruction = 8
HNSWMaxEfConstruction = 512
HNSWMinM = 4
HNSWMaxM = 64
// DIM is a constant used to represent dimension
DIM = "dim"
// Metric is a constant used to metric type
Metric = "metric_type"
// NLIST is a constant used to nlist in Index IVFxxx
NLIST = "nlist"
NBITS = "nbits"
IVFM = "m"
EFConstruction = "efConstruction"
HNSWM = "M"
)
// METRICS is a set of all metrics types supported for float vector.
var METRICS = []string{L2, IP, COSINE} // const
// BinIDMapMetrics is a set of all metric types supported for binary vector.
var BinIDMapMetrics = []string{HAMMING, JACCARD, TANIMOTO, SUBSTRUCTURE, SUPERSTRUCTURE} // const
var BinIvfMetrics = []string{HAMMING, JACCARD, TANIMOTO} // const
var supportDimPerSubQuantizer = []int{32, 28, 24, 20, 16, 12, 10, 8, 6, 4, 3, 2, 1} // const
var supportSubQuantizer = []int{96, 64, 56, 48, 40, 32, 28, 24, 20, 16, 12, 8, 4, 3, 2, 1} // const

View File

@ -0,0 +1,17 @@
package indexparamcheck
// diskannChecker checks if an diskann index can be built.
type diskannChecker struct {
floatVectorBaseChecker
}
func (c *diskannChecker) CheckTrain(params map[string]string) error {
if !CheckIntByRange(params, DIM, DiskAnnMinDim, DefaultMaxDim) {
return errOutOfRange(DIM, DiskAnnMinDim, DefaultMaxDim)
}
return c.floatVectorBaseChecker.CheckTrain(params)
}
func newDiskannChecker() IndexChecker {
return &diskannChecker{}
}

View File

@ -0,0 +1,159 @@
package indexparamcheck
import (
"strconv"
"testing"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/stretchr/testify/assert"
)
func Test_diskannChecker_CheckTrain(t *testing.T) {
validParams := map[string]string{
DIM: strconv.Itoa(128),
Metric: L2,
}
validParamsBigDim := copyParams(validParams)
validParamsBigDim[DIM] = strconv.Itoa(2048)
invalidParamsWithoutDim := map[string]string{
Metric: L2,
}
invalidParamsSmallDim := copyParams(validParams)
invalidParamsSmallDim[DIM] = strconv.Itoa(15)
p1 := map[string]string{
DIM: strconv.Itoa(128),
Metric: L2,
}
p2 := map[string]string{
DIM: strconv.Itoa(128),
Metric: IP,
}
p3 := map[string]string{
DIM: strconv.Itoa(128),
Metric: COSINE,
}
p4 := map[string]string{
DIM: strconv.Itoa(128),
Metric: HAMMING,
}
p5 := map[string]string{
DIM: strconv.Itoa(128),
Metric: JACCARD,
}
p6 := map[string]string{
DIM: strconv.Itoa(128),
Metric: TANIMOTO,
}
p7 := map[string]string{
DIM: strconv.Itoa(128),
Metric: SUBSTRUCTURE,
}
p8 := map[string]string{
DIM: strconv.Itoa(128),
Metric: SUPERSTRUCTURE,
}
cases := []struct {
params map[string]string
errIsNil bool
}{
{validParams, true},
{validParamsBigDim, true},
{invalidParamsWithoutDim, false},
{invalidParamsSmallDim, false},
{p1, true},
{p2, true},
{p3, true},
{p4, false},
{p5, false},
{p6, false},
{p7, false},
{p8, false},
}
c := newDiskannChecker()
for _, test := range cases {
err := c.CheckTrain(test.params)
if test.errIsNil {
assert.NoError(t, err)
} else {
assert.Error(t, err)
}
}
}
func Test_diskannChecker_CheckValidDataType(t *testing.T) {
cases := []struct {
dType schemapb.DataType
errIsNil bool
}{
{
dType: schemapb.DataType_Bool,
errIsNil: false,
},
{
dType: schemapb.DataType_Int8,
errIsNil: false,
},
{
dType: schemapb.DataType_Int16,
errIsNil: false,
},
{
dType: schemapb.DataType_Int32,
errIsNil: false,
},
{
dType: schemapb.DataType_Int64,
errIsNil: false,
},
{
dType: schemapb.DataType_Float,
errIsNil: false,
},
{
dType: schemapb.DataType_Double,
errIsNil: false,
},
{
dType: schemapb.DataType_String,
errIsNil: false,
},
{
dType: schemapb.DataType_VarChar,
errIsNil: false,
},
{
dType: schemapb.DataType_Array,
errIsNil: false,
},
{
dType: schemapb.DataType_JSON,
errIsNil: false,
},
{
dType: schemapb.DataType_FloatVector,
errIsNil: true,
},
{
dType: schemapb.DataType_BinaryVector,
errIsNil: false,
},
}
c := newDiskannChecker()
for _, test := range cases {
err := c.CheckValidDataType(test.dType)
if test.errIsNil {
assert.NoError(t, err)
} else {
assert.Error(t, err)
}
}
}

View File

@ -0,0 +1,34 @@
package indexparamcheck
import (
"fmt"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
)
type floatVectorBaseChecker struct {
baseChecker
}
func (c *floatVectorBaseChecker) CheckTrain(params map[string]string) error {
if err := c.baseChecker.CheckTrain(params); err != nil {
return err
}
if !CheckStrByValues(params, Metric, METRICS) {
return fmt.Errorf("metric type not found or not supported, supported: %v", METRICS)
}
return nil
}
func (c *floatVectorBaseChecker) CheckValidDataType(dType schemapb.DataType) error {
if dType != schemapb.DataType_FloatVector {
return fmt.Errorf("float vector is only supported")
}
return nil
}
func newFloatVectorBaseChecker() IndexChecker {
return &floatVectorBaseChecker{}
}

View File

@ -0,0 +1,79 @@
package indexparamcheck
import (
"testing"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/stretchr/testify/assert"
)
func Test_floatVectorBaseChecker_CheckValidDataType(t *testing.T) {
cases := []struct {
dType schemapb.DataType
errIsNil bool
}{
{
dType: schemapb.DataType_Bool,
errIsNil: false,
},
{
dType: schemapb.DataType_Int8,
errIsNil: false,
},
{
dType: schemapb.DataType_Int16,
errIsNil: false,
},
{
dType: schemapb.DataType_Int32,
errIsNil: false,
},
{
dType: schemapb.DataType_Int64,
errIsNil: false,
},
{
dType: schemapb.DataType_Float,
errIsNil: false,
},
{
dType: schemapb.DataType_Double,
errIsNil: false,
},
{
dType: schemapb.DataType_String,
errIsNil: false,
},
{
dType: schemapb.DataType_VarChar,
errIsNil: false,
},
{
dType: schemapb.DataType_Array,
errIsNil: false,
},
{
dType: schemapb.DataType_JSON,
errIsNil: false,
},
{
dType: schemapb.DataType_FloatVector,
errIsNil: true,
},
{
dType: schemapb.DataType_BinaryVector,
errIsNil: false,
},
}
c := newFloatVectorBaseChecker()
for _, test := range cases {
err := c.CheckValidDataType(test.dType)
if test.errIsNil {
assert.NoError(t, err)
} else {
assert.Error(t, err)
}
}
}

View File

@ -0,0 +1,21 @@
package indexparamcheck
type hnswChecker struct {
floatVectorBaseChecker
}
func (c *hnswChecker) CheckTrain(params map[string]string) error {
if !CheckIntByRange(params, EFConstruction, HNSWMinEfConstruction, HNSWMaxEfConstruction) {
return errOutOfRange(EFConstruction, HNSWMinEfConstruction, HNSWMaxEfConstruction)
}
if !CheckIntByRange(params, HNSWM, HNSWMinM, HNSWMaxM) {
return errOutOfRange(HNSWM, HNSWMinM, HNSWMaxM)
}
return c.floatVectorBaseChecker.CheckTrain(params)
}
func newHnswChecker() IndexChecker {
return &hnswChecker{}
}

View File

@ -0,0 +1,182 @@
package indexparamcheck
import (
"strconv"
"testing"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/stretchr/testify/assert"
)
func Test_hnswChecker_CheckTrain(t *testing.T) {
validParams := map[string]string{
DIM: strconv.Itoa(128),
HNSWM: strconv.Itoa(16),
EFConstruction: strconv.Itoa(200),
Metric: L2,
}
invalidEfParamsMin := copyParams(validParams)
invalidEfParamsMin[EFConstruction] = strconv.Itoa(HNSWMinEfConstruction - 1)
invalidEfParamsMax := copyParams(validParams)
invalidEfParamsMax[EFConstruction] = strconv.Itoa(HNSWMaxEfConstruction + 1)
invalidMParamsMin := copyParams(validParams)
invalidMParamsMin[HNSWM] = strconv.Itoa(HNSWMinM - 1)
invalidMParamsMax := copyParams(validParams)
invalidMParamsMax[HNSWM] = strconv.Itoa(HNSWMaxM + 1)
p1 := map[string]string{
DIM: strconv.Itoa(128),
HNSWM: strconv.Itoa(16),
EFConstruction: strconv.Itoa(200),
Metric: L2,
}
p2 := map[string]string{
DIM: strconv.Itoa(128),
HNSWM: strconv.Itoa(16),
EFConstruction: strconv.Itoa(200),
Metric: IP,
}
p3 := map[string]string{
DIM: strconv.Itoa(128),
HNSWM: strconv.Itoa(16),
EFConstruction: strconv.Itoa(200),
Metric: COSINE,
}
p4 := map[string]string{
DIM: strconv.Itoa(128),
HNSWM: strconv.Itoa(16),
EFConstruction: strconv.Itoa(200),
Metric: HAMMING,
}
p5 := map[string]string{
DIM: strconv.Itoa(128),
HNSWM: strconv.Itoa(16),
EFConstruction: strconv.Itoa(200),
Metric: JACCARD,
}
p6 := map[string]string{
DIM: strconv.Itoa(128),
HNSWM: strconv.Itoa(16),
EFConstruction: strconv.Itoa(200),
Metric: TANIMOTO,
}
p7 := map[string]string{
DIM: strconv.Itoa(128),
HNSWM: strconv.Itoa(16),
EFConstruction: strconv.Itoa(200),
Metric: SUBSTRUCTURE,
}
p8 := map[string]string{
DIM: strconv.Itoa(128),
HNSWM: strconv.Itoa(16),
EFConstruction: strconv.Itoa(200),
Metric: SUPERSTRUCTURE,
}
cases := []struct {
params map[string]string
errIsNil bool
}{
{validParams, true},
{invalidEfParamsMin, false},
{invalidEfParamsMax, false},
{invalidMParamsMin, false},
{invalidMParamsMax, false},
{p1, true},
{p2, true},
{p3, true},
{p4, false},
{p5, false},
{p6, false},
{p7, false},
{p8, false},
}
c := newHnswChecker()
for _, test := range cases {
err := c.CheckTrain(test.params)
if test.errIsNil {
assert.NoError(t, err)
} else {
assert.Error(t, err)
}
}
}
func Test_hnswChecker_CheckValidDataType(t *testing.T) {
cases := []struct {
dType schemapb.DataType
errIsNil bool
}{
{
dType: schemapb.DataType_Bool,
errIsNil: false,
},
{
dType: schemapb.DataType_Int8,
errIsNil: false,
},
{
dType: schemapb.DataType_Int16,
errIsNil: false,
},
{
dType: schemapb.DataType_Int32,
errIsNil: false,
},
{
dType: schemapb.DataType_Int64,
errIsNil: false,
},
{
dType: schemapb.DataType_Float,
errIsNil: false,
},
{
dType: schemapb.DataType_Double,
errIsNil: false,
},
{
dType: schemapb.DataType_String,
errIsNil: false,
},
{
dType: schemapb.DataType_VarChar,
errIsNil: false,
},
{
dType: schemapb.DataType_Array,
errIsNil: false,
},
{
dType: schemapb.DataType_JSON,
errIsNil: false,
},
{
dType: schemapb.DataType_FloatVector,
errIsNil: true,
},
{
dType: schemapb.DataType_BinaryVector,
errIsNil: false,
},
}
c := newHnswChecker()
for _, test := range cases {
err := c.CheckValidDataType(test.dType)
if test.errIsNil {
assert.NoError(t, err)
} else {
assert.Error(t, err)
}
}
}

View File

@ -0,0 +1,26 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package indexparamcheck
import (
"github.com/milvus-io/milvus-proto/go-api/schemapb"
)
type IndexChecker interface {
CheckTrain(map[string]string) error
CheckValidDataType(dType schemapb.DataType) error
}

View File

@ -0,0 +1,45 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
package indexparamcheck
import (
"strconv"
)
// TODO: add more test cases which `IndexChecker.CheckTrain` return false,
// for example, maybe we can copy test cases from regression test later.
func invalidIVFParamsMin() map[string]string {
invalidIVFParams := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(MinNList - 1),
Metric: L2,
}
return invalidIVFParams
}
func invalidIVFParamsMax() map[string]string {
invalidIVFParams := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(MaxNList + 1),
Metric: L2,
}
return invalidIVFParams
}
func copyParams(original map[string]string) map[string]string {
result := make(map[string]string)
for key, value := range original {
result[key] = value
}
return result
}

View File

@ -0,0 +1,19 @@
package indexparamcheck
type ivfBaseChecker struct {
floatVectorBaseChecker
}
func (c *ivfBaseChecker) CheckTrain(params map[string]string) error {
if !CheckIntByRange(params, NLIST, MinNList, MaxNList) {
return errOutOfRange(NLIST, MinNList, MaxNList)
}
// skip check number of rows
return c.floatVectorBaseChecker.CheckTrain(params)
}
func newIVFBaseChecker() IndexChecker {
return &ivfBaseChecker{}
}

View File

@ -0,0 +1,157 @@
package indexparamcheck
import (
"strconv"
"testing"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/stretchr/testify/assert"
)
func Test_ivfBaseChecker_CheckTrain(t *testing.T) {
validParams := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
Metric: L2,
}
p1 := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
Metric: L2,
}
p2 := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
Metric: IP,
}
p3 := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
Metric: COSINE,
}
p4 := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
Metric: HAMMING,
}
p5 := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
Metric: JACCARD,
}
p6 := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
Metric: TANIMOTO,
}
p7 := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
Metric: SUBSTRUCTURE,
}
p8 := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
Metric: SUPERSTRUCTURE,
}
cases := []struct {
params map[string]string
errIsNil bool
}{
{validParams, true},
{invalidIVFParamsMin(), false},
{invalidIVFParamsMax(), false},
{p1, true},
{p2, true},
{p3, true},
{p4, false},
{p5, false},
{p6, false},
{p7, false},
{p8, false},
}
c := newIVFBaseChecker()
for _, test := range cases {
err := c.CheckTrain(test.params)
if test.errIsNil {
assert.NoError(t, err)
} else {
assert.Error(t, err)
}
}
}
func Test_ivfBaseChecker_CheckValidDataType(t *testing.T) {
cases := []struct {
dType schemapb.DataType
errIsNil bool
}{
{
dType: schemapb.DataType_Bool,
errIsNil: false,
},
{
dType: schemapb.DataType_Int8,
errIsNil: false,
},
{
dType: schemapb.DataType_Int16,
errIsNil: false,
},
{
dType: schemapb.DataType_Int32,
errIsNil: false,
},
{
dType: schemapb.DataType_Int64,
errIsNil: false,
},
{
dType: schemapb.DataType_Float,
errIsNil: false,
},
{
dType: schemapb.DataType_Double,
errIsNil: false,
},
{
dType: schemapb.DataType_String,
errIsNil: false,
},
{
dType: schemapb.DataType_VarChar,
errIsNil: false,
},
{
dType: schemapb.DataType_Array,
errIsNil: false,
},
{
dType: schemapb.DataType_JSON,
errIsNil: false,
},
{
dType: schemapb.DataType_FloatVector,
errIsNil: true,
},
{
dType: schemapb.DataType_BinaryVector,
errIsNil: false,
},
}
c := newIVFBaseChecker()
for _, test := range cases {
err := c.CheckValidDataType(test.dType)
if test.errIsNil {
assert.NoError(t, err)
} else {
assert.Error(t, err)
}
}
}

View File

@ -0,0 +1,63 @@
package indexparamcheck
import (
"fmt"
"strconv"
)
// ivfPQChecker checks if a IVF_PQ index can be built.
type ivfPQChecker struct {
ivfBaseChecker
}
// CheckTrain checks if ivf-pq index can be built with the specific index parameters.
func (c *ivfPQChecker) CheckTrain(params map[string]string) error {
if err := c.ivfBaseChecker.CheckTrain(params); err != nil {
return err
}
return c.checkPQParams(params)
}
func (c *ivfPQChecker) checkPQParams(params map[string]string) error {
dimStr, dimensionExist := params[DIM]
if !dimensionExist {
return fmt.Errorf("dimension not found")
}
dimension, err := strconv.Atoi(dimStr)
if err != nil { // invalid dimension
return fmt.Errorf("invalid dimension: %s", dimStr)
}
// nbits can be set to default: 8
nbitsStr, nbitsExist := params[NBITS]
if nbitsExist {
_, err := strconv.Atoi(nbitsStr)
if err != nil { // invalid nbits
return fmt.Errorf("invalid nbits: %s", nbitsStr)
}
}
mStr, ok := params[IVFM]
if !ok {
return fmt.Errorf("parameter `m` not found")
}
m, err := strconv.Atoi(mStr)
if err != nil || m == 0 { // invalid m
return fmt.Errorf("invalid `m`: %s", mStr)
}
return c.checkCPUPQParams(dimension, m)
}
func (c *ivfPQChecker) checkCPUPQParams(dimension, m int) error {
if (dimension % m) != 0 {
return fmt.Errorf("dimension must be abled to be divided by `m`, dimension: %d, m: %d", dimension, m)
}
return nil
}
func newIVFPQChecker() IndexChecker {
return &ivfPQChecker{}
}

View File

@ -0,0 +1,229 @@
package indexparamcheck
import (
"strconv"
"testing"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/stretchr/testify/assert"
)
func Test_ivfPQChecker_CheckTrain(t *testing.T) {
validParams := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
IVFM: strconv.Itoa(4),
NBITS: strconv.Itoa(8),
Metric: L2,
}
paramsNotMultiplier := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
IVFM: strconv.Itoa(5),
NBITS: strconv.Itoa(8),
Metric: L2,
}
validParamsWithoutNbits := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
IVFM: strconv.Itoa(4),
Metric: L2,
}
validParamsWithoutDim := map[string]string{
NLIST: strconv.Itoa(1024),
IVFM: strconv.Itoa(4),
NBITS: strconv.Itoa(8),
Metric: L2,
}
invalidParamsDim := copyParams(validParams)
invalidParamsDim[DIM] = "NAN"
invalidParamsNbits := copyParams(validParams)
invalidParamsNbits[NBITS] = "NAN"
invalidParamsWithoutIVF := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
NBITS: strconv.Itoa(8),
Metric: L2,
}
invalidParamsIVF := copyParams(validParams)
invalidParamsIVF[IVFM] = "NAN"
invalidParamsM := copyParams(validParams)
invalidParamsM[DIM] = strconv.Itoa(65536)
invalidParamsMzero := copyParams(validParams)
invalidParamsMzero[IVFM] = "0"
p1 := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
IVFM: strconv.Itoa(4),
NBITS: strconv.Itoa(8),
Metric: L2,
}
p2 := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
IVFM: strconv.Itoa(4),
NBITS: strconv.Itoa(8),
Metric: IP,
}
p3 := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
IVFM: strconv.Itoa(4),
NBITS: strconv.Itoa(8),
Metric: COSINE,
}
p4 := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
IVFM: strconv.Itoa(4),
NBITS: strconv.Itoa(8),
Metric: HAMMING,
}
p5 := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
IVFM: strconv.Itoa(4),
NBITS: strconv.Itoa(8),
Metric: JACCARD,
}
p6 := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
IVFM: strconv.Itoa(4),
NBITS: strconv.Itoa(8),
Metric: TANIMOTO,
}
p7 := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
IVFM: strconv.Itoa(4),
NBITS: strconv.Itoa(8),
Metric: SUBSTRUCTURE,
}
p8 := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
IVFM: strconv.Itoa(4),
NBITS: strconv.Itoa(8),
Metric: SUPERSTRUCTURE,
}
cases := []struct {
params map[string]string
errIsNil bool
}{
{validParams, true},
{paramsNotMultiplier, false},
{validParamsWithoutNbits, true},
{invalidIVFParamsMin(), false},
{invalidIVFParamsMax(), false},
{validParamsWithoutDim, false},
{invalidParamsDim, false},
{invalidParamsNbits, false},
{invalidParamsWithoutIVF, false},
{invalidParamsIVF, false},
{invalidParamsM, false},
{invalidParamsMzero, false},
{p1, true},
{p2, true},
{p3, true},
{p4, false},
{p5, false},
{p6, false},
{p7, false},
{p8, false},
}
c := newIVFPQChecker()
for _, test := range cases {
err := c.CheckTrain(test.params)
if test.errIsNil {
assert.NoError(t, err)
} else {
assert.Error(t, err)
}
}
}
func Test_ivfPQChecker_CheckValidDataType(t *testing.T) {
cases := []struct {
dType schemapb.DataType
errIsNil bool
}{
{
dType: schemapb.DataType_Bool,
errIsNil: false,
},
{
dType: schemapb.DataType_Int8,
errIsNil: false,
},
{
dType: schemapb.DataType_Int16,
errIsNil: false,
},
{
dType: schemapb.DataType_Int32,
errIsNil: false,
},
{
dType: schemapb.DataType_Int64,
errIsNil: false,
},
{
dType: schemapb.DataType_Float,
errIsNil: false,
},
{
dType: schemapb.DataType_Double,
errIsNil: false,
},
{
dType: schemapb.DataType_String,
errIsNil: false,
},
{
dType: schemapb.DataType_VarChar,
errIsNil: false,
},
{
dType: schemapb.DataType_Array,
errIsNil: false,
},
{
dType: schemapb.DataType_JSON,
errIsNil: false,
},
{
dType: schemapb.DataType_FloatVector,
errIsNil: true,
},
{
dType: schemapb.DataType_BinaryVector,
errIsNil: false,
},
}
c := newIVFPQChecker()
for _, test := range cases {
err := c.CheckValidDataType(test.dType)
if test.errIsNil {
assert.NoError(t, err)
} else {
assert.Error(t, err)
}
}
}

View File

@ -0,0 +1,34 @@
package indexparamcheck
import (
"fmt"
)
// ivfSQChecker checks if a IVF_SQ index can be built.
type ivfSQChecker struct {
ivfBaseChecker
}
func (c *ivfSQChecker) checkNBits(params map[string]string) error {
// cgo will set this key to DefaultNBits (8), which is the only value Milvus supports.
_, exist := params[NBITS]
if exist {
// 8 is the only supported nbits.
if !CheckIntByRange(params, NBITS, DefaultNBits, DefaultNBits) {
return fmt.Errorf("nbits can be only set to 8 for IVF_SQ")
}
}
return nil
}
// CheckTrain returns true if the index can be built with the specific index parameters.
func (c *ivfSQChecker) CheckTrain(params map[string]string) error {
if err := c.checkNBits(params); err != nil {
return err
}
return c.ivfBaseChecker.CheckTrain(params)
}
func newIVFSQChecker() IndexChecker {
return &ivfSQChecker{}
}

View File

@ -0,0 +1,178 @@
package indexparamcheck
import (
"strconv"
"testing"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/stretchr/testify/assert"
)
func Test_ivfSQChecker_CheckTrain(t *testing.T) {
getValidParams := func(withNBits bool) map[string]string {
validParams := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(100),
NBITS: strconv.Itoa(8),
Metric: L2,
}
if withNBits {
validParams[NBITS] = strconv.Itoa(DefaultNBits)
}
return validParams
}
validParams := getValidParams(false)
validParamsWithNBits := getValidParams(true)
paramsWithInvalidNBits := getValidParams(false)
paramsWithInvalidNBits[NBITS] = strconv.Itoa(DefaultNBits + 1)
p1 := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(100),
NBITS: strconv.Itoa(8),
Metric: L2,
}
p2 := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(100),
NBITS: strconv.Itoa(8),
Metric: IP,
}
p3 := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(100),
NBITS: strconv.Itoa(8),
Metric: COSINE,
}
p4 := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(100),
NBITS: strconv.Itoa(8),
Metric: HAMMING,
}
p5 := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(100),
NBITS: strconv.Itoa(8),
Metric: JACCARD,
}
p6 := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(100),
NBITS: strconv.Itoa(8),
Metric: TANIMOTO,
}
p7 := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(100),
NBITS: strconv.Itoa(8),
Metric: SUBSTRUCTURE,
}
p8 := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(100),
NBITS: strconv.Itoa(8),
Metric: SUPERSTRUCTURE,
}
cases := []struct {
params map[string]string
errIsNil bool
}{
{validParams, true},
{validParamsWithNBits, true},
{paramsWithInvalidNBits, false},
{invalidIVFParamsMin(), false},
{invalidIVFParamsMax(), false},
{p1, true},
{p2, true},
{p3, true},
{p4, false},
{p5, false},
{p6, false},
{p7, false},
{p8, false},
}
c := newIVFSQChecker()
for _, test := range cases {
err := c.CheckTrain(test.params)
if test.errIsNil {
assert.NoError(t, err)
} else {
assert.Error(t, err)
}
}
}
func Test_ivfSQChecker_CheckValidDataType(t *testing.T) {
cases := []struct {
dType schemapb.DataType
errIsNil bool
}{
{
dType: schemapb.DataType_Bool,
errIsNil: false,
},
{
dType: schemapb.DataType_Int8,
errIsNil: false,
},
{
dType: schemapb.DataType_Int16,
errIsNil: false,
},
{
dType: schemapb.DataType_Int32,
errIsNil: false,
},
{
dType: schemapb.DataType_Int64,
errIsNil: false,
},
{
dType: schemapb.DataType_Float,
errIsNil: false,
},
{
dType: schemapb.DataType_Double,
errIsNil: false,
},
{
dType: schemapb.DataType_String,
errIsNil: false,
},
{
dType: schemapb.DataType_VarChar,
errIsNil: false,
},
{
dType: schemapb.DataType_Array,
errIsNil: false,
},
{
dType: schemapb.DataType_JSON,
errIsNil: false,
},
{
dType: schemapb.DataType_FloatVector,
errIsNil: true,
},
{
dType: schemapb.DataType_BinaryVector,
errIsNil: false,
},
}
c := newIVFSQChecker()
for _, test := range cases {
err := c.CheckValidDataType(test.dType)
if test.errIsNil {
assert.NoError(t, err)
} else {
assert.Error(t, err)
}
}
}

View File

@ -0,0 +1,63 @@
package indexparamcheck
import (
"fmt"
"strconv"
)
// raftIVFPQChecker checks if a RAFT_IVF_PQ index can be built.
type raftIVFPQChecker struct {
ivfBaseChecker
}
// CheckTrain checks if ivf-pq index can be built with the specific index parameters.
func (c *raftIVFPQChecker) CheckTrain(params map[string]string) error {
if err := c.ivfBaseChecker.CheckTrain(params); err != nil {
return err
}
return c.checkPQParams(params)
}
func (c *raftIVFPQChecker) checkPQParams(params map[string]string) error {
dimStr, dimensionExist := params[DIM]
if !dimensionExist {
return fmt.Errorf("dimension not found")
}
dimension, err := strconv.Atoi(dimStr)
if err != nil { // invalid dimension
return fmt.Errorf("invalid dimension: %s", dimStr)
}
// nbits can be set to default: 8
nbitsStr, nbitsExist := params[NBITS]
if nbitsExist {
_, err := strconv.Atoi(nbitsStr)
if err != nil { // invalid nbits
return fmt.Errorf("invalid nbits: %s", nbitsStr)
}
}
mStr, ok := params[IVFM]
if !ok {
return fmt.Errorf("parameter `m` not found")
}
m, err := strconv.Atoi(mStr)
if err != nil { // invalid m
return fmt.Errorf("invalid `m`: %s", mStr)
}
// here is the only difference with IVF_PQ
if m == 0 {
return nil
}
if dimension%m != 0 {
return fmt.Errorf("dimension must be abled to be divided by `m`, dimension: %d, m: %d", dimension, m)
}
return nil
}
func newRaftIVFPQChecker() IndexChecker {
return &raftIVFPQChecker{}
}

View File

@ -0,0 +1,220 @@
package indexparamcheck
import (
"strconv"
"testing"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/stretchr/testify/assert"
)
func Test_raftIVFPQChecker_CheckTrain(t *testing.T) {
validParams := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
IVFM: strconv.Itoa(4),
NBITS: strconv.Itoa(8),
Metric: L2,
}
validParamsWithoutNbits := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
IVFM: strconv.Itoa(4),
Metric: L2,
}
validParamsWithoutDim := map[string]string{
NLIST: strconv.Itoa(1024),
IVFM: strconv.Itoa(4),
NBITS: strconv.Itoa(8),
Metric: L2,
}
invalidParamsDim := copyParams(validParams)
invalidParamsDim[DIM] = "NAN"
invalidParamsNbits := copyParams(validParams)
invalidParamsNbits[NBITS] = "NAN"
invalidParamsWithoutIVF := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
NBITS: strconv.Itoa(8),
Metric: L2,
}
invalidParamsIVF := copyParams(validParams)
invalidParamsIVF[IVFM] = "NAN"
invalidParamsM := copyParams(validParams)
invalidParamsM[DIM] = strconv.Itoa(65536)
validParamsMzero := copyParams(validParams)
validParamsMzero[IVFM] = "0"
p1 := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
IVFM: strconv.Itoa(4),
NBITS: strconv.Itoa(8),
Metric: L2,
}
p2 := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
IVFM: strconv.Itoa(4),
NBITS: strconv.Itoa(8),
Metric: IP,
}
p3 := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
IVFM: strconv.Itoa(4),
NBITS: strconv.Itoa(8),
Metric: COSINE,
}
p4 := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
IVFM: strconv.Itoa(4),
NBITS: strconv.Itoa(8),
Metric: HAMMING,
}
p5 := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
IVFM: strconv.Itoa(4),
NBITS: strconv.Itoa(8),
Metric: JACCARD,
}
p6 := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
IVFM: strconv.Itoa(4),
NBITS: strconv.Itoa(8),
Metric: TANIMOTO,
}
p7 := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
IVFM: strconv.Itoa(4),
NBITS: strconv.Itoa(8),
Metric: SUBSTRUCTURE,
}
p8 := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
IVFM: strconv.Itoa(4),
NBITS: strconv.Itoa(8),
Metric: SUPERSTRUCTURE,
}
cases := []struct {
params map[string]string
errIsNil bool
}{
{validParams, true},
{validParamsWithoutNbits, true},
{invalidIVFParamsMin(), false},
{invalidIVFParamsMax(), false},
{validParamsWithoutDim, false},
{invalidParamsDim, false},
{invalidParamsNbits, false},
{invalidParamsWithoutIVF, false},
{invalidParamsIVF, false},
{invalidParamsM, false},
{validParamsMzero, true},
{p1, true},
{p2, true},
{p3, true},
{p4, false},
{p5, false},
{p6, false},
{p7, false},
{p8, false},
}
c := newRaftIVFPQChecker()
for _, test := range cases {
err := c.CheckTrain(test.params)
if test.errIsNil {
assert.NoError(t, err)
} else {
assert.Error(t, err)
}
}
}
func Test_raftIVFPQChecker_CheckValidDataType(t *testing.T) {
cases := []struct {
dType schemapb.DataType
errIsNil bool
}{
{
dType: schemapb.DataType_Bool,
errIsNil: false,
},
{
dType: schemapb.DataType_Int8,
errIsNil: false,
},
{
dType: schemapb.DataType_Int16,
errIsNil: false,
},
{
dType: schemapb.DataType_Int32,
errIsNil: false,
},
{
dType: schemapb.DataType_Int64,
errIsNil: false,
},
{
dType: schemapb.DataType_Float,
errIsNil: false,
},
{
dType: schemapb.DataType_Double,
errIsNil: false,
},
{
dType: schemapb.DataType_String,
errIsNil: false,
},
{
dType: schemapb.DataType_VarChar,
errIsNil: false,
},
{
dType: schemapb.DataType_Array,
errIsNil: false,
},
{
dType: schemapb.DataType_JSON,
errIsNil: false,
},
{
dType: schemapb.DataType_FloatVector,
errIsNil: true,
},
{
dType: schemapb.DataType_BinaryVector,
errIsNil: false,
},
}
c := newRaftIVFPQChecker()
for _, test := range cases {
err := c.CheckValidDataType(test.dType)
if test.errIsNil {
assert.NoError(t, err)
} else {
assert.Error(t, err)
}
}
}

View File

@ -17,6 +17,7 @@
package indexparamcheck
import (
"fmt"
"strconv"
"github.com/milvus-io/milvus/pkg/util/funcutil"
@ -57,3 +58,7 @@ func CheckStrByValues(params map[string]string, key string, container []string)
return funcutil.SliceContain(container, value)
}
func errOutOfRange(x interface{}, lb interface{}, ub interface{}) error {
return fmt.Errorf("%v out of range: [%v, %v]", x, lb, ub)
}