Add dimension validation.

Signed-off-by: sunby <bingyi.sun@zilliz.com>
This commit is contained in:
sunby 2020-11-26 14:54:03 +08:00 committed by yefu.chen
parent 045fc3d808
commit e8e21a0308
5 changed files with 65 additions and 1 deletions

View File

@ -28,4 +28,5 @@ proxy:
bufSize: 512 bufSize: 512
maxNameLength: 255 maxNameLength: 255
maxFieldNum: 64 maxFieldNum: 64
maxDimension: 32768

View File

@ -480,6 +480,18 @@ func (pt *ParamTable) MaxFieldNum() int64 {
return maxFieldNum return maxFieldNum
} }
func (pt *ParamTable) MaxDimension() int64 {
str, err := pt.Load("proxy.maxDimension")
if err != nil {
panic(err)
}
maxDimension, err := strconv.ParseInt(str, 10, 64)
if err != nil {
panic(err)
}
return maxDimension
}
func (pt *ParamTable) defaultPartitionTag() string { func (pt *ParamTable) defaultPartitionTag() string {
tag, err := pt.Load("common.defaultPartitionTag") tag, err := pt.Load("common.defaultPartitionTag")
if err != nil { if err != nil {

View File

@ -181,6 +181,33 @@ func (cct *CreateCollectionTask) PreExecute() error {
if err := ValidateFieldName(field.Name); err != nil { if err := ValidateFieldName(field.Name); err != nil {
return err return err
} }
if field.DataType == schemapb.DataType_VECTOR_FLOAT || field.DataType == schemapb.DataType_VECTOR_BINARY {
exist := false
var dim int64 = 0
for _, param := range field.TypeParams {
if param.Key == "dim" {
exist = true
tmp, err := strconv.ParseInt(param.Value, 10, 64)
if err != nil {
return err
}
dim = tmp
break
}
}
if !exist {
return errors.New("dimension is not defined in field type params")
}
if field.DataType == schemapb.DataType_VECTOR_FLOAT {
if err := ValidateDimension(dim, false); err != nil {
return err
}
} else {
if err := ValidateDimension(dim, true); err != nil {
return err
}
}
}
} }
return nil return nil

View File

@ -116,3 +116,14 @@ func ValidateFieldName(fieldName string) error {
} }
return nil return nil
} }
func ValidateDimension(dim int64, isBinary bool) error {
if dim <= 0 || dim > Params.MaxDimension() {
return errors.New("invalid dimension: " + strconv.FormatInt(dim, 10) + ". should be in range 1 ~ " +
strconv.FormatInt(Params.MaxDimension(), 10) + ".")
}
if isBinary && dim%8 != 0 {
return errors.New("invalid dimension: " + strconv.FormatInt(dim, 10) + ". should be multiple of 8.")
}
return nil
}

View File

@ -82,3 +82,16 @@ func TestValidateFieldName(t *testing.T) {
assert.NotNil(t, ValidateFieldName(name)) assert.NotNil(t, ValidateFieldName(name))
} }
} }
func TestValidateDimension(t *testing.T) {
Params.Init()
assert.Nil(t, ValidateDimension(1, false))
assert.Nil(t, ValidateDimension(Params.MaxDimension(), false))
assert.Nil(t, ValidateDimension(8, true))
assert.Nil(t, ValidateDimension(Params.MaxDimension(), true))
// invalid dim
assert.NotNil(t, ValidateDimension(-1, false))
assert.NotNil(t, ValidateDimension(Params.MaxDimension()+1, false))
assert.NotNil(t, ValidateDimension(9, true))
}