Import util functions (#16237)

Signed-off-by: groot <yihua.mo@zilliz.com>
This commit is contained in:
groot 2022-03-30 16:25:30 +08:00 committed by GitHub
parent 801eeffbcc
commit ffa06c77b6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 2211 additions and 0 deletions

View File

@ -0,0 +1,401 @@
package importutil
import (
"bufio"
"context"
"errors"
"os"
"path"
"strconv"
"github.com/milvus-io/milvus/internal/allocator"
"github.com/milvus-io/milvus/internal/common"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/proto/schemapb"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/util/typeutil"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
const (
JSONFileExt = ".json"
NumpyFileExt = ".npy"
)
type ImportWrapper struct {
ctx context.Context // for canceling parse process
cancel context.CancelFunc // for canceling parse process
collectionSchema *schemapb.CollectionSchema // collection schema
shardNum int32 // sharding number of the collection
segmentSize int32 // maximum size of a segment in MB
rowIDAllocator *allocator.IDAllocator // autoid allocator
callFlushFunc func(fields map[string]storage.FieldData) error // call back function to flush a segment
}
func NewImportWrapper(ctx context.Context, collectionSchema *schemapb.CollectionSchema, shardNum int32, segmentSize int32,
idAlloc *allocator.IDAllocator, flushFunc func(fields map[string]storage.FieldData) error) *ImportWrapper {
if collectionSchema == nil {
log.Error("import error: collection schema is nil")
return nil
}
// ignore the RowID field and Timestamp field
realSchema := &schemapb.CollectionSchema{
Name: collectionSchema.GetName(),
Description: collectionSchema.GetDescription(),
AutoID: collectionSchema.GetAutoID(),
Fields: make([]*schemapb.FieldSchema, 0),
}
for i := 0; i < len(collectionSchema.Fields); i++ {
schema := collectionSchema.Fields[i]
if schema.GetName() == common.RowIDFieldName || schema.GetName() == common.TimeStampFieldName {
continue
}
realSchema.Fields = append(realSchema.Fields, schema)
}
ctx, cancel := context.WithCancel(ctx)
wrapper := &ImportWrapper{
ctx: ctx,
cancel: cancel,
collectionSchema: realSchema,
shardNum: shardNum,
segmentSize: segmentSize,
rowIDAllocator: idAlloc,
callFlushFunc: flushFunc,
}
return wrapper
}
// this method can be used to cancel parse process
func (p *ImportWrapper) Cancel() error {
p.cancel()
return nil
}
func (p *ImportWrapper) printFieldsDataInfo(fieldsData map[string]storage.FieldData, msg string, files []string) {
stats := make([]zapcore.Field, 0)
for k, v := range fieldsData {
stats = append(stats, zap.Int(k, v.RowNum()))
}
for i := 0; i < len(files); i++ {
stats = append(stats, zap.String("file", files[i]))
}
log.Debug(msg, stats...)
}
// import process entry
// filePath and rowBased are from ImportTask
// if onlyValidate is true, this process only do validation, no data generated, callFlushFunc will not be called
func (p *ImportWrapper) Import(filePaths []string, rowBased bool, onlyValidate bool) error {
if rowBased {
// parse and consume row-based files
// for row-based files, the JSONRowConsumer will generate autoid for primary key, and split rows into segments
// according to shard number, so the callFlushFunc will be called in the JSONRowConsumer
for i := 0; i < len(filePaths); i++ {
filePath := filePaths[i]
fileName := path.Base(filePath)
fileType := path.Ext(fileName)
log.Debug("imprort wrapper: row-based file ", zap.Any("filePath", filePath), zap.Any("fileType", fileType))
if fileType == JSONFileExt {
err := func() error {
file, err := os.Open(filePath)
if err != nil {
return err
}
defer file.Close()
reader := bufio.NewReader(file)
parser := NewJSONParser(p.ctx, p.collectionSchema)
var consumer *JSONRowConsumer
if !onlyValidate {
flushFunc := func(fields map[string]storage.FieldData) error {
p.printFieldsDataInfo(fields, "import wrapper: prepare to flush segment", filePaths)
return p.callFlushFunc(fields)
}
consumer = NewJSONRowConsumer(p.collectionSchema, p.rowIDAllocator, p.shardNum, p.segmentSize, flushFunc)
}
validator := NewJSONRowValidator(p.collectionSchema, consumer)
err = parser.ParseRows(reader, validator)
if err != nil {
return err
}
return nil
}()
if err != nil {
log.Error("imprort error: "+err.Error(), zap.String("filePath", filePath))
return err
}
}
}
} else {
// parse and consume row-based files
// for column-based files, the XXXColumnConsumer only output map[string]storage.FieldData
// after all columns are parsed/consumed, we need to combine map[string]storage.FieldData into one
// and use splitFieldsData() to split fields data into segments according to shard number
fieldsData := initSegmentData(p.collectionSchema)
rowCount := 0
// function to combine column data into fieldsData
combineFunc := func(fields map[string]storage.FieldData) error {
if len(fields) == 0 {
return nil
}
fieldNames := make([]string, 0)
for k, v := range fields {
data, ok := fieldsData[k]
if ok && data.RowNum() > 0 {
return errors.New("imprort error: the field " + k + " is duplicated")
}
fieldsData[k] = v
fieldNames = append(fieldNames, k)
if rowCount == 0 {
rowCount = v.RowNum()
} else if rowCount != v.RowNum() {
return errors.New("imprort error: the field " + k + " row count " + strconv.Itoa(v.RowNum()) + " doesn't equal " + strconv.Itoa(rowCount))
}
}
log.Debug("imprort wrapper: ", zap.Any("fieldNames", fieldNames), zap.Int("rowCount", rowCount))
return nil
}
// parse/validate/consume data
for i := 0; i < len(filePaths); i++ {
filePath := filePaths[i]
fileName := path.Base(filePath)
fileType := path.Ext(fileName)
log.Debug("imprort wrapper: column-based file ", zap.Any("filePath", filePath), zap.Any("fileType", fileType))
if fileType == JSONFileExt {
err := func() error {
file, err := os.Open(filePath)
if err != nil {
log.Error("imprort error: "+err.Error(), zap.String("filePath", filePath))
return err
}
defer file.Close()
reader := bufio.NewReader(file)
parser := NewJSONParser(p.ctx, p.collectionSchema)
var consumer *JSONColumnConsumer
if !onlyValidate {
consumer = NewJSONColumnConsumer(p.collectionSchema, combineFunc)
}
validator := NewJSONColumnValidator(p.collectionSchema, consumer)
err = parser.ParseColumns(reader, validator)
if err != nil {
log.Error("imprort error: "+err.Error(), zap.String("filePath", filePath))
return err
}
return nil
}()
if err != nil {
log.Error("imprort error: "+err.Error(), zap.String("filePath", filePath))
return err
}
} else if fileType == NumpyFileExt {
}
}
// split fields data into segments
err := p.splitFieldsData(fieldsData, filePaths)
if err != nil {
log.Error("imprort error: " + err.Error())
return err
}
}
return nil
}
func (p *ImportWrapper) appendFunc(schema *schemapb.FieldSchema) func(src storage.FieldData, n int, target storage.FieldData) error {
switch schema.DataType {
case schemapb.DataType_Bool:
return func(src storage.FieldData, n int, target storage.FieldData) error {
arr := target.(*storage.BoolFieldData)
arr.Data = append(arr.Data, src.GetRow(n).(bool))
arr.NumRows[0]++
return nil
}
case schemapb.DataType_Float:
return func(src storage.FieldData, n int, target storage.FieldData) error {
arr := target.(*storage.FloatFieldData)
arr.Data = append(arr.Data, src.GetRow(n).(float32))
arr.NumRows[0]++
return nil
}
case schemapb.DataType_Double:
return func(src storage.FieldData, n int, target storage.FieldData) error {
arr := target.(*storage.DoubleFieldData)
arr.Data = append(arr.Data, src.GetRow(n).(float64))
arr.NumRows[0]++
return nil
}
case schemapb.DataType_Int8:
return func(src storage.FieldData, n int, target storage.FieldData) error {
arr := target.(*storage.Int8FieldData)
arr.Data = append(arr.Data, src.GetRow(n).(int8))
arr.NumRows[0]++
return nil
}
case schemapb.DataType_Int16:
return func(src storage.FieldData, n int, target storage.FieldData) error {
arr := target.(*storage.Int16FieldData)
arr.Data = append(arr.Data, src.GetRow(n).(int16))
arr.NumRows[0]++
return nil
}
case schemapb.DataType_Int32:
return func(src storage.FieldData, n int, target storage.FieldData) error {
arr := target.(*storage.Int32FieldData)
arr.Data = append(arr.Data, src.GetRow(n).(int32))
arr.NumRows[0]++
return nil
}
case schemapb.DataType_Int64:
return func(src storage.FieldData, n int, target storage.FieldData) error {
arr := target.(*storage.Int64FieldData)
arr.Data = append(arr.Data, src.GetRow(n).(int64))
arr.NumRows[0]++
return nil
}
case schemapb.DataType_BinaryVector:
return func(src storage.FieldData, n int, target storage.FieldData) error {
arr := target.(*storage.BinaryVectorFieldData)
arr.Data = append(arr.Data, src.GetRow(n).([]byte)...)
arr.NumRows[0]++
return nil
}
case schemapb.DataType_FloatVector:
return func(src storage.FieldData, n int, target storage.FieldData) error {
arr := target.(*storage.FloatVectorFieldData)
arr.Data = append(arr.Data, src.GetRow(n).([]float32)...)
arr.NumRows[0]++
return nil
}
case schemapb.DataType_String:
return func(src storage.FieldData, n int, target storage.FieldData) error {
arr := target.(*storage.StringFieldData)
arr.Data = append(arr.Data, src.GetRow(n).(string))
return nil
}
default:
return nil
}
}
func (p *ImportWrapper) splitFieldsData(fieldsData map[string]storage.FieldData, files []string) error {
if len(fieldsData) == 0 {
return errors.New("imprort error: fields data is empty")
}
var primaryKey *schemapb.FieldSchema
for i := 0; i < len(p.collectionSchema.Fields); i++ {
schema := p.collectionSchema.Fields[i]
if schema.GetIsPrimaryKey() {
primaryKey = schema
} else {
_, ok := fieldsData[schema.GetName()]
if !ok {
return errors.New("imprort error: field " + schema.GetName() + " not provided")
}
}
}
if primaryKey == nil {
return errors.New("imprort error: primary key field is not found")
}
rowCount := 0
for _, v := range fieldsData {
rowCount = v.RowNum()
break
}
primaryData, ok := fieldsData[primaryKey.GetName()]
if !ok {
// generate auto id for primary key
if primaryKey.GetAutoID() {
var rowIDBegin typeutil.UniqueID
var rowIDEnd typeutil.UniqueID
rowIDBegin, rowIDEnd, _ = p.rowIDAllocator.Alloc(uint32(rowCount))
primaryDataArr := primaryData.(*storage.Int64FieldData)
for i := rowIDBegin; i < rowIDEnd; i++ {
primaryDataArr.Data = append(primaryDataArr.Data, rowIDBegin+i)
}
}
}
if primaryData.RowNum() <= 0 {
return errors.New("imprort error: primary key " + primaryKey.GetName() + " not provided")
}
// prepare segemnts
segmentsData := make([]map[string]storage.FieldData, 0, p.shardNum)
for i := 0; i < int(p.shardNum); i++ {
segmentData := initSegmentData(p.collectionSchema)
if segmentData == nil {
return nil
}
segmentsData = append(segmentsData, segmentData)
}
// prepare append functions
appendFunctions := make(map[string]func(src storage.FieldData, n int, target storage.FieldData) error)
for i := 0; i < len(p.collectionSchema.Fields); i++ {
schema := p.collectionSchema.Fields[i]
appendFunc := p.appendFunc(schema)
if appendFunc == nil {
return errors.New("imprort error: unsupported field data type")
}
appendFunctions[schema.GetName()] = appendFunc
}
// split data into segments
for i := 0; i < rowCount; i++ {
id := primaryData.GetRow(i).(int64)
// hash to a shard number
hash, _ := typeutil.Hash32Int64(id)
shard := hash % uint32(p.shardNum)
for k := 0; k < len(p.collectionSchema.Fields); k++ {
schema := p.collectionSchema.Fields[k]
srcData := fieldsData[schema.GetName()]
targetData := segmentsData[shard][schema.GetName()]
appendFunc := appendFunctions[schema.GetName()]
err := appendFunc(srcData, i, targetData)
if err != nil {
return err
}
}
}
// call flush function
for i := 0; i < int(p.shardNum); i++ {
segmentData := segmentsData[i]
p.printFieldsDataInfo(segmentData, "import wrapper: prepare to flush segment", files)
err := p.callFlushFunc(segmentData)
if err != nil {
return err
}
}
return nil
}

View File

@ -0,0 +1,205 @@
package importutil
import (
"context"
"os"
"testing"
"github.com/milvus-io/milvus/internal/common"
"github.com/milvus-io/milvus/internal/proto/schemapb"
"github.com/milvus-io/milvus/internal/storage"
"github.com/stretchr/testify/assert"
)
const (
TempFilesPath = "/tmp/milvus_test/import/"
)
func Test_NewImportWrapper(t *testing.T) {
ctx := context.Background()
wrapper := NewImportWrapper(ctx, nil, 2, 1, nil, nil)
assert.Nil(t, wrapper)
schema := &schemapb.CollectionSchema{
Name: "schema",
Description: "schema",
AutoID: true,
Fields: make([]*schemapb.FieldSchema, 0),
}
schema.Fields = append(schema.Fields, sampleSchema().Fields...)
schema.Fields = append(schema.Fields, &schemapb.FieldSchema{
FieldID: 106,
Name: common.RowIDFieldName,
IsPrimaryKey: true,
AutoID: false,
Description: "int64",
DataType: schemapb.DataType_Int64,
})
wrapper = NewImportWrapper(ctx, schema, 2, 1, nil, nil)
assert.NotNil(t, wrapper)
err := wrapper.Cancel()
assert.Nil(t, err)
}
func saveFile(t *testing.T, filePath string, content []byte) *os.File {
fp, err := os.Create(filePath)
assert.Nil(t, err)
_, err = fp.Write(content)
assert.Nil(t, err)
return fp
}
func Test_ImportRowBased(t *testing.T) {
ctx := context.Background()
err := os.MkdirAll(TempFilesPath, os.ModePerm)
assert.Nil(t, err)
defer os.RemoveAll(TempFilesPath)
idAllocator := newIDAllocator(ctx, t)
content := []byte(`{
"rows":[
{"field_bool": true, "field_int8": 10, "field_int16": 101, "field_int32": 1001, "field_int64": 10001, "field_float": 3.14, "field_double": 1.56, "field_string": "hello world", "field_binary_vector": [254, 0], "field_float_vector": [1.1, 1.2, 1.3, 1.4]},
{"field_bool": false, "field_int8": 11, "field_int16": 102, "field_int32": 1002, "field_int64": 10002, "field_float": 3.15, "field_double": 2.56, "field_string": "hello world", "field_binary_vector": [253, 0], "field_float_vector": [2.1, 2.2, 2.3, 2.4]},
{"field_bool": true, "field_int8": 12, "field_int16": 103, "field_int32": 1003, "field_int64": 10003, "field_float": 3.16, "field_double": 3.56, "field_string": "hello world", "field_binary_vector": [252, 0], "field_float_vector": [3.1, 3.2, 3.3, 3.4]},
{"field_bool": false, "field_int8": 13, "field_int16": 104, "field_int32": 1004, "field_int64": 10004, "field_float": 3.17, "field_double": 4.56, "field_string": "hello world", "field_binary_vector": [251, 0], "field_float_vector": [4.1, 4.2, 4.3, 4.4]},
{"field_bool": true, "field_int8": 14, "field_int16": 105, "field_int32": 1005, "field_int64": 10005, "field_float": 3.18, "field_double": 5.56, "field_string": "hello world", "field_binary_vector": [250, 0], "field_float_vector": [5.1, 5.2, 5.3, 5.4]}
]
}`)
filePath := TempFilesPath + "rows_1.json"
fp1 := saveFile(t, filePath, content)
defer fp1.Close()
rowCount := 0
flushFunc := func(fields map[string]storage.FieldData) error {
count := 0
for _, data := range fields {
assert.Less(t, 0, data.RowNum())
if count == 0 {
count = data.RowNum()
} else {
assert.Equal(t, count, data.RowNum())
}
}
rowCount += count
return nil
}
// success case
wrapper := NewImportWrapper(ctx, sampleSchema(), 2, 1, idAllocator, flushFunc)
files := make([]string, 0)
files = append(files, filePath)
err = wrapper.Import(files, true, false)
assert.Nil(t, err)
assert.Equal(t, 5, rowCount)
// parse error
content = []byte(`{
"rows":[
{"field_bool": true, "field_int8": false, "field_int16": 101, "field_int32": 1001, "field_int64": 10001, "field_float": 3.14, "field_double": 1.56, "field_string": "hello world", "field_binary_vector": [254, 0], "field_float_vector": [1.1, 1.2, 1.3, 1.4]},
]
}`)
filePath = TempFilesPath + "rows_2.json"
fp2 := saveFile(t, filePath, content)
defer fp2.Close()
wrapper = NewImportWrapper(ctx, sampleSchema(), 2, 1, idAllocator, flushFunc)
files = make([]string, 0)
files = append(files, filePath)
err = wrapper.Import(files, true, false)
assert.NotNil(t, err)
// file doesn't exist
files = make([]string, 0)
files = append(files, "/dummy/dummy.json")
err = wrapper.Import(files, true, false)
assert.NotNil(t, err)
}
func Test_ImportColumnBased(t *testing.T) {
ctx := context.Background()
err := os.MkdirAll(TempFilesPath, os.ModePerm)
assert.Nil(t, err)
defer os.RemoveAll(TempFilesPath)
idAllocator := newIDAllocator(ctx, t)
content := []byte(`{
"field_bool": [true, false, true, true, true],
"field_int8": [10, 11, 12, 13, 14],
"field_int16": [100, 101, 102, 103, 104],
"field_int32": [1000, 1001, 1002, 1003, 1004],
"field_int64": [10000, 10001, 10002, 10003, 10004],
"field_float": [3.14, 3.15, 3.16, 3.17, 3.18],
"field_double": [5.1, 5.2, 5.3, 5.4, 5.5],
"field_string": ["a", "b", "c", "d", "e"],
"field_binary_vector": [
[254, 1],
[253, 2],
[252, 3],
[251, 4],
[250, 5]
],
"field_float_vector": [
[1.1, 1.2, 1.3, 1.4],
[2.1, 2.2, 2.3, 2.4],
[3.1, 3.2, 3.3, 3.4],
[4.1, 4.2, 4.3, 4.4],
[5.1, 5.2, 5.3, 5.4]
]
}`)
filePath := TempFilesPath + "columns_1.json"
fp1 := saveFile(t, filePath, content)
defer fp1.Close()
rowCount := 0
flushFunc := func(fields map[string]storage.FieldData) error {
count := 0
for _, data := range fields {
assert.Less(t, 0, data.RowNum())
if count == 0 {
count = data.RowNum()
} else {
assert.Equal(t, count, data.RowNum())
}
}
rowCount += count
return nil
}
// success case
wrapper := NewImportWrapper(ctx, sampleSchema(), 2, 1, idAllocator, flushFunc)
files := make([]string, 0)
files = append(files, filePath)
err = wrapper.Import(files, false, false)
assert.Nil(t, err)
assert.Equal(t, 5, rowCount)
// parse error
content = []byte(`{
"field_bool": [true, false, true, true, true]
}`)
filePath = TempFilesPath + "rows_2.json"
fp2 := saveFile(t, filePath, content)
defer fp2.Close()
wrapper = NewImportWrapper(ctx, sampleSchema(), 2, 1, idAllocator, flushFunc)
files = make([]string, 0)
files = append(files, filePath)
err = wrapper.Import(files, false, false)
assert.NotNil(t, err)
// file doesn't exist
files = make([]string, 0)
files = append(files, "/dummy/dummy.json")
err = wrapper.Import(files, false, false)
assert.NotNil(t, err)
}

View File

@ -0,0 +1,714 @@
package importutil
import (
"errors"
"fmt"
"strconv"
"github.com/milvus-io/milvus/internal/allocator"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/proto/schemapb"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/util/typeutil"
"go.uber.org/zap"
)
// interface to process rows data
type JSONRowHandler interface {
Handle(rows []map[string]interface{}) error
}
// interface to process column data
type JSONColumnHandler interface {
Handle(columns map[string][]interface{}) error
}
// method to get dimension of vecotor field
func getFieldDimension(schema *schemapb.FieldSchema) (int, error) {
for _, kvPair := range schema.GetTypeParams() {
key, value := kvPair.GetKey(), kvPair.GetValue()
if key == "dim" {
dim, err := strconv.Atoi(value)
if err != nil {
return 0, errors.New("vector dimension is invalid")
}
return dim, nil
}
}
return 0, errors.New("vector dimension is not defined")
}
// field value validator
type Validator struct {
validateFunc func(obj interface{}) error // validate data type function
convertFunc func(obj interface{}, field storage.FieldData) error // convert data function
primaryKey bool // true for primary key
autoID bool // only for primary key field
dimension int // only for vector field
}
// method to construct valiator functions
func initValidators(collectionSchema *schemapb.CollectionSchema, validators map[string]*Validator) error {
if collectionSchema == nil {
return errors.New("collection schema is nil")
}
// json decoder parse all the numeric value into float64
numericValidator := func(obj interface{}) error {
switch obj.(type) {
case float64:
return nil
default:
s := fmt.Sprintf("%v", obj)
msg := "illegal numeric value " + s
return errors.New(msg)
}
}
for i := 0; i < len(collectionSchema.Fields); i++ {
schema := collectionSchema.Fields[i]
validators[schema.GetName()] = &Validator{}
validators[schema.GetName()].primaryKey = schema.GetIsPrimaryKey()
validators[schema.GetName()].autoID = schema.GetAutoID()
switch schema.DataType {
case schemapb.DataType_Bool:
validators[schema.GetName()].validateFunc = func(obj interface{}) error {
switch obj.(type) {
case bool:
return nil
default:
s := fmt.Sprintf("%v", obj)
msg := "illegal value " + s + " for bool type field " + schema.GetName()
return errors.New(msg)
}
}
validators[schema.GetName()].convertFunc = func(obj interface{}, field storage.FieldData) error {
value := obj.(bool)
field.(*storage.BoolFieldData).Data = append(field.(*storage.BoolFieldData).Data, value)
field.(*storage.BoolFieldData).NumRows[0]++
return nil
}
case schemapb.DataType_Float:
validators[schema.GetName()].validateFunc = numericValidator
validators[schema.GetName()].convertFunc = func(obj interface{}, field storage.FieldData) error {
value := float32(obj.(float64))
field.(*storage.FloatFieldData).Data = append(field.(*storage.FloatFieldData).Data, value)
field.(*storage.FloatFieldData).NumRows[0]++
return nil
}
case schemapb.DataType_Double:
validators[schema.GetName()].validateFunc = numericValidator
validators[schema.GetName()].convertFunc = func(obj interface{}, field storage.FieldData) error {
value := obj.(float64)
field.(*storage.DoubleFieldData).Data = append(field.(*storage.DoubleFieldData).Data, value)
field.(*storage.DoubleFieldData).NumRows[0]++
return nil
}
case schemapb.DataType_Int8:
validators[schema.GetName()].validateFunc = numericValidator
validators[schema.GetName()].convertFunc = func(obj interface{}, field storage.FieldData) error {
value := int8(obj.(float64))
field.(*storage.Int8FieldData).Data = append(field.(*storage.Int8FieldData).Data, value)
field.(*storage.Int8FieldData).NumRows[0]++
return nil
}
case schemapb.DataType_Int16:
validators[schema.GetName()].validateFunc = numericValidator
validators[schema.GetName()].convertFunc = func(obj interface{}, field storage.FieldData) error {
value := int16(obj.(float64))
field.(*storage.Int16FieldData).Data = append(field.(*storage.Int16FieldData).Data, value)
field.(*storage.Int16FieldData).NumRows[0]++
return nil
}
case schemapb.DataType_Int32:
validators[schema.GetName()].validateFunc = numericValidator
validators[schema.GetName()].convertFunc = func(obj interface{}, field storage.FieldData) error {
value := int32(obj.(float64))
field.(*storage.Int32FieldData).Data = append(field.(*storage.Int32FieldData).Data, value)
field.(*storage.Int32FieldData).NumRows[0]++
return nil
}
case schemapb.DataType_Int64:
validators[schema.GetName()].validateFunc = numericValidator
validators[schema.GetName()].convertFunc = func(obj interface{}, field storage.FieldData) error {
value := int64(obj.(float64))
field.(*storage.Int64FieldData).Data = append(field.(*storage.Int64FieldData).Data, value)
field.(*storage.Int64FieldData).NumRows[0]++
return nil
}
case schemapb.DataType_BinaryVector:
dim, err := getFieldDimension(schema)
if err != nil {
return err
}
validators[schema.GetName()].dimension = dim
validators[schema.GetName()].validateFunc = func(obj interface{}) error {
switch vt := obj.(type) {
case []interface{}:
if len(vt)*8 != dim {
msg := "bit size " + strconv.Itoa(len(vt)*8) + " doesn't equal to vector dimension " + strconv.Itoa(dim)
return errors.New(msg)
}
for i := 0; i < len(vt); i++ {
if e := numericValidator(vt[i]); e != nil {
msg := e.Error() + " for binary vector field " + schema.GetName()
return errors.New(msg)
}
t := int(vt[i].(float64))
if t >= 255 || t < 0 {
msg := "illegal value " + strconv.Itoa(t) + " for binary vector field " + schema.GetName()
return errors.New(msg)
}
}
return nil
default:
s := fmt.Sprintf("%v", obj)
msg := s + " is not an array for binary vector field " + schema.GetName()
return errors.New(msg)
}
}
validators[schema.GetName()].convertFunc = func(obj interface{}, field storage.FieldData) error {
arr := obj.([]interface{})
for i := 0; i < len(arr); i++ {
value := byte(arr[i].(float64))
field.(*storage.BinaryVectorFieldData).Data = append(field.(*storage.BinaryVectorFieldData).Data, value)
}
field.(*storage.BinaryVectorFieldData).NumRows[0]++
return nil
}
case schemapb.DataType_FloatVector:
dim, err := getFieldDimension(schema)
if err != nil {
return err
}
validators[schema.GetName()].dimension = dim
validators[schema.GetName()].validateFunc = func(obj interface{}) error {
switch vt := obj.(type) {
case []interface{}:
if len(vt) != dim {
msg := "array size " + strconv.Itoa(len(vt)) + " doesn't equal to vector dimension " + strconv.Itoa(dim)
return errors.New(msg)
}
for i := 0; i < len(vt); i++ {
if e := numericValidator(vt[i]); e != nil {
msg := e.Error() + " for float vector field " + schema.GetName()
return errors.New(msg)
}
}
return nil
default:
s := fmt.Sprintf("%v", obj)
msg := s + " is not an array for float vector field " + schema.GetName()
return errors.New(msg)
}
}
validators[schema.GetName()].convertFunc = func(obj interface{}, field storage.FieldData) error {
arr := obj.([]interface{})
for i := 0; i < len(arr); i++ {
value := float32(arr[i].(float64))
field.(*storage.FloatVectorFieldData).Data = append(field.(*storage.FloatVectorFieldData).Data, value)
}
field.(*storage.FloatVectorFieldData).NumRows[0]++
return nil
}
case schemapb.DataType_String:
validators[schema.GetName()].validateFunc = func(obj interface{}) error {
switch obj.(type) {
case string:
return nil
default:
s := fmt.Sprintf("%v", obj)
msg := s + " is not a string for string type field " + schema.GetName()
return errors.New(msg)
}
}
validators[schema.GetName()].convertFunc = func(obj interface{}, field storage.FieldData) error {
value := obj.(string)
field.(*storage.StringFieldData).Data = append(field.(*storage.StringFieldData).Data, value)
field.(*storage.StringFieldData).NumRows[0]++
return nil
}
default:
return errors.New("unsupport data type: " + strconv.Itoa(int(collectionSchema.Fields[i].DataType)))
}
}
return nil
}
// row-based json format validator class
type JSONRowValidator struct {
downstream JSONRowHandler // downstream processor, typically is a JSONRowComsumer
validators map[string]*Validator // validators for each field
rowCounter int64 // how many rows have been validated
}
func NewJSONRowValidator(collectionSchema *schemapb.CollectionSchema, downstream JSONRowHandler) *JSONRowValidator {
v := &JSONRowValidator{
validators: make(map[string]*Validator),
downstream: downstream,
rowCounter: 0,
}
initValidators(collectionSchema, v.validators)
return v
}
func (v *JSONRowValidator) ValidateCount() int64 {
return v.rowCounter
}
func (v *JSONRowValidator) Handle(rows []map[string]interface{}) error {
if v.validators == nil || len(v.validators) == 0 {
return errors.New("JSON row validator is not initialized")
}
// parse completed
if rows == nil {
log.Debug("JSON row validation finished")
if v.downstream != nil {
return v.downstream.Handle(rows)
}
return nil
}
for i := 0; i < len(rows); i++ {
row := rows[i]
for name, validator := range v.validators {
if validator.primaryKey && validator.autoID {
// auto-generated primary key, ignore
continue
}
value, ok := row[name]
if !ok {
return errors.New("JSON row validator: field " + name + " missed at the row " + strconv.FormatInt(v.rowCounter+int64(i), 10))
}
if err := validator.validateFunc(value); err != nil {
return errors.New("JSON row validator: " + err.Error() + " at the row " + strconv.FormatInt(v.rowCounter, 10))
}
}
}
v.rowCounter += int64(len(rows))
if v.downstream != nil {
return v.downstream.Handle(rows)
}
return nil
}
// column-based json format validator class
type JSONColumnValidator struct {
downstream JSONColumnHandler // downstream processor, typically is a JSONColumnComsumer
validators map[string]*Validator // validators for each field
rowCounter map[string]int64 // row count of each field
}
func NewJSONColumnValidator(schema *schemapb.CollectionSchema, downstream JSONColumnHandler) *JSONColumnValidator {
v := &JSONColumnValidator{
validators: make(map[string]*Validator),
downstream: downstream,
rowCounter: make(map[string]int64),
}
initValidators(schema, v.validators)
for k := range v.validators {
v.rowCounter[k] = 0
}
return v
}
func (v *JSONColumnValidator) ValidateCount() map[string]int64 {
return v.rowCounter
}
func (v *JSONColumnValidator) Handle(columns map[string][]interface{}) error {
if v.validators == nil || len(v.validators) == 0 {
return errors.New("JSON column validator is not initialized")
}
// parse completed
if columns == nil {
// all columns are parsed?
maxCount := int64(0)
for _, counter := range v.rowCounter {
if counter > maxCount {
maxCount = counter
}
}
for k := range v.validators {
counter, ok := v.rowCounter[k]
if !ok || counter != maxCount {
return errors.New("JSON column validator: the field " + k + " row count is not equal to other fields")
}
}
log.Debug("JSON column validation finished")
if v.downstream != nil {
return v.downstream.Handle(nil)
}
return nil
}
for name, values := range columns {
validator, ok := v.validators[name]
if !ok {
// not a valid field name
break
}
for i := 0; i < len(values); i++ {
if err := validator.validateFunc(values[i]); err != nil {
return errors.New("JSON column validator: " + err.Error() + " at the row " + strconv.FormatInt(v.rowCounter[name]+int64(i), 10))
}
}
v.rowCounter[name] += int64(len(values))
}
if v.downstream != nil {
return v.downstream.Handle(columns)
}
return nil
}
// row-based json format consumer class
type JSONRowConsumer struct {
collectionSchema *schemapb.CollectionSchema // collection schema
rowIDAllocator *allocator.IDAllocator // autoid allocator
validators map[string]*Validator // validators for each field
rowCounter int64 // how many rows have been consumed
shardNum int32 // sharding number of the collection
segmentsData []map[string]storage.FieldData // in-memory segments data
segmentSize int32 // maximum size of a segment in MB
primaryKey string // name of primary key
callFlushFunc func(fields map[string]storage.FieldData) error // call back function to flush segment
}
func initSegmentData(collectionSchema *schemapb.CollectionSchema) map[string]storage.FieldData {
segmentData := make(map[string]storage.FieldData)
for i := 0; i < len(collectionSchema.Fields); i++ {
schema := collectionSchema.Fields[i]
switch schema.DataType {
case schemapb.DataType_Bool:
segmentData[schema.GetName()] = &storage.BoolFieldData{
Data: make([]bool, 0),
NumRows: []int64{0},
}
case schemapb.DataType_Float:
segmentData[schema.GetName()] = &storage.FloatFieldData{
Data: make([]float32, 0),
NumRows: []int64{0},
}
case schemapb.DataType_Double:
segmentData[schema.GetName()] = &storage.DoubleFieldData{
Data: make([]float64, 0),
NumRows: []int64{0},
}
case schemapb.DataType_Int8:
segmentData[schema.GetName()] = &storage.Int8FieldData{
Data: make([]int8, 0),
NumRows: []int64{0},
}
case schemapb.DataType_Int16:
segmentData[schema.GetName()] = &storage.Int16FieldData{
Data: make([]int16, 0),
NumRows: []int64{0},
}
case schemapb.DataType_Int32:
segmentData[schema.GetName()] = &storage.Int32FieldData{
Data: make([]int32, 0),
NumRows: []int64{0},
}
case schemapb.DataType_Int64:
segmentData[schema.GetName()] = &storage.Int64FieldData{
Data: make([]int64, 0),
NumRows: []int64{0},
}
case schemapb.DataType_BinaryVector:
dim, _ := getFieldDimension(schema)
segmentData[schema.GetName()] = &storage.BinaryVectorFieldData{
Data: make([]byte, 0),
NumRows: []int64{0},
Dim: dim,
}
case schemapb.DataType_FloatVector:
dim, _ := getFieldDimension(schema)
segmentData[schema.GetName()] = &storage.FloatVectorFieldData{
Data: make([]float32, 0),
NumRows: []int64{0},
Dim: dim,
}
case schemapb.DataType_String:
segmentData[schema.GetName()] = &storage.StringFieldData{
Data: make([]string, 0),
NumRows: []int64{0},
}
default:
log.Error("JSON row consumer error: unsupported data type", zap.Int("DataType", int(schema.DataType)))
return nil
}
}
return segmentData
}
func NewJSONRowConsumer(collectionSchema *schemapb.CollectionSchema, idAlloc *allocator.IDAllocator, shardNum int32, segmentSize int32,
flushFunc func(fields map[string]storage.FieldData) error) *JSONRowConsumer {
if collectionSchema == nil {
log.Error("JSON row consumer: collection schema is nil")
return nil
}
v := &JSONRowConsumer{
collectionSchema: collectionSchema,
rowIDAllocator: idAlloc,
validators: make(map[string]*Validator),
shardNum: shardNum,
segmentSize: segmentSize,
rowCounter: 0,
callFlushFunc: flushFunc,
}
initValidators(collectionSchema, v.validators)
v.segmentsData = make([]map[string]storage.FieldData, 0, shardNum)
for i := 0; i < int(shardNum); i++ {
segmentData := initSegmentData(collectionSchema)
if segmentData == nil {
return nil
}
v.segmentsData = append(v.segmentsData, segmentData)
}
for i := 0; i < len(collectionSchema.Fields); i++ {
schema := collectionSchema.Fields[i]
if schema.GetIsPrimaryKey() {
v.primaryKey = schema.GetName()
break
}
}
// primary key not found
if v.primaryKey == "" {
log.Error("JSON row consumer: collection schema has no primary key")
return nil
}
// primary key is autoid, id generator is required
if v.validators[v.primaryKey].autoID && idAlloc == nil {
log.Error("JSON row consumer: ID allocator is nil")
return nil
}
return v
}
func (v *JSONRowConsumer) flush(force bool) error {
// force flush all data
if force {
for i := 0; i < len(v.segmentsData); i++ {
segmentData := v.segmentsData[i]
rowNum := segmentData[v.primaryKey].RowNum()
if rowNum > 0 {
log.Debug("JSON row consumer: force flush segment", zap.Int("rows", rowNum))
v.callFlushFunc(segmentData)
}
}
return nil
}
// segment size can be flushed
for i := 0; i < len(v.segmentsData); i++ {
segmentData := v.segmentsData[i]
memSize := 0
for _, field := range segmentData {
memSize += field.GetMemorySize()
}
if memSize >= int(v.segmentSize)*1024*1024 {
log.Debug("JSON row consumer: flush fulled segment", zap.Int("bytes", memSize))
v.callFlushFunc(segmentData)
v.segmentsData[i] = initSegmentData(v.collectionSchema)
}
}
return nil
}
func (v *JSONRowConsumer) Handle(rows []map[string]interface{}) error {
if v.validators == nil || len(v.validators) == 0 {
return errors.New("JSON row consumer is not initialized")
}
// flush in necessery
if rows == nil {
err := v.flush(true)
log.Debug("JSON row consumer finished")
return err
}
err := v.flush(false)
if err != nil {
return err
}
// prepare autoid
primaryValidator := v.validators[v.primaryKey]
var rowIDBegin typeutil.UniqueID
var rowIDEnd typeutil.UniqueID
if primaryValidator.autoID {
var err error
rowIDBegin, rowIDEnd, err = v.rowIDAllocator.Alloc(uint32(len(rows)))
if err != nil {
return errors.New("JSON row consumer: " + err.Error())
}
if rowIDEnd-rowIDBegin != int64(len(rows)) {
return errors.New("JSON row consumer: failed to allocate ID for " + strconv.Itoa(len(rows)) + " rows")
}
}
// consume rows
for i := 0; i < len(rows); i++ {
row := rows[i]
// firstly get/generate the row id
var id int64
if primaryValidator.autoID {
id = rowIDBegin + int64(i)
} else {
value := row[v.primaryKey]
id = int64(value.(float64))
}
// hash to a shard number
hash, _ := typeutil.Hash32Int64(id)
shard := hash % uint32(v.shardNum)
pkArray := v.segmentsData[shard][v.primaryKey].(*storage.Int64FieldData)
pkArray.Data = append(pkArray.Data, id)
// convert value and consume
for name, validator := range v.validators {
if validator.primaryKey {
continue
}
value := row[name]
if err := validator.convertFunc(value, v.segmentsData[shard][name]); err != nil {
return errors.New("JSON row consumer: " + err.Error() + " at the row " + strconv.FormatInt(v.rowCounter, 10))
}
}
}
v.rowCounter += int64(len(rows))
return nil
}
// column-based json format consumer class
type JSONColumnConsumer struct {
collectionSchema *schemapb.CollectionSchema // collection schema
validators map[string]*Validator // validators for each field
fieldsData map[string]storage.FieldData // in-memory fields data
primaryKey string // name of primary key
callFlushFunc func(fields map[string]storage.FieldData) error // call back function to flush segment
}
func NewJSONColumnConsumer(collectionSchema *schemapb.CollectionSchema,
flushFunc func(fields map[string]storage.FieldData) error) *JSONColumnConsumer {
if collectionSchema == nil {
return nil
}
v := &JSONColumnConsumer{
collectionSchema: collectionSchema,
validators: make(map[string]*Validator),
callFlushFunc: flushFunc,
}
initValidators(collectionSchema, v.validators)
v.fieldsData = initSegmentData(collectionSchema)
for i := 0; i < len(collectionSchema.Fields); i++ {
schema := collectionSchema.Fields[i]
if schema.GetIsPrimaryKey() {
v.primaryKey = schema.GetName()
break
}
}
return v
}
func (v *JSONColumnConsumer) flush() error {
// check row count, should be equal
rowCount := 0
for name, field := range v.fieldsData {
if name == v.primaryKey && v.validators[v.primaryKey].autoID {
continue
}
cnt := field.RowNum()
if rowCount == 0 {
rowCount = cnt
} else if rowCount != cnt {
return errors.New("JSON column consumer: " + name + " row count " + strconv.Itoa(cnt) + " doesn't equal " + strconv.Itoa(rowCount))
}
}
if rowCount == 0 {
return errors.New("JSON column consumer: row count is 0")
}
log.Debug("JSON column consumer: rows parsed", zap.Int("rowCount", rowCount))
// output the fileds data, let outside split them into segments
return v.callFlushFunc(v.fieldsData)
}
func (v *JSONColumnConsumer) Handle(columns map[string][]interface{}) error {
if v.validators == nil || len(v.validators) == 0 {
return errors.New("JSON column consumer is not initialized")
}
// flush at the end
if columns == nil {
err := v.flush()
log.Debug("JSON column consumer finished")
return err
}
for name, values := range columns {
validator, ok := v.validators[name]
if !ok {
// not a valid field name
break
}
if validator.primaryKey && validator.autoID {
// autoid is no need to provide
break
}
// convert and consume data
for i := 0; i < len(values); i++ {
if err := validator.convertFunc(values[i], v.fieldsData[name]); err != nil {
return errors.New("JSON column consumer: " + err.Error() + " of field " + name)
}
}
}
return nil
}

View File

@ -0,0 +1,397 @@
package importutil
import (
"context"
"strings"
"testing"
"github.com/milvus-io/milvus/internal/allocator"
"github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
"github.com/milvus-io/milvus/internal/proto/schemapb"
"github.com/milvus-io/milvus/internal/storage"
"github.com/stretchr/testify/assert"
)
type mockIDAllocator struct {
}
func (tso *mockIDAllocator) AllocID(ctx context.Context, req *rootcoordpb.AllocIDRequest) (*rootcoordpb.AllocIDResponse, error) {
return &rootcoordpb.AllocIDResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
Reason: "",
},
ID: int64(1),
Count: req.Count,
}, nil
}
func newIDAllocator(ctx context.Context, t *testing.T) *allocator.IDAllocator {
mockIDAllocator := &mockIDAllocator{}
idAllocator, err := allocator.NewIDAllocator(ctx, mockIDAllocator, int64(1))
assert.Nil(t, err)
err = idAllocator.Start()
assert.Nil(t, err)
return idAllocator
}
func Test_GetFieldDimension(t *testing.T) {
schema := &schemapb.FieldSchema{
FieldID: 111,
Name: "field_float_vector",
IsPrimaryKey: false,
Description: "float_vector",
DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: "dim", Value: "4"},
},
}
dim, err := getFieldDimension(schema)
assert.Nil(t, err)
assert.Equal(t, 4, dim)
schema.TypeParams = []*commonpb.KeyValuePair{
{Key: "dim", Value: "abc"},
}
dim, err = getFieldDimension(schema)
assert.NotNil(t, err)
assert.Equal(t, 0, dim)
schema.TypeParams = []*commonpb.KeyValuePair{}
dim, err = getFieldDimension(schema)
assert.NotNil(t, err)
assert.Equal(t, 0, dim)
}
func Test_InitValidators(t *testing.T) {
validators := make(map[string]*Validator)
err := initValidators(nil, validators)
assert.NotNil(t, err)
// success case
err = initValidators(sampleSchema(), validators)
assert.Nil(t, err)
assert.Equal(t, len(sampleSchema().Fields), len(validators))
checkFunc := func(funcName string, validVal interface{}, invalidVal interface{}) {
v, ok := validators[funcName]
assert.True(t, ok)
err = v.validateFunc(validVal)
assert.Nil(t, err)
err = v.validateFunc(invalidVal)
assert.NotNil(t, err)
}
// validate functions
var validVal interface{} = true
var invalidVal interface{} = "aa"
checkFunc("field_bool", validVal, invalidVal)
validVal = float64(100)
invalidVal = "aa"
checkFunc("field_int8", validVal, invalidVal)
checkFunc("field_int16", validVal, invalidVal)
checkFunc("field_int32", validVal, invalidVal)
checkFunc("field_int64", validVal, invalidVal)
checkFunc("field_float", validVal, invalidVal)
checkFunc("field_double", validVal, invalidVal)
validVal = "aa"
invalidVal = 100
checkFunc("field_string", validVal, invalidVal)
validVal = []interface{}{float64(100), float64(101)}
invalidVal = "aa"
checkFunc("field_binary_vector", validVal, invalidVal)
invalidVal = []interface{}{float64(100)}
checkFunc("field_binary_vector", validVal, invalidVal)
invalidVal = []interface{}{float64(100), float64(101), float64(102)}
checkFunc("field_binary_vector", validVal, invalidVal)
invalidVal = []interface{}{true, true}
checkFunc("field_binary_vector", validVal, invalidVal)
invalidVal = []interface{}{float64(255), float64(-1)}
checkFunc("field_binary_vector", validVal, invalidVal)
validVal = []interface{}{float64(1), float64(2), float64(3), float64(4)}
invalidVal = true
checkFunc("field_float_vector", validVal, invalidVal)
invalidVal = []interface{}{float64(1), float64(2), float64(3)}
checkFunc("field_float_vector", validVal, invalidVal)
invalidVal = []interface{}{float64(1), float64(2), float64(3), float64(4), float64(5)}
checkFunc("field_float_vector", validVal, invalidVal)
invalidVal = []interface{}{"a", "b", "c", "d"}
checkFunc("field_float_vector", validVal, invalidVal)
// error cases
schema := &schemapb.CollectionSchema{
Name: "schema",
Description: "schema",
AutoID: true,
Fields: make([]*schemapb.FieldSchema, 0),
}
schema.Fields = append(schema.Fields, &schemapb.FieldSchema{
FieldID: 111,
Name: "field_float_vector",
IsPrimaryKey: false,
Description: "float_vector",
DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: "dim", Value: "aa"},
},
})
validators = make(map[string]*Validator)
err = initValidators(schema, validators)
assert.NotNil(t, err)
schema.Fields = make([]*schemapb.FieldSchema, 0)
schema.Fields = append(schema.Fields, &schemapb.FieldSchema{
FieldID: 110,
Name: "field_binary_vector",
IsPrimaryKey: false,
Description: "float_vector",
DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: "dim", Value: "aa"},
},
})
err = initValidators(schema, validators)
assert.NotNil(t, err)
}
func Test_JSONRowValidator(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
schema := sampleSchema()
parser := NewJSONParser(ctx, schema)
assert.NotNil(t, parser)
// 0 row case
reader := strings.NewReader(`{
"rows":[]
}`)
validator := NewJSONRowValidator(schema, nil)
err := parser.ParseRows(reader, validator)
assert.Nil(t, err)
assert.Equal(t, int64(0), validator.ValidateCount())
// // missed some fields
// reader = strings.NewReader(`{
// "rows":[
// {"field_bool": true, "field_int8": 10, "field_int16": 101, "field_int32": 1001, "field_int64": 10001, "field_float": 3.14, "field_double": 1.56, "field_string": "hello world", "field_binary_vector": [254, 0], "field_float_vector": [1.1, 1.2, 1.3, 1.4]},
// {"field_bool": true, "field_int8": 10, "field_int16": 101, "field_int64": 10001, "field_float": 3.14, "field_double": 1.56, "field_string": "hello world", "field_binary_vector": [254, 0], "field_float_vector": [1.1, 1.2, 1.3, 1.4]}
// ]
// }`)
// err = parser.ParseRows(reader, validator)
// assert.NotNil(t, err)
// invalid dimension
reader = strings.NewReader(`{
"rows":[
{"field_bool": true, "field_int8": true, "field_int16": 101, "field_int32": 1001, "field_int64": 10001, "field_float": 3.14, "field_double": 1.56, "field_string": "hello world", "field_binary_vector": [254, 0, 1, 66, 128, 0, 1, 66], "field_float_vector": [1.1, 1.2, 1.3, 1.4]}
]
}`)
err = parser.ParseRows(reader, validator)
assert.NotNil(t, err)
// invalid value type
reader = strings.NewReader(`{
"rows":[
{"field_bool": true, "field_int8": true, "field_int16": 101, "field_int32": 1001, "field_int64": 10001, "field_float": 3.14, "field_double": 1.56, "field_string": "hello world", "field_binary_vector": [254, 0], "field_float_vector": [1.1, 1.2, 1.3, 1.4]}
]
}`)
err = parser.ParseRows(reader, validator)
assert.NotNil(t, err)
// init failed
validator.validators = nil
err = validator.Handle(nil)
assert.NotNil(t, err)
}
func Test_JSONColumnValidator(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
schema := sampleSchema()
parser := NewJSONParser(ctx, schema)
assert.NotNil(t, parser)
// 0 row case
reader := strings.NewReader(`{
"field_bool": [],
"field_int8": [],
"field_int16": [],
"field_int32": [],
"field_int64": [],
"field_float": [],
"field_double": [],
"field_string": [],
"field_binary_vector": [],
"field_float_vector": []
}`)
validator := NewJSONColumnValidator(schema, nil)
err := parser.ParseColumns(reader, validator)
assert.Nil(t, err)
for _, count := range validator.rowCounter {
assert.Equal(t, int64(0), count)
}
// different row count
reader = strings.NewReader(`{
"field_bool": [true],
"field_int8": [],
"field_int16": [],
"field_int32": [1, 2, 3],
"field_int64": [],
"field_float": [],
"field_double": [],
"field_string": [],
"field_binary_vector": [],
"field_float_vector": []
}`)
validator = NewJSONColumnValidator(schema, nil)
err = parser.ParseColumns(reader, validator)
assert.NotNil(t, err)
// invalid value type
reader = strings.NewReader(`{
"dummy": [],
"field_bool": [true],
"field_int8": [1],
"field_int16": [2],
"field_int32": [3],
"field_int64": [4],
"field_float": [1],
"field_double": [1],
"field_string": [9],
"field_binary_vector": [[254, 1]],
"field_float_vector": [[1.1, 1.2, 1.3, 1.4]]
}`)
validator = NewJSONColumnValidator(schema, nil)
err = parser.ParseColumns(reader, validator)
assert.NotNil(t, err)
// init failed
validator.validators = nil
err = validator.Handle(nil)
assert.NotNil(t, err)
}
func Test_JSONRowConsumer(t *testing.T) {
ctx := context.Background()
idAllocator := newIDAllocator(ctx, t)
schema := sampleSchema()
parser := NewJSONParser(ctx, schema)
assert.NotNil(t, parser)
reader := strings.NewReader(`{
"rows":[
{"field_bool": true, "field_int8": 10, "field_int16": 101, "field_int32": 1001, "field_int64": 10001, "field_float": 3.14, "field_double": 1.56, "field_string": "hello world", "field_binary_vector": [254, 0], "field_float_vector": [1.1, 1.2, 1.3, 1.4]},
{"field_bool": false, "field_int8": 11, "field_int16": 102, "field_int32": 1002, "field_int64": 10002, "field_float": 3.15, "field_double": 2.56, "field_string": "hello world", "field_binary_vector": [253, 0], "field_float_vector": [2.1, 2.2, 2.3, 2.4]},
{"field_bool": true, "field_int8": 12, "field_int16": 103, "field_int32": 1003, "field_int64": 10003, "field_float": 3.16, "field_double": 3.56, "field_string": "hello world", "field_binary_vector": [252, 0], "field_float_vector": [3.1, 3.2, 3.3, 3.4]},
{"field_bool": false, "field_int8": 13, "field_int16": 104, "field_int32": 1004, "field_int64": 10004, "field_float": 3.17, "field_double": 4.56, "field_string": "hello world", "field_binary_vector": [251, 0], "field_float_vector": [4.1, 4.2, 4.3, 4.4]},
{"field_bool": true, "field_int8": 14, "field_int16": 105, "field_int32": 1005, "field_int64": 10005, "field_float": 3.18, "field_double": 5.56, "field_string": "hello world", "field_binary_vector": [250, 0], "field_float_vector": [5.1, 5.2, 5.3, 5.4]}
]
}`)
var callTime int32
var totalCount int
consumeFunc := func(fields map[string]storage.FieldData) error {
callTime++
rowCount := 0
for _, data := range fields {
if rowCount == 0 {
rowCount = data.RowNum()
} else {
assert.Equal(t, rowCount, data.RowNum())
}
}
totalCount += rowCount
return nil
}
var shardNum int32 = 2
consumer := NewJSONRowConsumer(schema, idAllocator, shardNum, 1, consumeFunc)
assert.NotNil(t, consumer)
validator := NewJSONRowValidator(schema, consumer)
err := parser.ParseRows(reader, validator)
assert.Nil(t, err)
assert.Equal(t, int64(5), validator.ValidateCount())
assert.Equal(t, shardNum, callTime)
assert.Equal(t, 5, totalCount)
}
func Test_JSONColumnConsumer(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
schema := sampleSchema()
parser := NewJSONParser(ctx, schema)
assert.NotNil(t, parser)
reader := strings.NewReader(`{
"field_bool": [true, false, true, true, true],
"field_int8": [10, 11, 12, 13, 14],
"field_int16": [100, 101, 102, 103, 104],
"field_int32": [1000, 1001, 1002, 1003, 1004],
"field_int64": [10000, 10001, 10002, 10003, 10004],
"field_float": [3.14, 3.15, 3.16, 3.17, 3.18],
"field_double": [5.1, 5.2, 5.3, 5.4, 5.5],
"field_string": ["a", "b", "c", "d", "e"],
"field_binary_vector": [
[254, 1],
[253, 2],
[252, 3],
[251, 4],
[250, 5]
],
"field_float_vector": [
[1.1, 1.2, 1.3, 1.4],
[2.1, 2.2, 2.3, 2.4],
[3.1, 3.2, 3.3, 3.4],
[4.1, 4.2, 4.3, 4.4],
[5.1, 5.2, 5.3, 5.4]
]
}`)
callTime := 0
rowCount := 0
consumeFunc := func(fields map[string]storage.FieldData) error {
callTime++
for _, data := range fields {
if rowCount == 0 {
rowCount = data.RowNum()
} else {
assert.Equal(t, rowCount, data.RowNum())
}
}
return nil
}
consumer := NewJSONColumnConsumer(schema, consumeFunc)
assert.NotNil(t, consumer)
validator := NewJSONColumnValidator(schema, consumer)
err := parser.ParseColumns(reader, validator)
assert.Nil(t, err)
for _, count := range validator.ValidateCount() {
assert.Equal(t, int64(5), count)
}
assert.Equal(t, 1, callTime)
assert.Equal(t, 5, rowCount)
}

