mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-11-30 02:48:45 +08:00
Refactor check logic of index parameters (#23856)
Signed-off-by: longjiquan <jiquan.long@zilliz.com>
This commit is contained in:
parent
899702f13c
commit
7be7e6f360
@ -1645,7 +1645,7 @@ func (node *Proxy) CreateIndex(ctx context.Context, request *milvuspb.CreateInde
|
||||
zap.String("field", request.FieldName),
|
||||
zap.Any("extra_params", request.ExtraParams))
|
||||
|
||||
log.Debug(rpcReceived(method))
|
||||
log.Info(rpcReceived(method))
|
||||
|
||||
if err := node.sched.ddQueue.Enqueue(cit); err != nil {
|
||||
log.Warn(
|
||||
@ -1661,7 +1661,7 @@ func (node *Proxy) CreateIndex(ctx context.Context, request *milvuspb.CreateInde
|
||||
}, nil
|
||||
}
|
||||
|
||||
log.Debug(
|
||||
log.Info(
|
||||
rpcEnqueued(method),
|
||||
zap.Uint64("BeginTs", cit.BeginTs()),
|
||||
zap.Uint64("EndTs", cit.EndTs()))
|
||||
@ -1682,7 +1682,7 @@ func (node *Proxy) CreateIndex(ctx context.Context, request *milvuspb.CreateInde
|
||||
}, nil
|
||||
}
|
||||
|
||||
log.Debug(
|
||||
log.Info(
|
||||
rpcDone(method),
|
||||
zap.Uint64("BeginTs", cit.BeginTs()),
|
||||
zap.Uint64("EndTs", cit.EndTs()))
|
||||
|
@ -242,9 +242,9 @@ func checkTrain(field *schemapb.FieldSchema, indexParams map[string]string) erro
|
||||
return indexparamcheck.CheckIndexValid(field.GetDataType(), indexType, indexParams)
|
||||
}
|
||||
|
||||
adapter, err := indexparamcheck.GetConfAdapterMgrInstance().GetAdapter(indexType)
|
||||
checker, err := indexparamcheck.GetIndexCheckerMgrInstance().GetChecker(indexType)
|
||||
if err != nil {
|
||||
log.Warn("Failed to get conf adapter", zap.String(common.IndexTypeKey, indexType))
|
||||
log.Warn("Failed to get index checker", zap.String(common.IndexTypeKey, indexType))
|
||||
return fmt.Errorf("invalid index type: %s", indexType)
|
||||
}
|
||||
|
||||
@ -252,16 +252,14 @@ func checkTrain(field *schemapb.FieldSchema, indexParams map[string]string) erro
|
||||
return err
|
||||
}
|
||||
|
||||
ok := adapter.CheckValidDataType(field.GetDataType())
|
||||
if !ok {
|
||||
log.Warn("Field data type don't support the index build type", zap.String("fieldDataType", field.GetDataType().String()), zap.String("indexType", indexType))
|
||||
return fmt.Errorf("field data type %s don't support the index build type %s", field.GetDataType().String(), indexType)
|
||||
if err := checker.CheckValidDataType(field.GetDataType()); err != nil {
|
||||
log.Info("create index with invalid data type", zap.Error(err), zap.String("data_type", field.GetDataType().String()))
|
||||
return err
|
||||
}
|
||||
|
||||
ok = adapter.CheckTrain(indexParams)
|
||||
if !ok {
|
||||
log.Warn("Create index with invalid params", zap.Any("index_params", indexParams))
|
||||
return fmt.Errorf("invalid index params: %v", indexParams)
|
||||
if err := checker.CheckTrain(indexParams); err != nil {
|
||||
log.Info("create index with invalid parameters", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
|
25
pkg/util/indexparamcheck/base_checker.go
Normal file
25
pkg/util/indexparamcheck/base_checker.go
Normal file
@ -0,0 +1,25 @@
|
||||
package indexparamcheck
|
||||
|
||||
import (
|
||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
)
|
||||
|
||||
type baseChecker struct {
|
||||
}
|
||||
|
||||
func (c *baseChecker) CheckTrain(params map[string]string) error {
|
||||
if !CheckIntByRange(params, DIM, DefaultMinDim, DefaultMaxDim) {
|
||||
return errOutOfRange(DIM, DefaultMinDim, DefaultMaxDim)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CheckValidDataType check whether the field data type is supported for the index type
|
||||
func (c *baseChecker) CheckValidDataType(dType schemapb.DataType) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func newBaseChecker() IndexChecker {
|
||||
return &baseChecker{}
|
||||
}
|
107
pkg/util/indexparamcheck/base_checker_test.go
Normal file
107
pkg/util/indexparamcheck/base_checker_test.go
Normal file
@ -0,0 +1,107 @@
|
||||
package indexparamcheck
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_baseChecker_CheckTrain(t *testing.T) {
|
||||
validParams := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
Metric: L2,
|
||||
}
|
||||
paramsWithoutDim := map[string]string{
|
||||
Metric: L2,
|
||||
}
|
||||
cases := []struct {
|
||||
params map[string]string
|
||||
errIsNil bool
|
||||
}{
|
||||
{validParams, true},
|
||||
{paramsWithoutDim, false},
|
||||
}
|
||||
|
||||
c := newBaseChecker()
|
||||
for _, test := range cases {
|
||||
err := c.CheckTrain(test.params)
|
||||
if test.errIsNil {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
assert.Error(t, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test_baseChecker_CheckValidDataType(t *testing.T) {
|
||||
cases := []struct {
|
||||
dType schemapb.DataType
|
||||
errIsNil bool
|
||||
}{
|
||||
{
|
||||
dType: schemapb.DataType_Bool,
|
||||
errIsNil: true,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Int8,
|
||||
errIsNil: true,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Int16,
|
||||
errIsNil: true,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Int32,
|
||||
errIsNil: true,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Int64,
|
||||
errIsNil: true,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Float,
|
||||
errIsNil: true,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Double,
|
||||
errIsNil: true,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_String,
|
||||
errIsNil: true,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_VarChar,
|
||||
errIsNil: true,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Array,
|
||||
errIsNil: true,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_JSON,
|
||||
errIsNil: true,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_FloatVector,
|
||||
errIsNil: true,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_BinaryVector,
|
||||
errIsNil: true,
|
||||
},
|
||||
}
|
||||
|
||||
c := newBaseChecker()
|
||||
for _, test := range cases {
|
||||
err := c.CheckValidDataType(test.dType)
|
||||
if test.errIsNil {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
assert.Error(t, err)
|
||||
}
|
||||
}
|
||||
}
|
14
pkg/util/indexparamcheck/bin_flat_checker.go
Normal file
14
pkg/util/indexparamcheck/bin_flat_checker.go
Normal file
@ -0,0 +1,14 @@
|
||||
package indexparamcheck
|
||||
|
||||
type binFlatChecker struct {
|
||||
binaryVectorBaseChecker
|
||||
}
|
||||
|
||||
// CheckTrain checks if a binary flat index can be built with the specific parameters.
|
||||
func (c *binFlatChecker) CheckTrain(params map[string]string) error {
|
||||
return c.binaryVectorBaseChecker.CheckTrain(params)
|
||||
}
|
||||
|
||||
func newBinFlatChecker() IndexChecker {
|
||||
return &binFlatChecker{}
|
||||
}
|
151
pkg/util/indexparamcheck/bin_flat_checker_test.go
Normal file
151
pkg/util/indexparamcheck/bin_flat_checker_test.go
Normal file
@ -0,0 +1,151 @@
|
||||
package indexparamcheck
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_binFlatChecker_CheckTrain(t *testing.T) {
|
||||
validParams := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
Metric: JACCARD,
|
||||
}
|
||||
paramsWithoutDim := map[string]string{
|
||||
Metric: JACCARD,
|
||||
}
|
||||
|
||||
p1 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
Metric: L2,
|
||||
}
|
||||
p2 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
Metric: IP,
|
||||
}
|
||||
p3 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
Metric: COSINE,
|
||||
}
|
||||
|
||||
p4 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
Metric: HAMMING,
|
||||
}
|
||||
p5 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
Metric: JACCARD,
|
||||
}
|
||||
p6 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
Metric: TANIMOTO,
|
||||
}
|
||||
p7 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
Metric: SUBSTRUCTURE,
|
||||
}
|
||||
p8 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
Metric: SUPERSTRUCTURE,
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
params map[string]string
|
||||
errIsNil bool
|
||||
}{
|
||||
{validParams, true},
|
||||
{paramsWithoutDim, false},
|
||||
{p1, false},
|
||||
{p2, false},
|
||||
{p3, false},
|
||||
{p4, true},
|
||||
{p5, true},
|
||||
{p6, true},
|
||||
{p7, true},
|
||||
{p8, true},
|
||||
}
|
||||
|
||||
c := newBinFlatChecker()
|
||||
for _, test := range cases {
|
||||
err := c.CheckTrain(test.params)
|
||||
if test.errIsNil {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
assert.Error(t, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test_binFlatChecker_CheckValidDataType(t *testing.T) {
|
||||
|
||||
cases := []struct {
|
||||
dType schemapb.DataType
|
||||
errIsNil bool
|
||||
}{
|
||||
{
|
||||
dType: schemapb.DataType_Bool,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Int8,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Int16,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Int32,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Int64,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Float,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Double,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_String,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_VarChar,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Array,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_JSON,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_FloatVector,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_BinaryVector,
|
||||
errIsNil: true,
|
||||
},
|
||||
}
|
||||
|
||||
c := newBinFlatChecker()
|
||||
for _, test := range cases {
|
||||
err := c.CheckValidDataType(test.dType)
|
||||
if test.errIsNil {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
assert.Error(t, err)
|
||||
}
|
||||
}
|
||||
}
|
29
pkg/util/indexparamcheck/bin_ivf_flat_checker.go
Normal file
29
pkg/util/indexparamcheck/bin_ivf_flat_checker.go
Normal file
@ -0,0 +1,29 @@
|
||||
package indexparamcheck
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type binIVFFlatChecker struct {
|
||||
binaryVectorBaseChecker
|
||||
}
|
||||
|
||||
func (c *binIVFFlatChecker) CheckTrain(params map[string]string) error {
|
||||
if err := c.binaryVectorBaseChecker.CheckTrain(params); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !CheckStrByValues(params, Metric, BinIvfMetrics) {
|
||||
return fmt.Errorf("metric type not found or not supported, supported: %v", BinIvfMetrics)
|
||||
}
|
||||
|
||||
if !CheckIntByRange(params, NLIST, MinNList, MaxNList) {
|
||||
return errOutOfRange(NLIST, MinNList, MaxNList)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func newBinIVFFlatChecker() IndexChecker {
|
||||
return &binIVFFlatChecker{}
|
||||
}
|
207
pkg/util/indexparamcheck/bin_ivf_flat_checker_test.go
Normal file
207
pkg/util/indexparamcheck/bin_ivf_flat_checker_test.go
Normal file
@ -0,0 +1,207 @@
|
||||
package indexparamcheck
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_binIVFFlatChecker_CheckTrain(t *testing.T) {
|
||||
validParams := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(100),
|
||||
IVFM: strconv.Itoa(4),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: JACCARD,
|
||||
}
|
||||
paramsWithoutDim := map[string]string{
|
||||
NLIST: strconv.Itoa(100),
|
||||
IVFM: strconv.Itoa(4),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: JACCARD,
|
||||
}
|
||||
|
||||
invalidParams := copyParams(validParams)
|
||||
invalidParams[Metric] = L2
|
||||
|
||||
paramsWithLargeNlist := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(MaxNList + 1),
|
||||
IVFM: strconv.Itoa(4),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: JACCARD,
|
||||
}
|
||||
|
||||
paramsWithSmallNlist := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(MinNList - 1),
|
||||
IVFM: strconv.Itoa(4),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: JACCARD,
|
||||
}
|
||||
|
||||
p1 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
Metric: L2,
|
||||
NLIST: strconv.Itoa(100),
|
||||
IVFM: strconv.Itoa(4),
|
||||
NBITS: strconv.Itoa(8),
|
||||
}
|
||||
p2 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
Metric: IP,
|
||||
NLIST: strconv.Itoa(100),
|
||||
IVFM: strconv.Itoa(4),
|
||||
NBITS: strconv.Itoa(8),
|
||||
}
|
||||
p3 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
Metric: COSINE,
|
||||
NLIST: strconv.Itoa(100),
|
||||
IVFM: strconv.Itoa(4),
|
||||
NBITS: strconv.Itoa(8),
|
||||
}
|
||||
|
||||
p4 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
Metric: HAMMING,
|
||||
NLIST: strconv.Itoa(100),
|
||||
IVFM: strconv.Itoa(4),
|
||||
NBITS: strconv.Itoa(8),
|
||||
}
|
||||
p5 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
Metric: JACCARD,
|
||||
NLIST: strconv.Itoa(100),
|
||||
IVFM: strconv.Itoa(4),
|
||||
NBITS: strconv.Itoa(8),
|
||||
}
|
||||
p6 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
Metric: TANIMOTO,
|
||||
NLIST: strconv.Itoa(100),
|
||||
IVFM: strconv.Itoa(4),
|
||||
NBITS: strconv.Itoa(8),
|
||||
}
|
||||
|
||||
p7 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
Metric: SUBSTRUCTURE,
|
||||
NLIST: strconv.Itoa(100),
|
||||
IVFM: strconv.Itoa(4),
|
||||
NBITS: strconv.Itoa(8),
|
||||
}
|
||||
p8 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
Metric: SUPERSTRUCTURE,
|
||||
NLIST: strconv.Itoa(100),
|
||||
IVFM: strconv.Itoa(4),
|
||||
NBITS: strconv.Itoa(8),
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
params map[string]string
|
||||
errIsNil bool
|
||||
}{
|
||||
{validParams, true},
|
||||
{paramsWithoutDim, false},
|
||||
{paramsWithLargeNlist, false},
|
||||
{paramsWithSmallNlist, false},
|
||||
{invalidParams, false},
|
||||
|
||||
{p1, false},
|
||||
{p2, false},
|
||||
{p3, false},
|
||||
|
||||
{p4, true},
|
||||
{p5, true},
|
||||
{p6, true},
|
||||
|
||||
{p7, false},
|
||||
{p8, false},
|
||||
}
|
||||
|
||||
c := newBinIVFFlatChecker()
|
||||
for _, test := range cases {
|
||||
err := c.CheckTrain(test.params)
|
||||
if test.errIsNil {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
assert.Error(t, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test_binIVFFlatChecker_CheckValidDataType(t *testing.T) {
|
||||
|
||||
cases := []struct {
|
||||
dType schemapb.DataType
|
||||
errIsNil bool
|
||||
}{
|
||||
{
|
||||
dType: schemapb.DataType_Bool,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Int8,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Int16,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Int32,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Int64,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Float,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Double,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_String,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_VarChar,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Array,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_JSON,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_FloatVector,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_BinaryVector,
|
||||
errIsNil: true,
|
||||
},
|
||||
}
|
||||
|
||||
c := newBinIVFFlatChecker()
|
||||
for _, test := range cases {
|
||||
err := c.CheckValidDataType(test.dType)
|
||||
if test.errIsNil {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
assert.Error(t, err)
|
||||
}
|
||||
}
|
||||
}
|
34
pkg/util/indexparamcheck/binary_vector_base_checker.go
Normal file
34
pkg/util/indexparamcheck/binary_vector_base_checker.go
Normal file
@ -0,0 +1,34 @@
|
||||
package indexparamcheck
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
)
|
||||
|
||||
type binaryVectorBaseChecker struct {
|
||||
baseChecker
|
||||
}
|
||||
|
||||
func (c *binaryVectorBaseChecker) CheckTrain(params map[string]string) error {
|
||||
if err := c.baseChecker.CheckTrain(params); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !CheckStrByValues(params, Metric, BinIDMapMetrics) {
|
||||
return fmt.Errorf("metric type not found or not supported, supported: %v", BinIDMapMetrics)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *binaryVectorBaseChecker) CheckValidDataType(dType schemapb.DataType) error {
|
||||
if dType != schemapb.DataType_BinaryVector {
|
||||
return fmt.Errorf("binary vector is only supported")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func newBinaryVectorBaseChecker() IndexChecker {
|
||||
return &binaryVectorBaseChecker{}
|
||||
}
|
79
pkg/util/indexparamcheck/binary_vector_base_checker_test.go
Normal file
79
pkg/util/indexparamcheck/binary_vector_base_checker_test.go
Normal file
@ -0,0 +1,79 @@
|
||||
package indexparamcheck
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_binaryVectorBaseChecker_CheckValidDataType(t *testing.T) {
|
||||
|
||||
cases := []struct {
|
||||
dType schemapb.DataType
|
||||
errIsNil bool
|
||||
}{
|
||||
{
|
||||
dType: schemapb.DataType_Bool,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Int8,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Int16,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Int32,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Int64,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Float,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Double,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_String,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_VarChar,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Array,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_JSON,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_FloatVector,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_BinaryVector,
|
||||
errIsNil: true,
|
||||
},
|
||||
}
|
||||
|
||||
c := newBinaryVectorBaseChecker()
|
||||
for _, test := range cases {
|
||||
err := c.CheckValidDataType(test.dType)
|
||||
if test.errIsNil {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
assert.Error(t, err)
|
||||
}
|
||||
}
|
||||
}
|
@ -1,381 +0,0 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package indexparamcheck
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
)
|
||||
|
||||
const (
|
||||
// L2 represents Euclidean distance
|
||||
L2 = "L2"
|
||||
|
||||
// IP represents inner product distance
|
||||
IP = "IP"
|
||||
|
||||
// COSINE represents cosine distance
|
||||
COSINE = "COSINE"
|
||||
|
||||
// HAMMING represents hamming distance
|
||||
HAMMING = "HAMMING"
|
||||
|
||||
// JACCARD represents jaccard distance
|
||||
JACCARD = "JACCARD"
|
||||
|
||||
// TANIMOTO represents tanimoto distance
|
||||
TANIMOTO = "TANIMOTO"
|
||||
|
||||
// SUBSTRUCTURE represents substructure distance
|
||||
SUBSTRUCTURE = "SUBSTRUCTURE"
|
||||
|
||||
// SUPERSTRUCTURE represents superstructure distance
|
||||
SUPERSTRUCTURE = "SUPERSTRUCTURE"
|
||||
|
||||
MinNBits = 1
|
||||
MaxNBits = 16
|
||||
DefaultNBits = 8
|
||||
|
||||
// MinNList is the lower limit of nlist that used in Index IVFxxx
|
||||
MinNList = 1
|
||||
// MaxNList is the upper limit of nlist that used in Index IVFxxx
|
||||
MaxNList = 65536
|
||||
|
||||
// DefaultMinDim is the smallest dimension supported in Milvus
|
||||
DefaultMinDim = 1
|
||||
// DefaultMaxDim is the largest dimension supported in Milvus
|
||||
DefaultMaxDim = 32768
|
||||
|
||||
// If Dim = 32 and raw vector data = 2G, query node need 24G disk space When loading the vectors' disk index
|
||||
// If Dim = 2, and raw vector data = 2G, query node need 240G disk space When loading the vectors' disk index
|
||||
// So DiskAnnMinDim should be greater than or equal to 32 to avoid running out of disk space
|
||||
DiskAnnMinDim = 32
|
||||
|
||||
HNSWMinEfConstruction = 8
|
||||
HNSWMaxEfConstruction = 512
|
||||
HNSWMinM = 4
|
||||
HNSWMaxM = 64
|
||||
|
||||
// DIM is a constant used to represent dimension
|
||||
DIM = "dim"
|
||||
// Metric is a constant used to metric type
|
||||
Metric = "metric_type"
|
||||
// NLIST is a constant used to nlist in Index IVFxxx
|
||||
NLIST = "nlist"
|
||||
NBITS = "nbits"
|
||||
IVFM = "m"
|
||||
|
||||
EFConstruction = "efConstruction"
|
||||
HNSWM = "M"
|
||||
)
|
||||
|
||||
// METRICS is a set of all metrics types supported for float vector.
|
||||
var METRICS = []string{L2, IP, COSINE} // const
|
||||
|
||||
// BinIDMapMetrics is a set of all metric types supported for binary vector.
|
||||
var BinIDMapMetrics = []string{HAMMING, JACCARD, TANIMOTO, SUBSTRUCTURE, SUPERSTRUCTURE} // const
|
||||
var BinIvfMetrics = []string{HAMMING, JACCARD, TANIMOTO} // const
|
||||
var supportDimPerSubQuantizer = []int{32, 28, 24, 20, 16, 12, 10, 8, 6, 4, 3, 2, 1} // const
|
||||
var supportSubQuantizer = []int{96, 64, 56, 48, 40, 32, 28, 24, 20, 16, 12, 8, 4, 3, 2, 1} // const
|
||||
|
||||
type ConfAdapter interface {
|
||||
// CheckTrain returns true if the index can be built with the specific index parameters.
|
||||
CheckTrain(map[string]string) bool
|
||||
CheckValidDataType(dType schemapb.DataType) bool
|
||||
}
|
||||
|
||||
// BaseConfAdapter checks if a `FLAT` index can be built.
|
||||
type BaseConfAdapter struct {
|
||||
}
|
||||
|
||||
// CheckTrain check whether the params contains supported metrics types
|
||||
func (adapter *BaseConfAdapter) CheckTrain(params map[string]string) bool {
|
||||
if !CheckIntByRange(params, DIM, DefaultMinDim, DefaultMaxDim) {
|
||||
return false
|
||||
}
|
||||
|
||||
return CheckStrByValues(params, Metric, METRICS)
|
||||
}
|
||||
|
||||
// CheckValidDataType check whether the field data type is supported for the index type
|
||||
func (adapter *BaseConfAdapter) CheckValidDataType(dType schemapb.DataType) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func newBaseConfAdapter() *BaseConfAdapter {
|
||||
return &BaseConfAdapter{}
|
||||
}
|
||||
|
||||
// IVFConfAdapter checks if a IVF index can be built.
|
||||
type IVFConfAdapter struct {
|
||||
BaseConfAdapter
|
||||
}
|
||||
|
||||
// CheckTrain returns true if the index can be built with the specific index parameters.
|
||||
func (adapter *IVFConfAdapter) CheckTrain(params map[string]string) bool {
|
||||
if !CheckIntByRange(params, NLIST, MinNList, MaxNList) {
|
||||
return false
|
||||
}
|
||||
|
||||
// skip check number of rows
|
||||
|
||||
return adapter.BaseConfAdapter.CheckTrain(params)
|
||||
}
|
||||
|
||||
func newIVFConfAdapter() *IVFConfAdapter {
|
||||
return &IVFConfAdapter{}
|
||||
}
|
||||
|
||||
// IVFPQConfAdapter checks if a IVF_PQ index can be built.
|
||||
type IVFPQConfAdapter struct {
|
||||
IVFConfAdapter
|
||||
}
|
||||
|
||||
// CheckTrain checks if ivf-pq index can be built with the specific index parameters.
|
||||
func (adapter *IVFPQConfAdapter) CheckTrain(params map[string]string) bool {
|
||||
if !adapter.IVFConfAdapter.CheckTrain(params) {
|
||||
return false
|
||||
}
|
||||
|
||||
return adapter.checkPQParams(params)
|
||||
}
|
||||
|
||||
func (adapter *IVFPQConfAdapter) checkPQParams(params map[string]string) bool {
|
||||
dimStr, dimensionExist := params[DIM]
|
||||
if !dimensionExist {
|
||||
return false
|
||||
}
|
||||
|
||||
dimension, err := strconv.Atoi(dimStr)
|
||||
if err != nil { // invalid dimension
|
||||
return false
|
||||
}
|
||||
|
||||
// nbits can be set to default: 8
|
||||
nbitsStr, nbitsExist := params[NBITS]
|
||||
if nbitsExist {
|
||||
_, err := strconv.Atoi(nbitsStr)
|
||||
if err != nil { // invalid nbits
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
mStr, ok := params[IVFM]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
m, err := strconv.Atoi(mStr)
|
||||
if err != nil || m == 0 { // invalid m
|
||||
return false
|
||||
}
|
||||
|
||||
return adapter.checkCPUPQParams(dimension, m)
|
||||
}
|
||||
|
||||
func (adapter *IVFPQConfAdapter) checkGPUPQParams(dimension, m, nbits int) bool {
|
||||
/*
|
||||
* Faiss 1.6
|
||||
* Only 1, 2, 3, 4, 6, 8, 10, 12, 16, 20, 24, 28, 32 dims per sub-quantizer are currently supported with
|
||||
* no precomputed codes. Precomputed codes supports any number of dimensions, but will involve memory overheads.
|
||||
*/
|
||||
|
||||
subDim := dimension / m
|
||||
return funcutil.SliceContain(supportSubQuantizer, m) && funcutil.SliceContain(supportDimPerSubQuantizer, subDim) && nbits == 8
|
||||
}
|
||||
|
||||
func (adapter *IVFPQConfAdapter) checkCPUPQParams(dimension, m int) bool {
|
||||
return (dimension % m) == 0
|
||||
}
|
||||
|
||||
func newIVFPQConfAdapter() *IVFPQConfAdapter {
|
||||
return &IVFPQConfAdapter{}
|
||||
}
|
||||
|
||||
// RaftIVFPQConfAdapter checks if a RAFT_IVF_PQ index can be built.
|
||||
type RaftIVFPQConfAdapter struct {
|
||||
IVFConfAdapter
|
||||
}
|
||||
|
||||
// CheckTrain checks if ivf-pq index can be built with the specific index parameters.
|
||||
func (adapter *RaftIVFPQConfAdapter) CheckTrain(params map[string]string) bool {
|
||||
if !adapter.IVFConfAdapter.CheckTrain(params) {
|
||||
return false
|
||||
}
|
||||
|
||||
return adapter.checkPQParams(params)
|
||||
}
|
||||
|
||||
func (adapter *RaftIVFPQConfAdapter) checkPQParams(params map[string]string) bool {
|
||||
dimStr, dimensionExist := params[DIM]
|
||||
if !dimensionExist {
|
||||
return false
|
||||
}
|
||||
|
||||
dimension, err := strconv.Atoi(dimStr)
|
||||
if err != nil { // invalid dimension
|
||||
return false
|
||||
}
|
||||
|
||||
// nbits can be set to default: 8
|
||||
nbitsStr, nbitsExist := params[NBITS]
|
||||
if nbitsExist {
|
||||
_, err := strconv.Atoi(nbitsStr)
|
||||
if err != nil { // invalid nbits
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
mStr, ok := params[IVFM]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
m, err := strconv.Atoi(mStr)
|
||||
if err != nil { // invalid m
|
||||
return false
|
||||
}
|
||||
|
||||
// here is the only difference with IVF_PQ
|
||||
if m == 0 {
|
||||
return true
|
||||
}
|
||||
return dimension%m == 0
|
||||
}
|
||||
|
||||
func newRaftIVFPQConfAdapter() *RaftIVFPQConfAdapter {
|
||||
return &RaftIVFPQConfAdapter{}
|
||||
}
|
||||
|
||||
// IVFSQConfAdapter checks if a IVF_SQ index can be built.
|
||||
type IVFSQConfAdapter struct {
|
||||
IVFConfAdapter
|
||||
}
|
||||
|
||||
func (adapter *IVFSQConfAdapter) checkNBits(params map[string]string) bool {
|
||||
// cgo will set this key to DefaultNBits (8), which is the only value Milvus supports.
|
||||
_, exist := params[NBITS]
|
||||
if exist {
|
||||
// 8 is the only supported nbits.
|
||||
return CheckIntByRange(params, NBITS, DefaultNBits, DefaultNBits)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// CheckTrain returns true if the index can be built with the specific index parameters.
|
||||
func (adapter *IVFSQConfAdapter) CheckTrain(params map[string]string) bool {
|
||||
if !adapter.checkNBits(params) {
|
||||
return false
|
||||
}
|
||||
return adapter.IVFConfAdapter.CheckTrain(params)
|
||||
}
|
||||
|
||||
func newIVFSQConfAdapter() *IVFSQConfAdapter {
|
||||
return &IVFSQConfAdapter{}
|
||||
}
|
||||
|
||||
type BinIDMAPConfAdapter struct {
|
||||
BaseConfAdapter
|
||||
}
|
||||
|
||||
// CheckTrain checks if a binary flat index can be built with the specific parameters.
|
||||
func (adapter *BinIDMAPConfAdapter) CheckTrain(params map[string]string) bool {
|
||||
if !CheckIntByRange(params, DIM, DefaultMinDim, DefaultMaxDim) {
|
||||
return false
|
||||
}
|
||||
|
||||
return CheckStrByValues(params, Metric, BinIDMapMetrics)
|
||||
}
|
||||
|
||||
func newBinIDMAPConfAdapter() *BinIDMAPConfAdapter {
|
||||
return &BinIDMAPConfAdapter{}
|
||||
}
|
||||
|
||||
// BinIVFConfAdapter checks if a bin IFV index can be built.
|
||||
type BinIVFConfAdapter struct {
|
||||
BaseConfAdapter
|
||||
}
|
||||
|
||||
// CheckTrain checks if a binary ivf index can be built with specific parameters.
|
||||
func (adapter *BinIVFConfAdapter) CheckTrain(params map[string]string) bool {
|
||||
if !CheckIntByRange(params, DIM, DefaultMinDim, DefaultMaxDim) {
|
||||
return false
|
||||
}
|
||||
|
||||
if !CheckIntByRange(params, NLIST, MinNList, MaxNList) {
|
||||
return false
|
||||
}
|
||||
|
||||
if !CheckStrByValues(params, Metric, BinIvfMetrics) {
|
||||
return false
|
||||
}
|
||||
|
||||
// skip checking the number of rows
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func newBinIVFConfAdapter() *BinIVFConfAdapter {
|
||||
return &BinIVFConfAdapter{}
|
||||
}
|
||||
|
||||
// HNSWConfAdapter checks if a hnsw index can be built.
|
||||
type HNSWConfAdapter struct {
|
||||
BaseConfAdapter
|
||||
}
|
||||
|
||||
// CheckTrain checks if a hnsw index can be built with specific parameters.
|
||||
func (adapter *HNSWConfAdapter) CheckTrain(params map[string]string) bool {
|
||||
if !CheckIntByRange(params, EFConstruction, HNSWMinEfConstruction, HNSWMaxEfConstruction) {
|
||||
return false
|
||||
}
|
||||
|
||||
if !CheckIntByRange(params, HNSWM, HNSWMinM, HNSWMaxM) {
|
||||
return false
|
||||
}
|
||||
|
||||
return adapter.BaseConfAdapter.CheckTrain(params)
|
||||
}
|
||||
|
||||
func newHNSWConfAdapter() *HNSWConfAdapter {
|
||||
return &HNSWConfAdapter{}
|
||||
}
|
||||
|
||||
// DISKANNConfAdapter checks if an diskann index can be built.
|
||||
type DISKANNConfAdapter struct {
|
||||
BaseConfAdapter
|
||||
}
|
||||
|
||||
func (adapter *DISKANNConfAdapter) CheckTrain(params map[string]string) bool {
|
||||
if !CheckIntByRange(params, DIM, DiskAnnMinDim, DefaultMaxDim) {
|
||||
return false
|
||||
}
|
||||
return adapter.BaseConfAdapter.CheckTrain(params)
|
||||
}
|
||||
|
||||
// CheckValidDataType check whether the field data type is supported for the index type
|
||||
func (adapter *DISKANNConfAdapter) CheckValidDataType(dType schemapb.DataType) bool {
|
||||
vecDataTypes := []schemapb.DataType{
|
||||
schemapb.DataType_FloatVector,
|
||||
}
|
||||
return funcutil.SliceContain(vecDataTypes, dType)
|
||||
}
|
||||
|
||||
func newDISKANNConfAdapter() *DISKANNConfAdapter {
|
||||
return &DISKANNConfAdapter{}
|
||||
}
|
@ -22,56 +22,53 @@ import (
|
||||
"github.com/cockroachdb/errors"
|
||||
)
|
||||
|
||||
// ConfAdapterMgr manages the conf adapter.
|
||||
type ConfAdapterMgr interface {
|
||||
// GetAdapter gets the conf adapter by the index type.
|
||||
GetAdapter(indexType string) (ConfAdapter, error)
|
||||
type IndexCheckerMgr interface {
|
||||
GetChecker(indexType string) (IndexChecker, error)
|
||||
}
|
||||
|
||||
// ConfAdapterMgrImpl implements ConfAdapter.
|
||||
type ConfAdapterMgrImpl struct {
|
||||
adapters map[IndexType]ConfAdapter
|
||||
// indexCheckerMgrImpl implements IndexChecker.
|
||||
type indexCheckerMgrImpl struct {
|
||||
checkers map[IndexType]IndexChecker
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
// GetAdapter gets the conf adapter by the index type.
|
||||
func (mgr *ConfAdapterMgrImpl) GetAdapter(indexType string) (ConfAdapter, error) {
|
||||
mgr.once.Do(mgr.registerConfAdapter)
|
||||
func (mgr *indexCheckerMgrImpl) GetChecker(indexType string) (IndexChecker, error) {
|
||||
mgr.once.Do(mgr.registerIndexChecker)
|
||||
|
||||
adapter, ok := mgr.adapters[indexType]
|
||||
adapter, ok := mgr.checkers[indexType]
|
||||
if ok {
|
||||
return adapter, nil
|
||||
}
|
||||
return nil, errors.New("Can not find conf adapter: " + indexType)
|
||||
}
|
||||
|
||||
func (mgr *ConfAdapterMgrImpl) registerConfAdapter() {
|
||||
mgr.adapters[IndexRaftIvfFlat] = newIVFConfAdapter()
|
||||
mgr.adapters[IndexRaftIvfPQ] = newRaftIVFPQConfAdapter()
|
||||
mgr.adapters[IndexFaissIDMap] = newBaseConfAdapter()
|
||||
mgr.adapters[IndexFaissIvfFlat] = newIVFConfAdapter()
|
||||
mgr.adapters[IndexFaissIvfPQ] = newIVFPQConfAdapter()
|
||||
mgr.adapters[IndexFaissIvfSQ8] = newIVFSQConfAdapter()
|
||||
mgr.adapters[IndexFaissBinIDMap] = newBinIDMAPConfAdapter()
|
||||
mgr.adapters[IndexFaissBinIvfFlat] = newBinIVFConfAdapter()
|
||||
mgr.adapters[IndexHNSW] = newHNSWConfAdapter()
|
||||
mgr.adapters[IndexDISKANN] = newDISKANNConfAdapter()
|
||||
func (mgr *indexCheckerMgrImpl) registerIndexChecker() {
|
||||
mgr.checkers[IndexRaftIvfFlat] = newIVFBaseChecker()
|
||||
mgr.checkers[IndexRaftIvfPQ] = newRaftIVFPQChecker()
|
||||
mgr.checkers[IndexFaissIDMap] = newBaseChecker()
|
||||
mgr.checkers[IndexFaissIvfFlat] = newIVFBaseChecker()
|
||||
mgr.checkers[IndexFaissIvfPQ] = newIVFPQChecker()
|
||||
mgr.checkers[IndexFaissIvfSQ8] = newIVFSQChecker()
|
||||
mgr.checkers[IndexFaissBinIDMap] = newBinFlatChecker()
|
||||
mgr.checkers[IndexFaissBinIvfFlat] = newBinIVFFlatChecker()
|
||||
mgr.checkers[IndexHNSW] = newHnswChecker()
|
||||
mgr.checkers[IndexDISKANN] = newDiskannChecker()
|
||||
}
|
||||
|
||||
func newConfAdapterMgrImpl() *ConfAdapterMgrImpl {
|
||||
return &ConfAdapterMgrImpl{
|
||||
adapters: make(map[IndexType]ConfAdapter),
|
||||
func newIndexCheckerMgr() *indexCheckerMgrImpl {
|
||||
return &indexCheckerMgrImpl{
|
||||
checkers: make(map[IndexType]IndexChecker),
|
||||
}
|
||||
}
|
||||
|
||||
var confAdapterMgr ConfAdapterMgr
|
||||
var indexCheckerMgr IndexCheckerMgr
|
||||
|
||||
var getConfAdapterMgrOnce sync.Once
|
||||
var getIndexCheckerMgrOnce sync.Once
|
||||
|
||||
// GetConfAdapterMgrInstance gets the instance of ConfAdapterMgr.
|
||||
func GetConfAdapterMgrInstance() ConfAdapterMgr {
|
||||
getConfAdapterMgrOnce.Do(func() {
|
||||
confAdapterMgr = newConfAdapterMgrImpl()
|
||||
// GetIndexCheckerMgrInstance gets the instance of IndexCheckerMgr.
|
||||
func GetIndexCheckerMgrInstance() IndexCheckerMgr {
|
||||
getIndexCheckerMgrOnce.Do(func() {
|
||||
indexCheckerMgr = newIndexCheckerMgr()
|
||||
})
|
||||
return confAdapterMgr
|
||||
return indexCheckerMgr
|
||||
}
|
||||
|
@ -19,122 +19,122 @@ import (
|
||||
)
|
||||
|
||||
func Test_GetConfAdapterMgrInstance(t *testing.T) {
|
||||
adapterMgr := GetConfAdapterMgrInstance()
|
||||
adapterMgr := GetIndexCheckerMgrInstance()
|
||||
|
||||
var adapter ConfAdapter
|
||||
var adapter IndexChecker
|
||||
var err error
|
||||
var ok bool
|
||||
|
||||
adapter, err = adapterMgr.GetAdapter("invalid")
|
||||
adapter, err = adapterMgr.GetChecker("invalid")
|
||||
assert.NotEqual(t, nil, err)
|
||||
assert.Equal(t, nil, adapter)
|
||||
|
||||
adapter, err = adapterMgr.GetAdapter(IndexFaissIDMap)
|
||||
adapter, err = adapterMgr.GetChecker(IndexFaissIDMap)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.NotEqual(t, nil, adapter)
|
||||
_, ok = adapter.(*BaseConfAdapter)
|
||||
_, ok = adapter.(*baseChecker)
|
||||
assert.Equal(t, true, ok)
|
||||
|
||||
adapter, err = adapterMgr.GetAdapter(IndexFaissIvfFlat)
|
||||
adapter, err = adapterMgr.GetChecker(IndexFaissIvfFlat)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.NotEqual(t, nil, adapter)
|
||||
_, ok = adapter.(*IVFConfAdapter)
|
||||
_, ok = adapter.(*ivfBaseChecker)
|
||||
assert.Equal(t, true, ok)
|
||||
|
||||
adapter, err = adapterMgr.GetAdapter(IndexFaissIvfPQ)
|
||||
adapter, err = adapterMgr.GetChecker(IndexFaissIvfPQ)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.NotEqual(t, nil, adapter)
|
||||
_, ok = adapter.(*IVFPQConfAdapter)
|
||||
_, ok = adapter.(*ivfPQChecker)
|
||||
assert.Equal(t, true, ok)
|
||||
|
||||
adapter, err = adapterMgr.GetAdapter(IndexFaissIvfSQ8)
|
||||
adapter, err = adapterMgr.GetChecker(IndexFaissIvfSQ8)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.NotEqual(t, nil, adapter)
|
||||
_, ok = adapter.(*IVFSQConfAdapter)
|
||||
_, ok = adapter.(*ivfSQChecker)
|
||||
assert.Equal(t, true, ok)
|
||||
|
||||
adapter, err = adapterMgr.GetAdapter(IndexFaissBinIDMap)
|
||||
adapter, err = adapterMgr.GetChecker(IndexFaissBinIDMap)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.NotEqual(t, nil, adapter)
|
||||
_, ok = adapter.(*BinIDMAPConfAdapter)
|
||||
_, ok = adapter.(*binFlatChecker)
|
||||
assert.Equal(t, true, ok)
|
||||
|
||||
adapter, err = adapterMgr.GetAdapter(IndexFaissBinIvfFlat)
|
||||
adapter, err = adapterMgr.GetChecker(IndexFaissBinIvfFlat)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.NotEqual(t, nil, adapter)
|
||||
_, ok = adapter.(*BinIVFConfAdapter)
|
||||
_, ok = adapter.(*binIVFFlatChecker)
|
||||
assert.Equal(t, true, ok)
|
||||
|
||||
adapter, err = adapterMgr.GetAdapter(IndexHNSW)
|
||||
adapter, err = adapterMgr.GetChecker(IndexHNSW)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.NotEqual(t, nil, adapter)
|
||||
_, ok = adapter.(*HNSWConfAdapter)
|
||||
_, ok = adapter.(*hnswChecker)
|
||||
assert.Equal(t, true, ok)
|
||||
}
|
||||
|
||||
func TestConfAdapterMgrImpl_GetAdapter(t *testing.T) {
|
||||
adapterMgr := newConfAdapterMgrImpl()
|
||||
adapterMgr := newIndexCheckerMgr()
|
||||
|
||||
var adapter ConfAdapter
|
||||
var adapter IndexChecker
|
||||
var err error
|
||||
var ok bool
|
||||
|
||||
adapter, err = adapterMgr.GetAdapter("invalid")
|
||||
adapter, err = adapterMgr.GetChecker("invalid")
|
||||
assert.NotEqual(t, nil, err)
|
||||
assert.Equal(t, nil, adapter)
|
||||
|
||||
adapter, err = adapterMgr.GetAdapter(IndexFaissIDMap)
|
||||
adapter, err = adapterMgr.GetChecker(IndexFaissIDMap)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.NotEqual(t, nil, adapter)
|
||||
_, ok = adapter.(*BaseConfAdapter)
|
||||
_, ok = adapter.(*baseChecker)
|
||||
assert.Equal(t, true, ok)
|
||||
|
||||
adapter, err = adapterMgr.GetAdapter(IndexFaissIvfFlat)
|
||||
adapter, err = adapterMgr.GetChecker(IndexFaissIvfFlat)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.NotEqual(t, nil, adapter)
|
||||
_, ok = adapter.(*IVFConfAdapter)
|
||||
_, ok = adapter.(*ivfBaseChecker)
|
||||
assert.Equal(t, true, ok)
|
||||
|
||||
adapter, err = adapterMgr.GetAdapter(IndexFaissIvfPQ)
|
||||
adapter, err = adapterMgr.GetChecker(IndexFaissIvfPQ)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.NotEqual(t, nil, adapter)
|
||||
_, ok = adapter.(*IVFPQConfAdapter)
|
||||
_, ok = adapter.(*ivfPQChecker)
|
||||
assert.Equal(t, true, ok)
|
||||
|
||||
adapter, err = adapterMgr.GetAdapter(IndexFaissIvfSQ8)
|
||||
adapter, err = adapterMgr.GetChecker(IndexFaissIvfSQ8)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.NotEqual(t, nil, adapter)
|
||||
_, ok = adapter.(*IVFSQConfAdapter)
|
||||
_, ok = adapter.(*ivfSQChecker)
|
||||
assert.Equal(t, true, ok)
|
||||
|
||||
adapter, err = adapterMgr.GetAdapter(IndexFaissBinIDMap)
|
||||
adapter, err = adapterMgr.GetChecker(IndexFaissBinIDMap)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.NotEqual(t, nil, adapter)
|
||||
_, ok = adapter.(*BinIDMAPConfAdapter)
|
||||
_, ok = adapter.(*binFlatChecker)
|
||||
assert.Equal(t, true, ok)
|
||||
|
||||
adapter, err = adapterMgr.GetAdapter(IndexFaissBinIvfFlat)
|
||||
adapter, err = adapterMgr.GetChecker(IndexFaissBinIvfFlat)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.NotEqual(t, nil, adapter)
|
||||
_, ok = adapter.(*BinIVFConfAdapter)
|
||||
_, ok = adapter.(*binIVFFlatChecker)
|
||||
assert.Equal(t, true, ok)
|
||||
|
||||
adapter, err = adapterMgr.GetAdapter(IndexHNSW)
|
||||
adapter, err = adapterMgr.GetChecker(IndexHNSW)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.NotEqual(t, nil, adapter)
|
||||
_, ok = adapter.(*HNSWConfAdapter)
|
||||
_, ok = adapter.(*hnswChecker)
|
||||
assert.Equal(t, true, ok)
|
||||
}
|
||||
|
||||
func TestConfAdapterMgrImpl_GetAdapter_multiple_threads(t *testing.T) {
|
||||
num := 4
|
||||
mgr := newConfAdapterMgrImpl()
|
||||
mgr := newIndexCheckerMgr()
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < num; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
adapter, err := mgr.GetAdapter(IndexHNSW)
|
||||
adapter, err := mgr.GetChecker(IndexHNSW)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, adapter)
|
||||
}()
|
||||
|
@ -1,410 +0,0 @@
|
||||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
package indexparamcheck
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TODO: add more test cases which `ConfAdapter.CheckTrain` return false,
|
||||
// for example, maybe we can copy test cases from regression test later.
|
||||
|
||||
func invalidIVFParamsMin() map[string]string {
|
||||
invalidIVFParams := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(MinNList - 1),
|
||||
Metric: L2,
|
||||
}
|
||||
return invalidIVFParams
|
||||
}
|
||||
|
||||
func invalidIVFParamsMax() map[string]string {
|
||||
invalidIVFParams := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(MaxNList + 1),
|
||||
Metric: L2,
|
||||
}
|
||||
return invalidIVFParams
|
||||
}
|
||||
|
||||
func copyParams(original map[string]string) map[string]string {
|
||||
result := make(map[string]string)
|
||||
for key, value := range original {
|
||||
result[key] = value
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func TestBaseConfAdapter_CheckTrain(t *testing.T) {
|
||||
validParams := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
Metric: L2,
|
||||
}
|
||||
paramsWithoutDim := map[string]string{
|
||||
Metric: L2,
|
||||
}
|
||||
cases := []struct {
|
||||
params map[string]string
|
||||
want bool
|
||||
}{
|
||||
{validParams, true},
|
||||
{paramsWithoutDim, false},
|
||||
}
|
||||
|
||||
adapter := newBaseConfAdapter()
|
||||
for _, test := range cases {
|
||||
if got := adapter.CheckTrain(test.params); got != test.want {
|
||||
t.Errorf("BaseConfAdapter.CheckTrain(%v) = %v", test.params, test.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// IVFConfAdapter checks if an ivf index can be built.
|
||||
func TestIVFConfAdapter_CheckTrain(t *testing.T) {
|
||||
validParams := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(1024),
|
||||
Metric: L2,
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
params map[string]string
|
||||
want bool
|
||||
}{
|
||||
{validParams, true},
|
||||
{invalidIVFParamsMin(), false},
|
||||
{invalidIVFParamsMax(), false},
|
||||
}
|
||||
|
||||
adapter := newIVFConfAdapter()
|
||||
for _, test := range cases {
|
||||
if got := adapter.CheckTrain(test.params); got != test.want {
|
||||
t.Errorf("IVFConfAdapter.CheckTrain(%v) = %v", test.params, test.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIVFPQConfAdapter_CheckTrain(t *testing.T) {
|
||||
validParams := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(1024),
|
||||
IVFM: strconv.Itoa(4),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: L2,
|
||||
}
|
||||
|
||||
validParamsWithoutNbits := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(1024),
|
||||
IVFM: strconv.Itoa(4),
|
||||
Metric: L2,
|
||||
}
|
||||
|
||||
validParamsWithoutDim := map[string]string{
|
||||
NLIST: strconv.Itoa(1024),
|
||||
IVFM: strconv.Itoa(4),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: L2,
|
||||
}
|
||||
|
||||
invalidParamsDim := copyParams(validParams)
|
||||
invalidParamsDim[DIM] = "NAN"
|
||||
|
||||
invalidParamsNbits := copyParams(validParams)
|
||||
invalidParamsNbits[NBITS] = "NAN"
|
||||
|
||||
invalidParamsWithoutIVF := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(1024),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: L2,
|
||||
}
|
||||
|
||||
invalidParamsIVF := copyParams(validParams)
|
||||
invalidParamsIVF[IVFM] = "NAN"
|
||||
|
||||
invalidParamsM := copyParams(validParams)
|
||||
invalidParamsM[DIM] = strconv.Itoa(65536)
|
||||
|
||||
invalidParamsMzero := copyParams(validParams)
|
||||
invalidParamsMzero[IVFM] = "0"
|
||||
|
||||
cases := []struct {
|
||||
params map[string]string
|
||||
want bool
|
||||
}{
|
||||
{validParams, true},
|
||||
{validParamsWithoutNbits, true},
|
||||
{invalidIVFParamsMin(), false},
|
||||
{invalidIVFParamsMax(), false},
|
||||
{validParamsWithoutDim, false},
|
||||
{invalidParamsDim, false},
|
||||
{invalidParamsNbits, false},
|
||||
{invalidParamsWithoutIVF, false},
|
||||
{invalidParamsIVF, false},
|
||||
{invalidParamsM, false},
|
||||
{invalidParamsMzero, false},
|
||||
}
|
||||
|
||||
adapter := newIVFPQConfAdapter()
|
||||
for i, test := range cases {
|
||||
if got := adapter.CheckTrain(test.params); got != test.want {
|
||||
t.Log("i:", i, "params", test.params)
|
||||
t.Errorf("IVFPQConfAdapter.CheckTrain(%v) = %v", test.params, test.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRaftIVFPQConfAdapter_CheckTrain(t *testing.T) {
|
||||
validParams := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(1024),
|
||||
IVFM: strconv.Itoa(4),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: L2,
|
||||
}
|
||||
|
||||
validParamsWithoutNbits := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(1024),
|
||||
IVFM: strconv.Itoa(4),
|
||||
Metric: L2,
|
||||
}
|
||||
|
||||
validParamsWithoutDim := map[string]string{
|
||||
NLIST: strconv.Itoa(1024),
|
||||
IVFM: strconv.Itoa(4),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: L2,
|
||||
}
|
||||
|
||||
invalidParamsDim := copyParams(validParams)
|
||||
invalidParamsDim[DIM] = "NAN"
|
||||
|
||||
invalidParamsNbits := copyParams(validParams)
|
||||
invalidParamsNbits[NBITS] = "NAN"
|
||||
|
||||
invalidParamsWithoutIVF := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(1024),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: L2,
|
||||
}
|
||||
|
||||
invalidParamsIVF := copyParams(validParams)
|
||||
invalidParamsIVF[IVFM] = "NAN"
|
||||
|
||||
invalidParamsM := copyParams(validParams)
|
||||
invalidParamsM[DIM] = strconv.Itoa(65536)
|
||||
|
||||
validParamsMzero := copyParams(validParams)
|
||||
validParamsMzero[IVFM] = "0"
|
||||
|
||||
cases := []struct {
|
||||
params map[string]string
|
||||
want bool
|
||||
}{
|
||||
{validParams, true},
|
||||
{validParamsWithoutNbits, true},
|
||||
{invalidIVFParamsMin(), false},
|
||||
{invalidIVFParamsMax(), false},
|
||||
{validParamsWithoutDim, false},
|
||||
{invalidParamsDim, false},
|
||||
{invalidParamsNbits, false},
|
||||
{invalidParamsWithoutIVF, false},
|
||||
{invalidParamsIVF, false},
|
||||
{invalidParamsM, false},
|
||||
{validParamsMzero, true},
|
||||
}
|
||||
|
||||
adapter := newRaftIVFPQConfAdapter()
|
||||
for i, test := range cases {
|
||||
if got := adapter.CheckTrain(test.params); got != test.want {
|
||||
t.Log("i:", i, "params", test.params)
|
||||
t.Errorf("RaftIVFPQConfAdapter.CheckTrain(%v) = %v", test.params, test.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIVFSQConfAdapter_CheckTrain(t *testing.T) {
|
||||
getValidParams := func(withNBits bool) map[string]string {
|
||||
validParams := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(100),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: L2,
|
||||
}
|
||||
if withNBits {
|
||||
validParams[NBITS] = strconv.Itoa(DefaultNBits)
|
||||
}
|
||||
return validParams
|
||||
}
|
||||
validParams := getValidParams(false)
|
||||
validParamsWithNBits := getValidParams(true)
|
||||
paramsWithInvalidNBits := getValidParams(false)
|
||||
paramsWithInvalidNBits[NBITS] = strconv.Itoa(DefaultNBits + 1)
|
||||
cases := []struct {
|
||||
params map[string]string
|
||||
want bool
|
||||
}{
|
||||
{validParams, true},
|
||||
{validParamsWithNBits, true},
|
||||
{paramsWithInvalidNBits, false},
|
||||
{invalidIVFParamsMin(), false},
|
||||
{invalidIVFParamsMax(), false},
|
||||
}
|
||||
|
||||
adapter := newIVFSQConfAdapter()
|
||||
for _, test := range cases {
|
||||
if got := adapter.CheckTrain(test.params); got != test.want {
|
||||
t.Errorf("IVFSQConfAdapter.CheckTrain(%v) = %v", test.params, test.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BinIDMAPConfAdapter checks if a bin id map index can be built.
|
||||
func TestBinIDMAPConfAdapter_CheckTrain(t *testing.T) {
|
||||
validParams := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
Metric: JACCARD,
|
||||
}
|
||||
paramsWithoutDim := map[string]string{
|
||||
Metric: JACCARD,
|
||||
}
|
||||
cases := []struct {
|
||||
params map[string]string
|
||||
want bool
|
||||
}{
|
||||
{validParams, true},
|
||||
{paramsWithoutDim, false},
|
||||
}
|
||||
|
||||
adapter := newBinIDMAPConfAdapter()
|
||||
for _, test := range cases {
|
||||
if got := adapter.CheckTrain(test.params); got != test.want {
|
||||
t.Errorf("BinIDMAPConfAdapter.CheckTrain(%v) = %v", test.params, test.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BinIVFConfAdapter checks if a bin ivf index can be built.
|
||||
func TestBinIVFConfAdapter_CheckTrain(t *testing.T) {
|
||||
validParams := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(100),
|
||||
IVFM: strconv.Itoa(4),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: JACCARD,
|
||||
}
|
||||
paramsWithoutDim := map[string]string{
|
||||
NLIST: strconv.Itoa(100),
|
||||
IVFM: strconv.Itoa(4),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: JACCARD,
|
||||
}
|
||||
|
||||
invalidParams := copyParams(validParams)
|
||||
invalidParams[Metric] = L2
|
||||
|
||||
cases := []struct {
|
||||
params map[string]string
|
||||
want bool
|
||||
}{
|
||||
{validParams, true},
|
||||
{paramsWithoutDim, false},
|
||||
{invalidIVFParamsMin(), false},
|
||||
{invalidIVFParamsMax(), false},
|
||||
{invalidParams, false},
|
||||
}
|
||||
|
||||
adapter := newBinIVFConfAdapter()
|
||||
for _, test := range cases {
|
||||
if got := adapter.CheckTrain(test.params); got != test.want {
|
||||
t.Errorf("BinIVFConfAdapter.CheckTrain(%v) = %v", test.params, test.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// HNSWConfAdapter checks if a hnsw index can be built.
|
||||
func TestHNSWConfAdapter_CheckTrain(t *testing.T) {
|
||||
validParams := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
HNSWM: strconv.Itoa(16),
|
||||
EFConstruction: strconv.Itoa(200),
|
||||
Metric: L2,
|
||||
}
|
||||
|
||||
invalidEfParamsMin := copyParams(validParams)
|
||||
invalidEfParamsMin[EFConstruction] = strconv.Itoa(HNSWMinEfConstruction - 1)
|
||||
|
||||
invalidEfParamsMax := copyParams(validParams)
|
||||
invalidEfParamsMax[EFConstruction] = strconv.Itoa(HNSWMaxEfConstruction + 1)
|
||||
|
||||
invalidMParamsMin := copyParams(validParams)
|
||||
invalidMParamsMin[HNSWM] = strconv.Itoa(HNSWMinM - 1)
|
||||
|
||||
invalidMParamsMax := copyParams(validParams)
|
||||
invalidMParamsMax[HNSWM] = strconv.Itoa(HNSWMaxM + 1)
|
||||
|
||||
cases := []struct {
|
||||
params map[string]string
|
||||
want bool
|
||||
}{
|
||||
{validParams, true},
|
||||
{invalidEfParamsMin, false},
|
||||
{invalidEfParamsMax, false},
|
||||
{invalidMParamsMin, false},
|
||||
{invalidMParamsMax, false},
|
||||
}
|
||||
|
||||
adapter := newHNSWConfAdapter()
|
||||
for _, test := range cases {
|
||||
if got := adapter.CheckTrain(test.params); got != test.want {
|
||||
t.Errorf("HNSWConfAdapter.CheckTrain(%v) = %v", test.params, test.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// DISKANNConfAdapter checks if an diskann index can be built
|
||||
func TestDiskAnnConfAdapter_CheckTrain(t *testing.T) {
|
||||
validParams := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
Metric: L2,
|
||||
}
|
||||
validParamsBigDim := copyParams(validParams)
|
||||
validParamsBigDim[DIM] = strconv.Itoa(2048)
|
||||
|
||||
invalidParamsWithoutDim := map[string]string{
|
||||
Metric: L2,
|
||||
}
|
||||
|
||||
invalidParamsSmallDim := copyParams(validParams)
|
||||
invalidParamsSmallDim[DIM] = strconv.Itoa(15)
|
||||
|
||||
cases := []struct {
|
||||
params map[string]string
|
||||
want bool
|
||||
}{
|
||||
{validParams, true},
|
||||
{validParamsBigDim, true},
|
||||
{invalidParamsWithoutDim, false},
|
||||
{invalidParamsSmallDim, false},
|
||||
}
|
||||
|
||||
adapter := newDISKANNConfAdapter()
|
||||
for _, test := range cases {
|
||||
if got := adapter.CheckTrain(test.params); got != test.want {
|
||||
t.Errorf("DiskAnnConfAdapter.CheckTrain(%v) = %v", test.params, test.want)
|
||||
}
|
||||
}
|
||||
}
|
72
pkg/util/indexparamcheck/constraints.go
Normal file
72
pkg/util/indexparamcheck/constraints.go
Normal file
@ -0,0 +1,72 @@
|
||||
package indexparamcheck
|
||||
|
||||
const (
|
||||
// L2 represents Euclidean distance
|
||||
L2 = "L2"
|
||||
|
||||
// IP represents inner product distance
|
||||
IP = "IP"
|
||||
|
||||
// COSINE represents cosine distance
|
||||
COSINE = "COSINE"
|
||||
|
||||
// HAMMING represents hamming distance
|
||||
HAMMING = "HAMMING"
|
||||
|
||||
// JACCARD represents jaccard distance
|
||||
JACCARD = "JACCARD"
|
||||
|
||||
// TANIMOTO represents tanimoto distance
|
||||
TANIMOTO = "TANIMOTO"
|
||||
|
||||
// SUBSTRUCTURE represents substructure distance
|
||||
SUBSTRUCTURE = "SUBSTRUCTURE"
|
||||
|
||||
// SUPERSTRUCTURE represents superstructure distance
|
||||
SUPERSTRUCTURE = "SUPERSTRUCTURE"
|
||||
|
||||
MinNBits = 1
|
||||
MaxNBits = 16
|
||||
DefaultNBits = 8
|
||||
|
||||
// MinNList is the lower limit of nlist that used in Index IVFxxx
|
||||
MinNList = 1
|
||||
// MaxNList is the upper limit of nlist that used in Index IVFxxx
|
||||
MaxNList = 65536
|
||||
|
||||
// DefaultMinDim is the smallest dimension supported in Milvus
|
||||
DefaultMinDim = 1
|
||||
// DefaultMaxDim is the largest dimension supported in Milvus
|
||||
DefaultMaxDim = 32768
|
||||
|
||||
// If Dim = 32 and raw vector data = 2G, query node need 24G disk space When loading the vectors' disk index
|
||||
// If Dim = 2, and raw vector data = 2G, query node need 240G disk space When loading the vectors' disk index
|
||||
// So DiskAnnMinDim should be greater than or equal to 32 to avoid running out of disk space
|
||||
DiskAnnMinDim = 32
|
||||
|
||||
HNSWMinEfConstruction = 8
|
||||
HNSWMaxEfConstruction = 512
|
||||
HNSWMinM = 4
|
||||
HNSWMaxM = 64
|
||||
|
||||
// DIM is a constant used to represent dimension
|
||||
DIM = "dim"
|
||||
// Metric is a constant used to metric type
|
||||
Metric = "metric_type"
|
||||
// NLIST is a constant used to nlist in Index IVFxxx
|
||||
NLIST = "nlist"
|
||||
NBITS = "nbits"
|
||||
IVFM = "m"
|
||||
|
||||
EFConstruction = "efConstruction"
|
||||
HNSWM = "M"
|
||||
)
|
||||
|
||||
// METRICS is a set of all metrics types supported for float vector.
|
||||
var METRICS = []string{L2, IP, COSINE} // const
|
||||
|
||||
// BinIDMapMetrics is a set of all metric types supported for binary vector.
|
||||
var BinIDMapMetrics = []string{HAMMING, JACCARD, TANIMOTO, SUBSTRUCTURE, SUPERSTRUCTURE} // const
|
||||
var BinIvfMetrics = []string{HAMMING, JACCARD, TANIMOTO} // const
|
||||
var supportDimPerSubQuantizer = []int{32, 28, 24, 20, 16, 12, 10, 8, 6, 4, 3, 2, 1} // const
|
||||
var supportSubQuantizer = []int{96, 64, 56, 48, 40, 32, 28, 24, 20, 16, 12, 8, 4, 3, 2, 1} // const
|
17
pkg/util/indexparamcheck/diskann_checker.go
Normal file
17
pkg/util/indexparamcheck/diskann_checker.go
Normal file
@ -0,0 +1,17 @@
|
||||
package indexparamcheck
|
||||
|
||||
// diskannChecker checks if an diskann index can be built.
|
||||
type diskannChecker struct {
|
||||
floatVectorBaseChecker
|
||||
}
|
||||
|
||||
func (c *diskannChecker) CheckTrain(params map[string]string) error {
|
||||
if !CheckIntByRange(params, DIM, DiskAnnMinDim, DefaultMaxDim) {
|
||||
return errOutOfRange(DIM, DiskAnnMinDim, DefaultMaxDim)
|
||||
}
|
||||
return c.floatVectorBaseChecker.CheckTrain(params)
|
||||
}
|
||||
|
||||
func newDiskannChecker() IndexChecker {
|
||||
return &diskannChecker{}
|
||||
}
|
159
pkg/util/indexparamcheck/diskann_checker_test.go
Normal file
159
pkg/util/indexparamcheck/diskann_checker_test.go
Normal file
@ -0,0 +1,159 @@
|
||||
package indexparamcheck
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_diskannChecker_CheckTrain(t *testing.T) {
|
||||
validParams := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
Metric: L2,
|
||||
}
|
||||
validParamsBigDim := copyParams(validParams)
|
||||
validParamsBigDim[DIM] = strconv.Itoa(2048)
|
||||
|
||||
invalidParamsWithoutDim := map[string]string{
|
||||
Metric: L2,
|
||||
}
|
||||
|
||||
invalidParamsSmallDim := copyParams(validParams)
|
||||
invalidParamsSmallDim[DIM] = strconv.Itoa(15)
|
||||
|
||||
p1 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
Metric: L2,
|
||||
}
|
||||
p2 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
Metric: IP,
|
||||
}
|
||||
p3 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
Metric: COSINE,
|
||||
}
|
||||
|
||||
p4 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
Metric: HAMMING,
|
||||
}
|
||||
p5 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
Metric: JACCARD,
|
||||
}
|
||||
p6 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
Metric: TANIMOTO,
|
||||
}
|
||||
p7 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
Metric: SUBSTRUCTURE,
|
||||
}
|
||||
p8 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
Metric: SUPERSTRUCTURE,
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
params map[string]string
|
||||
errIsNil bool
|
||||
}{
|
||||
{validParams, true},
|
||||
{validParamsBigDim, true},
|
||||
{invalidParamsWithoutDim, false},
|
||||
{invalidParamsSmallDim, false},
|
||||
{p1, true},
|
||||
{p2, true},
|
||||
{p3, true},
|
||||
{p4, false},
|
||||
{p5, false},
|
||||
{p6, false},
|
||||
{p7, false},
|
||||
{p8, false},
|
||||
}
|
||||
|
||||
c := newDiskannChecker()
|
||||
for _, test := range cases {
|
||||
err := c.CheckTrain(test.params)
|
||||
if test.errIsNil {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
assert.Error(t, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test_diskannChecker_CheckValidDataType(t *testing.T) {
|
||||
|
||||
cases := []struct {
|
||||
dType schemapb.DataType
|
||||
errIsNil bool
|
||||
}{
|
||||
{
|
||||
dType: schemapb.DataType_Bool,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Int8,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Int16,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Int32,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Int64,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Float,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Double,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_String,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_VarChar,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Array,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_JSON,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_FloatVector,
|
||||
errIsNil: true,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_BinaryVector,
|
||||
errIsNil: false,
|
||||
},
|
||||
}
|
||||
|
||||
c := newDiskannChecker()
|
||||
for _, test := range cases {
|
||||
err := c.CheckValidDataType(test.dType)
|
||||
if test.errIsNil {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
assert.Error(t, err)
|
||||
}
|
||||
}
|
||||
}
|
34
pkg/util/indexparamcheck/float_vector_base_checker.go
Normal file
34
pkg/util/indexparamcheck/float_vector_base_checker.go
Normal file
@ -0,0 +1,34 @@
|
||||
package indexparamcheck
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
)
|
||||
|
||||
type floatVectorBaseChecker struct {
|
||||
baseChecker
|
||||
}
|
||||
|
||||
func (c *floatVectorBaseChecker) CheckTrain(params map[string]string) error {
|
||||
if err := c.baseChecker.CheckTrain(params); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !CheckStrByValues(params, Metric, METRICS) {
|
||||
return fmt.Errorf("metric type not found or not supported, supported: %v", METRICS)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *floatVectorBaseChecker) CheckValidDataType(dType schemapb.DataType) error {
|
||||
if dType != schemapb.DataType_FloatVector {
|
||||
return fmt.Errorf("float vector is only supported")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func newFloatVectorBaseChecker() IndexChecker {
|
||||
return &floatVectorBaseChecker{}
|
||||
}
|
79
pkg/util/indexparamcheck/float_vector_base_checker_test.go
Normal file
79
pkg/util/indexparamcheck/float_vector_base_checker_test.go
Normal file
@ -0,0 +1,79 @@
|
||||
package indexparamcheck
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_floatVectorBaseChecker_CheckValidDataType(t *testing.T) {
|
||||
|
||||
cases := []struct {
|
||||
dType schemapb.DataType
|
||||
errIsNil bool
|
||||
}{
|
||||
{
|
||||
dType: schemapb.DataType_Bool,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Int8,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Int16,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Int32,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Int64,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Float,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Double,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_String,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_VarChar,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Array,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_JSON,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_FloatVector,
|
||||
errIsNil: true,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_BinaryVector,
|
||||
errIsNil: false,
|
||||
},
|
||||
}
|
||||
|
||||
c := newFloatVectorBaseChecker()
|
||||
for _, test := range cases {
|
||||
err := c.CheckValidDataType(test.dType)
|
||||
if test.errIsNil {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
assert.Error(t, err)
|
||||
}
|
||||
}
|
||||
}
|
21
pkg/util/indexparamcheck/hnsw_checker.go
Normal file
21
pkg/util/indexparamcheck/hnsw_checker.go
Normal file
@ -0,0 +1,21 @@
|
||||
package indexparamcheck
|
||||
|
||||
type hnswChecker struct {
|
||||
floatVectorBaseChecker
|
||||
}
|
||||
|
||||
func (c *hnswChecker) CheckTrain(params map[string]string) error {
|
||||
if !CheckIntByRange(params, EFConstruction, HNSWMinEfConstruction, HNSWMaxEfConstruction) {
|
||||
return errOutOfRange(EFConstruction, HNSWMinEfConstruction, HNSWMaxEfConstruction)
|
||||
}
|
||||
|
||||
if !CheckIntByRange(params, HNSWM, HNSWMinM, HNSWMaxM) {
|
||||
return errOutOfRange(HNSWM, HNSWMinM, HNSWMaxM)
|
||||
}
|
||||
|
||||
return c.floatVectorBaseChecker.CheckTrain(params)
|
||||
}
|
||||
|
||||
func newHnswChecker() IndexChecker {
|
||||
return &hnswChecker{}
|
||||
}
|
182
pkg/util/indexparamcheck/hnsw_checker_test.go
Normal file
182
pkg/util/indexparamcheck/hnsw_checker_test.go
Normal file
@ -0,0 +1,182 @@
|
||||
package indexparamcheck
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_hnswChecker_CheckTrain(t *testing.T) {
|
||||
|
||||
validParams := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
HNSWM: strconv.Itoa(16),
|
||||
EFConstruction: strconv.Itoa(200),
|
||||
Metric: L2,
|
||||
}
|
||||
|
||||
invalidEfParamsMin := copyParams(validParams)
|
||||
invalidEfParamsMin[EFConstruction] = strconv.Itoa(HNSWMinEfConstruction - 1)
|
||||
|
||||
invalidEfParamsMax := copyParams(validParams)
|
||||
invalidEfParamsMax[EFConstruction] = strconv.Itoa(HNSWMaxEfConstruction + 1)
|
||||
|
||||
invalidMParamsMin := copyParams(validParams)
|
||||
invalidMParamsMin[HNSWM] = strconv.Itoa(HNSWMinM - 1)
|
||||
|
||||
invalidMParamsMax := copyParams(validParams)
|
||||
invalidMParamsMax[HNSWM] = strconv.Itoa(HNSWMaxM + 1)
|
||||
|
||||
p1 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
HNSWM: strconv.Itoa(16),
|
||||
EFConstruction: strconv.Itoa(200),
|
||||
Metric: L2,
|
||||
}
|
||||
p2 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
HNSWM: strconv.Itoa(16),
|
||||
EFConstruction: strconv.Itoa(200),
|
||||
Metric: IP,
|
||||
}
|
||||
p3 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
HNSWM: strconv.Itoa(16),
|
||||
EFConstruction: strconv.Itoa(200),
|
||||
Metric: COSINE,
|
||||
}
|
||||
|
||||
p4 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
HNSWM: strconv.Itoa(16),
|
||||
EFConstruction: strconv.Itoa(200),
|
||||
Metric: HAMMING,
|
||||
}
|
||||
p5 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
HNSWM: strconv.Itoa(16),
|
||||
EFConstruction: strconv.Itoa(200),
|
||||
Metric: JACCARD,
|
||||
}
|
||||
p6 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
HNSWM: strconv.Itoa(16),
|
||||
EFConstruction: strconv.Itoa(200),
|
||||
Metric: TANIMOTO,
|
||||
}
|
||||
p7 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
HNSWM: strconv.Itoa(16),
|
||||
EFConstruction: strconv.Itoa(200),
|
||||
Metric: SUBSTRUCTURE,
|
||||
}
|
||||
p8 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
HNSWM: strconv.Itoa(16),
|
||||
EFConstruction: strconv.Itoa(200),
|
||||
Metric: SUPERSTRUCTURE,
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
params map[string]string
|
||||
errIsNil bool
|
||||
}{
|
||||
{validParams, true},
|
||||
{invalidEfParamsMin, false},
|
||||
{invalidEfParamsMax, false},
|
||||
{invalidMParamsMin, false},
|
||||
{invalidMParamsMax, false},
|
||||
{p1, true},
|
||||
{p2, true},
|
||||
{p3, true},
|
||||
{p4, false},
|
||||
{p5, false},
|
||||
{p6, false},
|
||||
{p7, false},
|
||||
{p8, false},
|
||||
}
|
||||
|
||||
c := newHnswChecker()
|
||||
for _, test := range cases {
|
||||
err := c.CheckTrain(test.params)
|
||||
if test.errIsNil {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
assert.Error(t, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test_hnswChecker_CheckValidDataType(t *testing.T) {
|
||||
|
||||
cases := []struct {
|
||||
dType schemapb.DataType
|
||||
errIsNil bool
|
||||
}{
|
||||
{
|
||||
dType: schemapb.DataType_Bool,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Int8,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Int16,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Int32,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Int64,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Float,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Double,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_String,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_VarChar,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Array,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_JSON,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_FloatVector,
|
||||
errIsNil: true,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_BinaryVector,
|
||||
errIsNil: false,
|
||||
},
|
||||
}
|
||||
|
||||
c := newHnswChecker()
|
||||
for _, test := range cases {
|
||||
err := c.CheckValidDataType(test.dType)
|
||||
if test.errIsNil {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
assert.Error(t, err)
|
||||
}
|
||||
}
|
||||
}
|
26
pkg/util/indexparamcheck/index_checker.go
Normal file
26
pkg/util/indexparamcheck/index_checker.go
Normal file
@ -0,0 +1,26 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package indexparamcheck
|
||||
|
||||
import (
|
||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
)
|
||||
|
||||
type IndexChecker interface {
|
||||
CheckTrain(map[string]string) error
|
||||
CheckValidDataType(dType schemapb.DataType) error
|
||||
}
|
45
pkg/util/indexparamcheck/index_checker_test.go
Normal file
45
pkg/util/indexparamcheck/index_checker_test.go
Normal file
@ -0,0 +1,45 @@
|
||||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
package indexparamcheck
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// TODO: add more test cases which `IndexChecker.CheckTrain` return false,
|
||||
// for example, maybe we can copy test cases from regression test later.
|
||||
|
||||
func invalidIVFParamsMin() map[string]string {
|
||||
invalidIVFParams := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(MinNList - 1),
|
||||
Metric: L2,
|
||||
}
|
||||
return invalidIVFParams
|
||||
}
|
||||
|
||||
func invalidIVFParamsMax() map[string]string {
|
||||
invalidIVFParams := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(MaxNList + 1),
|
||||
Metric: L2,
|
||||
}
|
||||
return invalidIVFParams
|
||||
}
|
||||
|
||||
func copyParams(original map[string]string) map[string]string {
|
||||
result := make(map[string]string)
|
||||
for key, value := range original {
|
||||
result[key] = value
|
||||
}
|
||||
return result
|
||||
}
|
19
pkg/util/indexparamcheck/ivf_base_checker.go
Normal file
19
pkg/util/indexparamcheck/ivf_base_checker.go
Normal file
@ -0,0 +1,19 @@
|
||||
package indexparamcheck
|
||||
|
||||
type ivfBaseChecker struct {
|
||||
floatVectorBaseChecker
|
||||
}
|
||||
|
||||
func (c *ivfBaseChecker) CheckTrain(params map[string]string) error {
|
||||
if !CheckIntByRange(params, NLIST, MinNList, MaxNList) {
|
||||
return errOutOfRange(NLIST, MinNList, MaxNList)
|
||||
}
|
||||
|
||||
// skip check number of rows
|
||||
|
||||
return c.floatVectorBaseChecker.CheckTrain(params)
|
||||
}
|
||||
|
||||
func newIVFBaseChecker() IndexChecker {
|
||||
return &ivfBaseChecker{}
|
||||
}
|
157
pkg/util/indexparamcheck/ivf_base_checker_test.go
Normal file
157
pkg/util/indexparamcheck/ivf_base_checker_test.go
Normal file
@ -0,0 +1,157 @@
|
||||
package indexparamcheck
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_ivfBaseChecker_CheckTrain(t *testing.T) {
|
||||
validParams := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(1024),
|
||||
Metric: L2,
|
||||
}
|
||||
|
||||
p1 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(1024),
|
||||
Metric: L2,
|
||||
}
|
||||
p2 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(1024),
|
||||
Metric: IP,
|
||||
}
|
||||
p3 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(1024),
|
||||
Metric: COSINE,
|
||||
}
|
||||
|
||||
p4 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(1024),
|
||||
Metric: HAMMING,
|
||||
}
|
||||
p5 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(1024),
|
||||
Metric: JACCARD,
|
||||
}
|
||||
p6 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(1024),
|
||||
Metric: TANIMOTO,
|
||||
}
|
||||
p7 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(1024),
|
||||
Metric: SUBSTRUCTURE,
|
||||
}
|
||||
p8 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(1024),
|
||||
Metric: SUPERSTRUCTURE,
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
params map[string]string
|
||||
errIsNil bool
|
||||
}{
|
||||
{validParams, true},
|
||||
{invalidIVFParamsMin(), false},
|
||||
{invalidIVFParamsMax(), false},
|
||||
{p1, true},
|
||||
{p2, true},
|
||||
{p3, true},
|
||||
{p4, false},
|
||||
{p5, false},
|
||||
{p6, false},
|
||||
{p7, false},
|
||||
{p8, false},
|
||||
}
|
||||
|
||||
c := newIVFBaseChecker()
|
||||
for _, test := range cases {
|
||||
err := c.CheckTrain(test.params)
|
||||
if test.errIsNil {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
assert.Error(t, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test_ivfBaseChecker_CheckValidDataType(t *testing.T) {
|
||||
cases := []struct {
|
||||
dType schemapb.DataType
|
||||
errIsNil bool
|
||||
}{
|
||||
{
|
||||
dType: schemapb.DataType_Bool,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Int8,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Int16,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Int32,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Int64,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Float,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Double,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_String,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_VarChar,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Array,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_JSON,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_FloatVector,
|
||||
errIsNil: true,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_BinaryVector,
|
||||
errIsNil: false,
|
||||
},
|
||||
}
|
||||
|
||||
c := newIVFBaseChecker()
|
||||
for _, test := range cases {
|
||||
err := c.CheckValidDataType(test.dType)
|
||||
if test.errIsNil {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
assert.Error(t, err)
|
||||
}
|
||||
}
|
||||
}
|
63
pkg/util/indexparamcheck/ivf_pq_checker.go
Normal file
63
pkg/util/indexparamcheck/ivf_pq_checker.go
Normal file
@ -0,0 +1,63 @@
|
||||
package indexparamcheck
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// ivfPQChecker checks if a IVF_PQ index can be built.
|
||||
type ivfPQChecker struct {
|
||||
ivfBaseChecker
|
||||
}
|
||||
|
||||
// CheckTrain checks if ivf-pq index can be built with the specific index parameters.
|
||||
func (c *ivfPQChecker) CheckTrain(params map[string]string) error {
|
||||
if err := c.ivfBaseChecker.CheckTrain(params); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.checkPQParams(params)
|
||||
}
|
||||
|
||||
func (c *ivfPQChecker) checkPQParams(params map[string]string) error {
|
||||
dimStr, dimensionExist := params[DIM]
|
||||
if !dimensionExist {
|
||||
return fmt.Errorf("dimension not found")
|
||||
}
|
||||
|
||||
dimension, err := strconv.Atoi(dimStr)
|
||||
if err != nil { // invalid dimension
|
||||
return fmt.Errorf("invalid dimension: %s", dimStr)
|
||||
}
|
||||
|
||||
// nbits can be set to default: 8
|
||||
nbitsStr, nbitsExist := params[NBITS]
|
||||
if nbitsExist {
|
||||
_, err := strconv.Atoi(nbitsStr)
|
||||
if err != nil { // invalid nbits
|
||||
return fmt.Errorf("invalid nbits: %s", nbitsStr)
|
||||
}
|
||||
}
|
||||
|
||||
mStr, ok := params[IVFM]
|
||||
if !ok {
|
||||
return fmt.Errorf("parameter `m` not found")
|
||||
}
|
||||
m, err := strconv.Atoi(mStr)
|
||||
if err != nil || m == 0 { // invalid m
|
||||
return fmt.Errorf("invalid `m`: %s", mStr)
|
||||
}
|
||||
|
||||
return c.checkCPUPQParams(dimension, m)
|
||||
}
|
||||
|
||||
func (c *ivfPQChecker) checkCPUPQParams(dimension, m int) error {
|
||||
if (dimension % m) != 0 {
|
||||
return fmt.Errorf("dimension must be abled to be divided by `m`, dimension: %d, m: %d", dimension, m)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func newIVFPQChecker() IndexChecker {
|
||||
return &ivfPQChecker{}
|
||||
}
|
229
pkg/util/indexparamcheck/ivf_pq_checker_test.go
Normal file
229
pkg/util/indexparamcheck/ivf_pq_checker_test.go
Normal file
@ -0,0 +1,229 @@
|
||||
package indexparamcheck
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_ivfPQChecker_CheckTrain(t *testing.T) {
|
||||
validParams := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(1024),
|
||||
IVFM: strconv.Itoa(4),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: L2,
|
||||
}
|
||||
|
||||
paramsNotMultiplier := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(1024),
|
||||
IVFM: strconv.Itoa(5),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: L2,
|
||||
}
|
||||
|
||||
validParamsWithoutNbits := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(1024),
|
||||
IVFM: strconv.Itoa(4),
|
||||
Metric: L2,
|
||||
}
|
||||
|
||||
validParamsWithoutDim := map[string]string{
|
||||
NLIST: strconv.Itoa(1024),
|
||||
IVFM: strconv.Itoa(4),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: L2,
|
||||
}
|
||||
|
||||
invalidParamsDim := copyParams(validParams)
|
||||
invalidParamsDim[DIM] = "NAN"
|
||||
|
||||
invalidParamsNbits := copyParams(validParams)
|
||||
invalidParamsNbits[NBITS] = "NAN"
|
||||
|
||||
invalidParamsWithoutIVF := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(1024),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: L2,
|
||||
}
|
||||
|
||||
invalidParamsIVF := copyParams(validParams)
|
||||
invalidParamsIVF[IVFM] = "NAN"
|
||||
|
||||
invalidParamsM := copyParams(validParams)
|
||||
invalidParamsM[DIM] = strconv.Itoa(65536)
|
||||
|
||||
invalidParamsMzero := copyParams(validParams)
|
||||
invalidParamsMzero[IVFM] = "0"
|
||||
|
||||
p1 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(1024),
|
||||
IVFM: strconv.Itoa(4),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: L2,
|
||||
}
|
||||
p2 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(1024),
|
||||
IVFM: strconv.Itoa(4),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: IP,
|
||||
}
|
||||
p3 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(1024),
|
||||
IVFM: strconv.Itoa(4),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: COSINE,
|
||||
}
|
||||
|
||||
p4 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(1024),
|
||||
IVFM: strconv.Itoa(4),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: HAMMING,
|
||||
}
|
||||
p5 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(1024),
|
||||
IVFM: strconv.Itoa(4),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: JACCARD,
|
||||
}
|
||||
p6 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(1024),
|
||||
IVFM: strconv.Itoa(4),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: TANIMOTO,
|
||||
}
|
||||
p7 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(1024),
|
||||
IVFM: strconv.Itoa(4),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: SUBSTRUCTURE,
|
||||
}
|
||||
p8 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(1024),
|
||||
IVFM: strconv.Itoa(4),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: SUPERSTRUCTURE,
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
params map[string]string
|
||||
errIsNil bool
|
||||
}{
|
||||
{validParams, true},
|
||||
{paramsNotMultiplier, false},
|
||||
{validParamsWithoutNbits, true},
|
||||
{invalidIVFParamsMin(), false},
|
||||
{invalidIVFParamsMax(), false},
|
||||
{validParamsWithoutDim, false},
|
||||
{invalidParamsDim, false},
|
||||
{invalidParamsNbits, false},
|
||||
{invalidParamsWithoutIVF, false},
|
||||
{invalidParamsIVF, false},
|
||||
{invalidParamsM, false},
|
||||
{invalidParamsMzero, false},
|
||||
{p1, true},
|
||||
{p2, true},
|
||||
{p3, true},
|
||||
{p4, false},
|
||||
{p5, false},
|
||||
{p6, false},
|
||||
{p7, false},
|
||||
{p8, false},
|
||||
}
|
||||
|
||||
c := newIVFPQChecker()
|
||||
for _, test := range cases {
|
||||
err := c.CheckTrain(test.params)
|
||||
if test.errIsNil {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
assert.Error(t, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test_ivfPQChecker_CheckValidDataType(t *testing.T) {
|
||||
|
||||
cases := []struct {
|
||||
dType schemapb.DataType
|
||||
errIsNil bool
|
||||
}{
|
||||
{
|
||||
dType: schemapb.DataType_Bool,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Int8,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Int16,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Int32,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Int64,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Float,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Double,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_String,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_VarChar,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Array,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_JSON,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_FloatVector,
|
||||
errIsNil: true,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_BinaryVector,
|
||||
errIsNil: false,
|
||||
},
|
||||
}
|
||||
|
||||
c := newIVFPQChecker()
|
||||
for _, test := range cases {
|
||||
err := c.CheckValidDataType(test.dType)
|
||||
if test.errIsNil {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
assert.Error(t, err)
|
||||
}
|
||||
}
|
||||
}
|
34
pkg/util/indexparamcheck/ivf_sq_checker.go
Normal file
34
pkg/util/indexparamcheck/ivf_sq_checker.go
Normal file
@ -0,0 +1,34 @@
|
||||
package indexparamcheck
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// ivfSQChecker checks if a IVF_SQ index can be built.
|
||||
type ivfSQChecker struct {
|
||||
ivfBaseChecker
|
||||
}
|
||||
|
||||
func (c *ivfSQChecker) checkNBits(params map[string]string) error {
|
||||
// cgo will set this key to DefaultNBits (8), which is the only value Milvus supports.
|
||||
_, exist := params[NBITS]
|
||||
if exist {
|
||||
// 8 is the only supported nbits.
|
||||
if !CheckIntByRange(params, NBITS, DefaultNBits, DefaultNBits) {
|
||||
return fmt.Errorf("nbits can be only set to 8 for IVF_SQ")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CheckTrain returns true if the index can be built with the specific index parameters.
|
||||
func (c *ivfSQChecker) CheckTrain(params map[string]string) error {
|
||||
if err := c.checkNBits(params); err != nil {
|
||||
return err
|
||||
}
|
||||
return c.ivfBaseChecker.CheckTrain(params)
|
||||
}
|
||||
|
||||
func newIVFSQChecker() IndexChecker {
|
||||
return &ivfSQChecker{}
|
||||
}
|
178
pkg/util/indexparamcheck/ivf_sq_checker_test.go
Normal file
178
pkg/util/indexparamcheck/ivf_sq_checker_test.go
Normal file
@ -0,0 +1,178 @@
|
||||
package indexparamcheck
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_ivfSQChecker_CheckTrain(t *testing.T) {
|
||||
getValidParams := func(withNBits bool) map[string]string {
|
||||
validParams := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(100),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: L2,
|
||||
}
|
||||
if withNBits {
|
||||
validParams[NBITS] = strconv.Itoa(DefaultNBits)
|
||||
}
|
||||
return validParams
|
||||
}
|
||||
validParams := getValidParams(false)
|
||||
validParamsWithNBits := getValidParams(true)
|
||||
paramsWithInvalidNBits := getValidParams(false)
|
||||
paramsWithInvalidNBits[NBITS] = strconv.Itoa(DefaultNBits + 1)
|
||||
|
||||
p1 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(100),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: L2,
|
||||
}
|
||||
p2 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(100),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: IP,
|
||||
}
|
||||
p3 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(100),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: COSINE,
|
||||
}
|
||||
|
||||
p4 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(100),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: HAMMING,
|
||||
}
|
||||
p5 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(100),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: JACCARD,
|
||||
}
|
||||
p6 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(100),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: TANIMOTO,
|
||||
}
|
||||
p7 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(100),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: SUBSTRUCTURE,
|
||||
}
|
||||
p8 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(100),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: SUPERSTRUCTURE,
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
params map[string]string
|
||||
errIsNil bool
|
||||
}{
|
||||
{validParams, true},
|
||||
{validParamsWithNBits, true},
|
||||
{paramsWithInvalidNBits, false},
|
||||
{invalidIVFParamsMin(), false},
|
||||
{invalidIVFParamsMax(), false},
|
||||
{p1, true},
|
||||
{p2, true},
|
||||
{p3, true},
|
||||
{p4, false},
|
||||
{p5, false},
|
||||
{p6, false},
|
||||
{p7, false},
|
||||
{p8, false},
|
||||
}
|
||||
|
||||
c := newIVFSQChecker()
|
||||
for _, test := range cases {
|
||||
err := c.CheckTrain(test.params)
|
||||
if test.errIsNil {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
assert.Error(t, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test_ivfSQChecker_CheckValidDataType(t *testing.T) {
|
||||
cases := []struct {
|
||||
dType schemapb.DataType
|
||||
errIsNil bool
|
||||
}{
|
||||
{
|
||||
dType: schemapb.DataType_Bool,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Int8,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Int16,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Int32,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Int64,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Float,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Double,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_String,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_VarChar,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Array,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_JSON,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_FloatVector,
|
||||
errIsNil: true,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_BinaryVector,
|
||||
errIsNil: false,
|
||||
},
|
||||
}
|
||||
|
||||
c := newIVFSQChecker()
|
||||
for _, test := range cases {
|
||||
err := c.CheckValidDataType(test.dType)
|
||||
if test.errIsNil {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
assert.Error(t, err)
|
||||
}
|
||||
}
|
||||
}
|
63
pkg/util/indexparamcheck/raft_ivf_pq_checker.go
Normal file
63
pkg/util/indexparamcheck/raft_ivf_pq_checker.go
Normal file
@ -0,0 +1,63 @@
|
||||
package indexparamcheck
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// raftIVFPQChecker checks if a RAFT_IVF_PQ index can be built.
|
||||
type raftIVFPQChecker struct {
|
||||
ivfBaseChecker
|
||||
}
|
||||
|
||||
// CheckTrain checks if ivf-pq index can be built with the specific index parameters.
|
||||
func (c *raftIVFPQChecker) CheckTrain(params map[string]string) error {
|
||||
if err := c.ivfBaseChecker.CheckTrain(params); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.checkPQParams(params)
|
||||
}
|
||||
|
||||
func (c *raftIVFPQChecker) checkPQParams(params map[string]string) error {
|
||||
dimStr, dimensionExist := params[DIM]
|
||||
if !dimensionExist {
|
||||
return fmt.Errorf("dimension not found")
|
||||
}
|
||||
|
||||
dimension, err := strconv.Atoi(dimStr)
|
||||
if err != nil { // invalid dimension
|
||||
return fmt.Errorf("invalid dimension: %s", dimStr)
|
||||
}
|
||||
|
||||
// nbits can be set to default: 8
|
||||
nbitsStr, nbitsExist := params[NBITS]
|
||||
if nbitsExist {
|
||||
_, err := strconv.Atoi(nbitsStr)
|
||||
if err != nil { // invalid nbits
|
||||
return fmt.Errorf("invalid nbits: %s", nbitsStr)
|
||||
}
|
||||
}
|
||||
|
||||
mStr, ok := params[IVFM]
|
||||
if !ok {
|
||||
return fmt.Errorf("parameter `m` not found")
|
||||
}
|
||||
m, err := strconv.Atoi(mStr)
|
||||
if err != nil { // invalid m
|
||||
return fmt.Errorf("invalid `m`: %s", mStr)
|
||||
}
|
||||
|
||||
// here is the only difference with IVF_PQ
|
||||
if m == 0 {
|
||||
return nil
|
||||
}
|
||||
if dimension%m != 0 {
|
||||
return fmt.Errorf("dimension must be abled to be divided by `m`, dimension: %d, m: %d", dimension, m)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func newRaftIVFPQChecker() IndexChecker {
|
||||
return &raftIVFPQChecker{}
|
||||
}
|
220
pkg/util/indexparamcheck/raft_ivf_pq_checker_test.go
Normal file
220
pkg/util/indexparamcheck/raft_ivf_pq_checker_test.go
Normal file
@ -0,0 +1,220 @@
|
||||
package indexparamcheck
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_raftIVFPQChecker_CheckTrain(t *testing.T) {
|
||||
|
||||
validParams := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(1024),
|
||||
IVFM: strconv.Itoa(4),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: L2,
|
||||
}
|
||||
|
||||
validParamsWithoutNbits := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(1024),
|
||||
IVFM: strconv.Itoa(4),
|
||||
Metric: L2,
|
||||
}
|
||||
|
||||
validParamsWithoutDim := map[string]string{
|
||||
NLIST: strconv.Itoa(1024),
|
||||
IVFM: strconv.Itoa(4),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: L2,
|
||||
}
|
||||
|
||||
invalidParamsDim := copyParams(validParams)
|
||||
invalidParamsDim[DIM] = "NAN"
|
||||
|
||||
invalidParamsNbits := copyParams(validParams)
|
||||
invalidParamsNbits[NBITS] = "NAN"
|
||||
|
||||
invalidParamsWithoutIVF := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(1024),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: L2,
|
||||
}
|
||||
|
||||
invalidParamsIVF := copyParams(validParams)
|
||||
invalidParamsIVF[IVFM] = "NAN"
|
||||
|
||||
invalidParamsM := copyParams(validParams)
|
||||
invalidParamsM[DIM] = strconv.Itoa(65536)
|
||||
|
||||
validParamsMzero := copyParams(validParams)
|
||||
validParamsMzero[IVFM] = "0"
|
||||
|
||||
p1 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(1024),
|
||||
IVFM: strconv.Itoa(4),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: L2,
|
||||
}
|
||||
p2 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(1024),
|
||||
IVFM: strconv.Itoa(4),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: IP,
|
||||
}
|
||||
p3 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(1024),
|
||||
IVFM: strconv.Itoa(4),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: COSINE,
|
||||
}
|
||||
|
||||
p4 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(1024),
|
||||
IVFM: strconv.Itoa(4),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: HAMMING,
|
||||
}
|
||||
p5 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(1024),
|
||||
IVFM: strconv.Itoa(4),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: JACCARD,
|
||||
}
|
||||
p6 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(1024),
|
||||
IVFM: strconv.Itoa(4),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: TANIMOTO,
|
||||
}
|
||||
p7 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(1024),
|
||||
IVFM: strconv.Itoa(4),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: SUBSTRUCTURE,
|
||||
}
|
||||
p8 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(1024),
|
||||
IVFM: strconv.Itoa(4),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: SUPERSTRUCTURE,
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
params map[string]string
|
||||
errIsNil bool
|
||||
}{
|
||||
{validParams, true},
|
||||
{validParamsWithoutNbits, true},
|
||||
{invalidIVFParamsMin(), false},
|
||||
{invalidIVFParamsMax(), false},
|
||||
{validParamsWithoutDim, false},
|
||||
{invalidParamsDim, false},
|
||||
{invalidParamsNbits, false},
|
||||
{invalidParamsWithoutIVF, false},
|
||||
{invalidParamsIVF, false},
|
||||
{invalidParamsM, false},
|
||||
{validParamsMzero, true},
|
||||
{p1, true},
|
||||
{p2, true},
|
||||
{p3, true},
|
||||
{p4, false},
|
||||
{p5, false},
|
||||
{p6, false},
|
||||
{p7, false},
|
||||
{p8, false},
|
||||
}
|
||||
|
||||
c := newRaftIVFPQChecker()
|
||||
for _, test := range cases {
|
||||
err := c.CheckTrain(test.params)
|
||||
if test.errIsNil {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
assert.Error(t, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test_raftIVFPQChecker_CheckValidDataType(t *testing.T) {
|
||||
cases := []struct {
|
||||
dType schemapb.DataType
|
||||
errIsNil bool
|
||||
}{
|
||||
{
|
||||
dType: schemapb.DataType_Bool,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Int8,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Int16,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Int32,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Int64,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Float,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Double,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_String,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_VarChar,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_Array,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_JSON,
|
||||
errIsNil: false,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_FloatVector,
|
||||
errIsNil: true,
|
||||
},
|
||||
{
|
||||
dType: schemapb.DataType_BinaryVector,
|
||||
errIsNil: false,
|
||||
},
|
||||
}
|
||||
|
||||
c := newRaftIVFPQChecker()
|
||||
for _, test := range cases {
|
||||
err := c.CheckValidDataType(test.dType)
|
||||
if test.errIsNil {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
assert.Error(t, err)
|
||||
}
|
||||
}
|
||||
}
|
@ -17,6 +17,7 @@
|
||||
package indexparamcheck
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
@ -57,3 +58,7 @@ func CheckStrByValues(params map[string]string, key string, container []string)
|
||||
|
||||
return funcutil.SliceContain(container, value)
|
||||
}
|
||||
|
||||
func errOutOfRange(x interface{}, lb interface{}, ub interface{}) error {
|
||||
return fmt.Errorf("%v out of range: [%v, %v]", x, lb, ub)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user