View File

@ -0,0 +1,238 @@
package importutil
import (
"context"
"encoding/json"
"errors"
"io"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/proto/schemapb"
)
const (
// root field of row-based json format
RowRootNode = "rows"
// initial size of a buffer
BufferSize = 1024
)
type JSONParser struct {
ctx context.Context // for canceling parse process
bufSize int64 // max rows in a buffer
fields map[string]int64 // fields need to be parsed
}
// newImportManager helper function to create a importManager
func NewJSONParser(ctx context.Context, collectionSchema *schemapb.CollectionSchema) *JSONParser {
fields := make(map[string]int64)
for i := 0; i < len(collectionSchema.Fields); i++ {
schema := collectionSchema.Fields[i]
fields[schema.GetName()] = 0
}
parser := &JSONParser{
ctx: ctx,
bufSize: 4096,
fields: fields,
}
return parser
}
func (p *JSONParser) logError(msg string) error {
log.Error(msg)
return errors.New(msg)
}
func (p *JSONParser) ParseRows(r io.Reader, handler JSONRowHandler) error {
if handler == nil {
return p.logError("JSON parse handler is nil")
}
dec := json.NewDecoder(r)
t, err := dec.Token()
if err != nil {
return p.logError("JSON parse: " + err.Error())
}
if t != json.Delim('{') {
return p.logError("JSON parse: invalid JSON format, the content should be started with'{'")
}
// read the first level
for dec.More() {
// read the key
t, err := dec.Token()
if err != nil {
return p.logError("JSON parse: " + err.Error())
}
key := t.(string)
// the root key should be RowRootNode
if key != RowRootNode {
return p.logError("JSON parse: invalid row-based JSON format, the key " + RowRootNode + " is not found")
}
// started by '['
t, err = dec.Token()
if err != nil {
return p.logError("JSON parse: " + err.Error())
}
if t != json.Delim('[') {
return p.logError("JSON parse: invalid row-based JSON format, rows list should begin with '['")
}
// read buffer
buf := make([]map[string]interface{}, 0, BufferSize)
for dec.More() {
var value interface{}
if err := dec.Decode(&value); err != nil {
return p.logError("JSON parse: " + err.Error())
}
switch value.(type) {
case map[string]interface{}:
break
default:
return p.logError("JSON parse: invalid JSON format, each row should be a key-value map")
}
row := value.(map[string]interface{})
buf = append(buf, row)
if len(buf) >= int(p.bufSize) {
if err = handler.Handle(buf); err != nil {
return p.logError(err.Error())
}
// clear the buffer
buf = make([]map[string]interface{}, 0, BufferSize)
}
}
// some rows in buffer not parsed, parse them
if len(buf) > 0 {
if err = handler.Handle(buf); err != nil {
return p.logError(err.Error())
}
}
// end by ']'
t, err = dec.Token()
if err != nil {
return p.logError("JSON parse: " + err.Error())
}
if t != json.Delim(']') {
return p.logError("JSON parse: invalid column-based JSON format, rows list should end with a ']'")
}
// canceled?
select {
case <-p.ctx.Done():
return p.logError("import task was canceled")
default:
break
}
// this break means we require the first node must be RowRootNode
// once the RowRootNode is parsed, just finish
break
}
// send nil to notify the handler all have done
return handler.Handle(nil)
}
func (p *JSONParser) ParseColumns(r io.Reader, handler JSONColumnHandler) error {
if handler == nil {
return p.logError("JSON parse handler is nil")
}
dec := json.NewDecoder(r)
t, err := dec.Token()
if err != nil {
return p.logError("JSON parse: " + err.Error())
}
if t != json.Delim('{') {
return p.logError("JSON parse: invalid JSON format, the content should be started with'{'")
}
// read the first level
for dec.More() {
// read the key
t, err := dec.Token()
if err != nil {
return p.logError("JSON parse: " + err.Error())
}
key := t.(string)
// not a valid column name, skip
_, isValidField := p.fields[key]
// started by '['
t, err = dec.Token()
if err != nil {
return p.logError("JSON parse: " + err.Error())
}
if t != json.Delim('[') {
return p.logError("JSON parse: invalid column-based JSON format, each field should begin with '['")
}
// read buffer
buf := make(map[string][]interface{})
buf[key] = make([]interface{}, 0, BufferSize)
for dec.More() {
var value interface{}
if err := dec.Decode(&value); err != nil {
return p.logError("JSON parse: " + err.Error())
}
if !isValidField {
continue
}
buf[key] = append(buf[key], value)
if len(buf[key]) >= int(p.bufSize) {
if err = handler.Handle(buf); err != nil {
return p.logError(err.Error())
}
// clear the buffer
buf[key] = make([]interface{}, 0, BufferSize)
}
}
// some values in buffer not parsed, parse them
if len(buf[key]) > 0 {
if err = handler.Handle(buf); err != nil {
return p.logError(err.Error())
}
}
// end by ']'
t, err = dec.Token()
if err != nil {
return p.logError("JSON parse: " + err.Error())
}
if t != json.Delim(']') {
return p.logError("JSON parse: invalid column-based JSON format, each field should end with a ']'")
}
// canceled?
select {
case <-p.ctx.Done():
return p.logError("import task was canceled")
default:
break
}
}
// send nil to notify the handler all have done
return handler.Handle(nil)
}

View File

@ -0,0 +1,256 @@
package importutil
import (
"context"
"strings"
"testing"
"github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/schemapb"
"github.com/stretchr/testify/assert"
)
func sampleSchema() *schemapb.CollectionSchema {
schema := &schemapb.CollectionSchema{
Name: "schema",
Description: "schema",
AutoID: true,
Fields: []*schemapb.FieldSchema{
{
FieldID: 102,
Name: "field_bool",
IsPrimaryKey: false,
Description: "bool",
DataType: schemapb.DataType_Bool,
},
{
FieldID: 103,
Name: "field_int8",
IsPrimaryKey: false,
Description: "int8",
DataType: schemapb.DataType_Int8,
},
{
FieldID: 104,
Name: "field_int16",
IsPrimaryKey: false,
Description: "int16",
DataType: schemapb.DataType_Int16,
},
{
FieldID: 105,
Name: "field_int32",
IsPrimaryKey: false,
Description: "int32",
DataType: schemapb.DataType_Int32,
},
{
FieldID: 106,
Name: "field_int64",
IsPrimaryKey: true,
AutoID: false,
Description: "int64",
DataType: schemapb.DataType_Int64,
},
{
FieldID: 107,
Name: "field_float",
IsPrimaryKey: false,
Description: "float",
DataType: schemapb.DataType_Float,
},
{
FieldID: 108,
Name: "field_double",
IsPrimaryKey: false,
Description: "double",
DataType: schemapb.DataType_Double,
},
{
FieldID: 109,
Name: "field_string",
IsPrimaryKey: false,
Description: "string",
DataType: schemapb.DataType_String,
},
{
FieldID: 110,
Name: "field_binary_vector",
IsPrimaryKey: false,
Description: "binary_vector",
DataType: schemapb.DataType_BinaryVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: "dim", Value: "16"},
},
},
{
FieldID: 111,
Name: "field_float_vector",
IsPrimaryKey: false,
Description: "float_vector",
DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: "dim", Value: "4"},
},
},
},
}
return schema
}
func Test_ParserRows(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
schema := sampleSchema()
parser := NewJSONParser(ctx, schema)
assert.NotNil(t, parser)
parser.bufSize = 1
reader := strings.NewReader(`{
"rows":[
{"field_bool": true, "field_int8": 10, "field_int16": 101, "field_int32": 1001, "field_int64": 10001, "field_float": 3.14, "field_double": 1.56, "field_string": "hello world", "field_binary_vector": [254, 0], "field_float_vector": [1.1, 1.2, 1.3, 1.4]},
{"field_bool": false, "field_int8": 11, "field_int16": 102, "field_int32": 1002, "field_int64": 10002, "field_float": 3.15, "field_double": 2.56, "field_string": "hello world", "field_binary_vector": [253, 0], "field_float_vector": [2.1, 2.2, 2.3, 2.4]},
{"field_bool": true, "field_int8": 12, "field_int16": 103, "field_int32": 1003, "field_int64": 10003, "field_float": 3.16, "field_double": 3.56, "field_string": "hello world", "field_binary_vector": [252, 0], "field_float_vector": [3.1, 3.2, 3.3, 3.4]},
{"field_bool": false, "field_int8": 13, "field_int16": 104, "field_int32": 1004, "field_int64": 10004, "field_float": 3.17, "field_double": 4.56, "field_string": "hello world", "field_binary_vector": [251, 0], "field_float_vector": [4.1, 4.2, 4.3, 4.4]},
{"field_bool": true, "field_int8": 14, "field_int16": 105, "field_int32": 1005, "field_int64": 10005, "field_float": 3.18, "field_double": 5.56, "field_string": "hello world", "field_binary_vector": [250, 0], "field_float_vector": [5.1, 5.2, 5.3, 5.4]}
]
}`)
err := parser.ParseRows(reader, nil)
assert.NotNil(t, err)
validator := NewJSONRowValidator(schema, nil)
err = parser.ParseRows(reader, validator)
assert.Nil(t, err)
assert.Equal(t, int64(5), validator.ValidateCount())
reader = strings.NewReader(`{
"dummy":[]
}`)
validator = NewJSONRowValidator(schema, nil)
err = parser.ParseRows(reader, validator)
assert.NotNil(t, err)
reader = strings.NewReader(`{
"rows":
}`)
validator = NewJSONRowValidator(schema, nil)
err = parser.ParseRows(reader, validator)
assert.NotNil(t, err)
reader = strings.NewReader(`{
"rows": [}
}`)
validator = NewJSONRowValidator(schema, nil)
err = parser.ParseRows(reader, validator)
assert.NotNil(t, err)
reader = strings.NewReader(`{
"rows": {}
}`)
validator = NewJSONRowValidator(schema, nil)
err = parser.ParseRows(reader, validator)
assert.NotNil(t, err)
reader = strings.NewReader(`{
"rows": [[]]
}`)
validator = NewJSONRowValidator(schema, nil)
err = parser.ParseRows(reader, validator)
assert.NotNil(t, err)
reader = strings.NewReader(`[]`)
validator = NewJSONRowValidator(schema, nil)
err = parser.ParseRows(reader, validator)
assert.NotNil(t, err)
reader = strings.NewReader(``)
validator = NewJSONRowValidator(schema, nil)
err = parser.ParseRows(reader, validator)
assert.NotNil(t, err)
}
func Test_ParserColumns(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
schema := sampleSchema()
parser := NewJSONParser(ctx, schema)
assert.NotNil(t, parser)
parser.bufSize = 1
reader := strings.NewReader(`{
"field_bool": [true, false, true, true, true],
"field_int8": [10, 11, 12, 13, 14],
"field_int16": [100, 101, 102, 103, 104],
"field_int32": [1000, 1001, 1002, 1003, 1004],
"field_int64": [10000, 10001, 10002, 10003, 10004],
"field_float": [3.14, 3.15, 3.16, 3.17, 3.18],
"field_double": [5.1, 5.2, 5.3, 5.4, 5.5],
"field_string": ["a", "b", "c", "d", "e"],
"field_binary_vector": [
[254, 1],
[253, 2],
[252, 3],
[251, 4],
[250, 5]
],
"field_float_vector": [
[1.1, 1.2, 1.3, 1.4],
[2.1, 2.2, 2.3, 2.4],
[3.1, 3.2, 3.3, 3.4],
[4.1, 4.2, 4.3, 4.4],
[5.1, 5.2, 5.3, 5.4]
]
}`)
err := parser.ParseColumns(reader, nil)
assert.NotNil(t, err)
validator := NewJSONColumnValidator(schema, nil)
err = parser.ParseColumns(reader, validator)
assert.Nil(t, err)
counter := validator.ValidateCount()
for _, v := range counter {
assert.Equal(t, int64(5), v)
}
reader = strings.NewReader(`{
"dummy":[1, 2, 3]
}`)
validator = NewJSONColumnValidator(schema, nil)
err = parser.ParseColumns(reader, validator)
assert.Nil(t, err)
reader = strings.NewReader(`{
"field_bool":
}`)
validator = NewJSONColumnValidator(schema, nil)
err = parser.ParseColumns(reader, validator)
assert.NotNil(t, err)
reader = strings.NewReader(`{
"field_bool":{}
}`)
validator = NewJSONColumnValidator(schema, nil)
err = parser.ParseColumns(reader, validator)
assert.NotNil(t, err)
reader = strings.NewReader(`{
"field_bool":[}
}`)
validator = NewJSONColumnValidator(schema, nil)
err = parser.ParseColumns(reader, validator)
assert.NotNil(t, err)
reader = strings.NewReader(`[]`)
validator = NewJSONColumnValidator(schema, nil)
err = parser.ParseColumns(reader, validator)
assert.NotNil(t, err)
reader = strings.NewReader(``)
validator = NewJSONColumnValidator(schema, nil)
err = parser.ParseColumns(reader, validator)
assert.NotNil(t, err)
}