milvus/internal/util/importutil/json_handler.go
groot 168935f4cf
Fix bulkload bug for string primary key (#16958)
Signed-off-by: groot <yihua.mo@zilliz.com>
2022-05-13 16:07:54 +08:00

765 lines
24 KiB
Go

package importutil
import (
"errors"
"fmt"
"strconv"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/allocator"
"github.com/milvus-io/milvus/internal/common"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/proto/schemapb"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/util/typeutil"
)
// interface to process rows data
type JSONRowHandler interface {
Handle(rows []map[storage.FieldID]interface{}) error
}
// interface to process column data
type JSONColumnHandler interface {
Handle(columns map[storage.FieldID][]interface{}) error
}
// method to get dimension of vecotor field
func getFieldDimension(schema *schemapb.FieldSchema) (int, error) {
for _, kvPair := range schema.GetTypeParams() {
key, value := kvPair.GetKey(), kvPair.GetValue()
if key == "dim" {
dim, err := strconv.Atoi(value)
if err != nil {
return 0, errors.New("vector dimension is invalid")
}
return dim, nil
}
}
return 0, errors.New("vector dimension is not defined")
}
// field value validator
type Validator struct {
validateFunc func(obj interface{}) error // validate data type function
convertFunc func(obj interface{}, field storage.FieldData) error // convert data function
primaryKey bool // true for primary key
autoID bool // only for primary key field
isString bool // for string field
dimension int // only for vector field
fieldName string // field name
}
// method to construct valiator functions
func initValidators(collectionSchema *schemapb.CollectionSchema, validators map[storage.FieldID]*Validator) error {
if collectionSchema == nil {
return errors.New("collection schema is nil")
}
// json decoder parse all the numeric value into float64
numericValidator := func(obj interface{}) error {
switch obj.(type) {
case float64:
return nil
default:
s := fmt.Sprintf("%v", obj)
msg := "illegal numeric value " + s
return errors.New(msg)
}
}
for i := 0; i < len(collectionSchema.Fields); i++ {
schema := collectionSchema.Fields[i]
validators[schema.GetFieldID()] = &Validator{}
validators[schema.GetFieldID()].primaryKey = schema.GetIsPrimaryKey()
validators[schema.GetFieldID()].autoID = schema.GetAutoID()
validators[schema.GetFieldID()].fieldName = schema.GetName()
validators[schema.GetFieldID()].isString = false
switch schema.DataType {
case schemapb.DataType_Bool:
validators[schema.GetFieldID()].validateFunc = func(obj interface{}) error {
switch obj.(type) {
case bool:
return nil
default:
s := fmt.Sprintf("%v", obj)
msg := "illegal value " + s + " for bool type field " + schema.GetName()
return errors.New(msg)
}
}
validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error {
value := obj.(bool)
field.(*storage.BoolFieldData).Data = append(field.(*storage.BoolFieldData).Data, value)
field.(*storage.BoolFieldData).NumRows[0]++
return nil
}
case schemapb.DataType_Float:
validators[schema.GetFieldID()].validateFunc = numericValidator
validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error {
value := float32(obj.(float64))
field.(*storage.FloatFieldData).Data = append(field.(*storage.FloatFieldData).Data, value)
field.(*storage.FloatFieldData).NumRows[0]++
return nil
}
case schemapb.DataType_Double:
validators[schema.GetFieldID()].validateFunc = numericValidator
validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error {
value := obj.(float64)
field.(*storage.DoubleFieldData).Data = append(field.(*storage.DoubleFieldData).Data, value)
field.(*storage.DoubleFieldData).NumRows[0]++
return nil
}
case schemapb.DataType_Int8:
validators[schema.GetFieldID()].validateFunc = numericValidator
validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error {
value := int8(obj.(float64))
field.(*storage.Int8FieldData).Data = append(field.(*storage.Int8FieldData).Data, value)
field.(*storage.Int8FieldData).NumRows[0]++
return nil
}
case schemapb.DataType_Int16:
validators[schema.GetFieldID()].validateFunc = numericValidator
validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error {
value := int16(obj.(float64))
field.(*storage.Int16FieldData).Data = append(field.(*storage.Int16FieldData).Data, value)
field.(*storage.Int16FieldData).NumRows[0]++
return nil
}
case schemapb.DataType_Int32:
validators[schema.GetFieldID()].validateFunc = numericValidator
validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error {
value := int32(obj.(float64))
field.(*storage.Int32FieldData).Data = append(field.(*storage.Int32FieldData).Data, value)
field.(*storage.Int32FieldData).NumRows[0]++
return nil
}
case schemapb.DataType_Int64:
validators[schema.GetFieldID()].validateFunc = numericValidator
validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error {
value := int64(obj.(float64))
field.(*storage.Int64FieldData).Data = append(field.(*storage.Int64FieldData).Data, value)
field.(*storage.Int64FieldData).NumRows[0]++
return nil
}
case schemapb.DataType_BinaryVector:
dim, err := getFieldDimension(schema)
if err != nil {
return err
}
validators[schema.GetFieldID()].dimension = dim
validators[schema.GetFieldID()].validateFunc = func(obj interface{}) error {
switch vt := obj.(type) {
case []interface{}:
if len(vt)*8 != dim {
msg := "bit size " + strconv.Itoa(len(vt)*8) + " doesn't equal to vector dimension " + strconv.Itoa(dim) + " of field " + schema.GetName()
return errors.New(msg)
}
for i := 0; i < len(vt); i++ {
if e := numericValidator(vt[i]); e != nil {
msg := e.Error() + " for binary vector field " + schema.GetName()
return errors.New(msg)
}
t := int(vt[i].(float64))
if t > 255 || t < 0 {
msg := "illegal value " + strconv.Itoa(t) + " for binary vector field " + schema.GetName()
return errors.New(msg)
}
}
return nil
default:
s := fmt.Sprintf("%v", obj)
msg := s + " is not an array for binary vector field " + schema.GetName()
return errors.New(msg)
}
}
validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error {
arr := obj.([]interface{})
for i := 0; i < len(arr); i++ {
value := byte(arr[i].(float64))
field.(*storage.BinaryVectorFieldData).Data = append(field.(*storage.BinaryVectorFieldData).Data, value)
}
field.(*storage.BinaryVectorFieldData).NumRows[0]++
return nil
}
case schemapb.DataType_FloatVector:
dim, err := getFieldDimension(schema)
if err != nil {
return err
}
validators[schema.GetFieldID()].dimension = dim
validators[schema.GetFieldID()].validateFunc = func(obj interface{}) error {
switch vt := obj.(type) {
case []interface{}:
if len(vt) != dim {
msg := "array size " + strconv.Itoa(len(vt)) + " doesn't equal to vector dimension " + strconv.Itoa(dim) + " of field " + schema.GetName()
return errors.New(msg)
}
for i := 0; i < len(vt); i++ {
if e := numericValidator(vt[i]); e != nil {
msg := e.Error() + " for float vector field " + schema.GetName()
return errors.New(msg)
}
}
return nil
default:
s := fmt.Sprintf("%v", obj)
msg := s + " is not an array for float vector field " + schema.GetName()
return errors.New(msg)
}
}
validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error {
arr := obj.([]interface{})
for i := 0; i < len(arr); i++ {
value := float32(arr[i].(float64))
field.(*storage.FloatVectorFieldData).Data = append(field.(*storage.FloatVectorFieldData).Data, value)
}
field.(*storage.FloatVectorFieldData).NumRows[0]++
return nil
}
case schemapb.DataType_String, schemapb.DataType_VarChar:
validators[schema.GetFieldID()].isString = true
validators[schema.GetFieldID()].validateFunc = func(obj interface{}) error {
switch obj.(type) {
case string:
return nil
default:
s := fmt.Sprintf("%v", obj)
msg := s + " is not a string for string type field " + schema.GetName()
return errors.New(msg)
}
}
validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error {
value := obj.(string)
field.(*storage.StringFieldData).Data = append(field.(*storage.StringFieldData).Data, value)
field.(*storage.StringFieldData).NumRows[0]++
return nil
}
default:
return errors.New("unsupport data type: " + strconv.Itoa(int(collectionSchema.Fields[i].DataType)))
}
}
return nil
}
// row-based json format validator class
type JSONRowValidator struct {
downstream JSONRowHandler // downstream processor, typically is a JSONRowComsumer
validators map[storage.FieldID]*Validator // validators for each field
rowCounter int64 // how many rows have been validated
}
func NewJSONRowValidator(collectionSchema *schemapb.CollectionSchema, downstream JSONRowHandler) *JSONRowValidator {
v := &JSONRowValidator{
validators: make(map[storage.FieldID]*Validator),
downstream: downstream,
rowCounter: 0,
}
initValidators(collectionSchema, v.validators)
return v
}
func (v *JSONRowValidator) ValidateCount() int64 {
return v.rowCounter
}
func (v *JSONRowValidator) Handle(rows []map[storage.FieldID]interface{}) error {
if v == nil || v.validators == nil || len(v.validators) == 0 {
return errors.New("JSON row validator is not initialized")
}
// parse completed
if rows == nil {
log.Info("JSON row validation finished")
if v.downstream != nil {
return v.downstream.Handle(rows)
}
return nil
}
for i := 0; i < len(rows); i++ {
row := rows[i]
for id, validator := range v.validators {
if validator.primaryKey && validator.autoID {
// auto-generated primary key, ignore
continue
}
value, ok := row[id]
if !ok {
return errors.New("JSON row validator: field " + validator.fieldName + " missed at the row " + strconv.FormatInt(v.rowCounter+int64(i), 10))
}
if err := validator.validateFunc(value); err != nil {
return errors.New("JSON row validator: " + err.Error() + " at the row " + strconv.FormatInt(v.rowCounter+int64(i), 10))
}
}
}
v.rowCounter += int64(len(rows))
if v.downstream != nil {
return v.downstream.Handle(rows)
}
return nil
}
// column-based json format validator class
type JSONColumnValidator struct {
downstream JSONColumnHandler // downstream processor, typically is a JSONColumnComsumer
validators map[storage.FieldID]*Validator // validators for each field
rowCounter map[string]int64 // row count of each field
}
func NewJSONColumnValidator(schema *schemapb.CollectionSchema, downstream JSONColumnHandler) *JSONColumnValidator {
v := &JSONColumnValidator{
validators: make(map[storage.FieldID]*Validator),
downstream: downstream,
rowCounter: make(map[string]int64),
}
initValidators(schema, v.validators)
return v
}
func (v *JSONColumnValidator) ValidateCount() map[string]int64 {
return v.rowCounter
}
func (v *JSONColumnValidator) Handle(columns map[storage.FieldID][]interface{}) error {
if v == nil || v.validators == nil || len(v.validators) == 0 {
return errors.New("JSON column validator is not initialized")
}
// parse completed
if columns == nil {
// compare the row count of columns, should be equal
rowCount := int64(-1)
for k, counter := range v.rowCounter {
if rowCount == -1 {
rowCount = counter
} else if rowCount != counter {
return errors.New("JSON column validator: the field " + k + " row count " + strconv.Itoa(int(counter)) + " is not equal to other fields row count" + strconv.Itoa(int(rowCount)))
}
}
// let the downstream know parse is completed
log.Info("JSON column validation finished")
if v.downstream != nil {
return v.downstream.Handle(nil)
}
return nil
}
for id, values := range columns {
validator, ok := v.validators[id]
name := validator.fieldName
if !ok {
// not a valid field name, skip without parsing
break
}
for i := 0; i < len(values); i++ {
if err := validator.validateFunc(values[i]); err != nil {
return errors.New("JSON column validator: " + err.Error() + " at the row " + strconv.FormatInt(v.rowCounter[name]+int64(i), 10))
}
}
v.rowCounter[name] += int64(len(values))
}
if v.downstream != nil {
return v.downstream.Handle(columns)
}
return nil
}
type ImportFlushFunc func(fields map[storage.FieldID]storage.FieldData, shardNum int) error
// row-based json format consumer class
type JSONRowConsumer struct {
collectionSchema *schemapb.CollectionSchema // collection schema
rowIDAllocator *allocator.IDAllocator // autoid allocator
validators map[storage.FieldID]*Validator // validators for each field
rowCounter int64 // how many rows have been consumed
shardNum int32 // sharding number of the collection
segmentsData []map[storage.FieldID]storage.FieldData // in-memory segments data
segmentSize int64 // maximum size of a segment(unit:byte)
primaryKey storage.FieldID // name of primary key
autoIDRange []int64 // auto-generated id range, for example: [1, 10, 20, 25] means id from 1 to 10 and 20 to 25
callFlushFunc ImportFlushFunc // call back function to flush segment
}
func initSegmentData(collectionSchema *schemapb.CollectionSchema) map[storage.FieldID]storage.FieldData {
segmentData := make(map[storage.FieldID]storage.FieldData)
// rowID field is a hidden field with fieldID=0, it is always auto-generated by IDAllocator
// if primary key is int64 and autoID=true, primary key field is equal to rowID field
segmentData[common.RowIDField] = &storage.Int64FieldData{
Data: make([]int64, 0),
NumRows: []int64{0},
}
for i := 0; i < len(collectionSchema.Fields); i++ {
schema := collectionSchema.Fields[i]
switch schema.DataType {
case schemapb.DataType_Bool:
segmentData[schema.GetFieldID()] = &storage.BoolFieldData{
Data: make([]bool, 0),
NumRows: []int64{0},
}
case schemapb.DataType_Float:
segmentData[schema.GetFieldID()] = &storage.FloatFieldData{
Data: make([]float32, 0),
NumRows: []int64{0},
}
case schemapb.DataType_Double:
segmentData[schema.GetFieldID()] = &storage.DoubleFieldData{
Data: make([]float64, 0),
NumRows: []int64{0},
}
case schemapb.DataType_Int8:
segmentData[schema.GetFieldID()] = &storage.Int8FieldData{
Data: make([]int8, 0),
NumRows: []int64{0},
}
case schemapb.DataType_Int16:
segmentData[schema.GetFieldID()] = &storage.Int16FieldData{
Data: make([]int16, 0),
NumRows: []int64{0},
}
case schemapb.DataType_Int32:
segmentData[schema.GetFieldID()] = &storage.Int32FieldData{
Data: make([]int32, 0),
NumRows: []int64{0},
}
case schemapb.DataType_Int64:
segmentData[schema.GetFieldID()] = &storage.Int64FieldData{
Data: make([]int64, 0),
NumRows: []int64{0},
}
case schemapb.DataType_BinaryVector:
dim, _ := getFieldDimension(schema)
segmentData[schema.GetFieldID()] = &storage.BinaryVectorFieldData{
Data: make([]byte, 0),
NumRows: []int64{0},
Dim: dim,
}
case schemapb.DataType_FloatVector:
dim, _ := getFieldDimension(schema)
segmentData[schema.GetFieldID()] = &storage.FloatVectorFieldData{
Data: make([]float32, 0),
NumRows: []int64{0},
Dim: dim,
}
case schemapb.DataType_String, schemapb.DataType_VarChar:
segmentData[schema.GetFieldID()] = &storage.StringFieldData{
Data: make([]string, 0),
NumRows: []int64{0},
}
default:
log.Error("JSON row consumer error: unsupported data type", zap.Int("DataType", int(schema.DataType)))
return nil
}
}
return segmentData
}
func NewJSONRowConsumer(collectionSchema *schemapb.CollectionSchema, idAlloc *allocator.IDAllocator, shardNum int32, segmentSize int64,
flushFunc ImportFlushFunc) *JSONRowConsumer {
if collectionSchema == nil {
log.Error("JSON row consumer: collection schema is nil")
return nil
}
v := &JSONRowConsumer{
collectionSchema: collectionSchema,
rowIDAllocator: idAlloc,
validators: make(map[storage.FieldID]*Validator),
shardNum: shardNum,
segmentSize: segmentSize,
rowCounter: 0,
primaryKey: -1,
autoIDRange: make([]int64, 0),
callFlushFunc: flushFunc,
}
initValidators(collectionSchema, v.validators)
v.segmentsData = make([]map[storage.FieldID]storage.FieldData, 0, shardNum)
for i := 0; i < int(shardNum); i++ {
segmentData := initSegmentData(collectionSchema)
if segmentData == nil {
return nil
}
v.segmentsData = append(v.segmentsData, segmentData)
}
for i := 0; i < len(collectionSchema.Fields); i++ {
schema := collectionSchema.Fields[i]
if schema.GetIsPrimaryKey() {
v.primaryKey = schema.GetFieldID()
break
}
}
// primary key not found
if v.primaryKey == -1 {
log.Error("JSON row consumer: collection schema has no primary key")
return nil
}
// primary key is autoid, id generator is required
if v.validators[v.primaryKey].autoID && idAlloc == nil {
log.Error("JSON row consumer: ID allocator is nil")
return nil
}
return v
}
func (v *JSONRowConsumer) IDRange() []int64 {
return v.autoIDRange
}
func (v *JSONRowConsumer) flush(force bool) error {
// force flush all data
if force {
for i := 0; i < len(v.segmentsData); i++ {
segmentData := v.segmentsData[i]
rowNum := segmentData[v.primaryKey].RowNum()
if rowNum > 0 {
log.Info("JSON row consumer: force flush segment", zap.Int("rows", rowNum))
v.callFlushFunc(segmentData, i)
}
}
return nil
}
// segment size can be flushed
for i := 0; i < len(v.segmentsData); i++ {
segmentData := v.segmentsData[i]
rowNum := segmentData[v.primaryKey].RowNum()
memSize := 0
for _, field := range segmentData {
memSize += field.GetMemorySize()
}
if memSize >= int(v.segmentSize) && rowNum > 0 {
log.Info("JSON row consumer: flush fulled segment", zap.Int("bytes", memSize), zap.Int("rowNum", rowNum))
v.callFlushFunc(segmentData, i)
v.segmentsData[i] = initSegmentData(v.collectionSchema)
}
}
return nil
}
func (v *JSONRowConsumer) Handle(rows []map[storage.FieldID]interface{}) error {
if v == nil || v.validators == nil || len(v.validators) == 0 {
return errors.New("JSON row consumer is not initialized")
}
// flush in necessery
if rows == nil {
err := v.flush(true)
log.Info("JSON row consumer finished")
return err
}
err := v.flush(false)
if err != nil {
return err
}
// prepare autoid
primaryValidator := v.validators[v.primaryKey]
var rowIDBegin typeutil.UniqueID
var rowIDEnd typeutil.UniqueID
if primaryValidator.autoID {
var err error
rowIDBegin, rowIDEnd, err = v.rowIDAllocator.Alloc(uint32(len(rows)))
if err != nil {
return errors.New("JSON row consumer: " + err.Error())
}
if rowIDEnd-rowIDBegin != int64(len(rows)) {
return errors.New("JSON row consumer: failed to allocate ID for " + strconv.Itoa(len(rows)) + " rows")
}
v.autoIDRange = append(v.autoIDRange, rowIDBegin, rowIDEnd)
}
// consume rows
for i := 0; i < len(rows); i++ {
row := rows[i]
// hash to a shard number
var shard uint32
if primaryValidator.isString {
if primaryValidator.autoID {
return errors.New("JSON row consumer: string type primary key cannot be auto-generated")
}
value := row[v.primaryKey]
pk := string(value.(string))
hash := typeutil.HashString2Uint32(pk)
shard = hash % uint32(v.shardNum)
pkArray := v.segmentsData[shard][v.primaryKey].(*storage.StringFieldData)
pkArray.Data = append(pkArray.Data, pk)
pkArray.NumRows[0]++
} else {
// get/generate the row id
var pk int64
if primaryValidator.autoID {
pk = rowIDBegin + int64(i)
} else {
value := row[v.primaryKey]
pk = int64(value.(float64))
}
hash, _ := typeutil.Hash32Int64(pk)
shard = hash % uint32(v.shardNum)
pkArray := v.segmentsData[shard][v.primaryKey].(*storage.Int64FieldData)
pkArray.Data = append(pkArray.Data, pk)
pkArray.NumRows[0]++
}
// set rowid field
rowIDField := v.segmentsData[shard][common.RowIDField].(*storage.Int64FieldData)
rowIDField.Data = append(rowIDField.Data, rowIDBegin+int64(i))
rowIDField.NumRows[0]++
// convert value and consume
for name, validator := range v.validators {
if validator.primaryKey {
continue
}
value := row[name]
if err := validator.convertFunc(value, v.segmentsData[shard][name]); err != nil {
return errors.New("JSON row consumer: " + err.Error() + " at the row " + strconv.FormatInt(v.rowCounter+int64(i), 10))
}
}
}
v.rowCounter += int64(len(rows))
return nil
}
type ColumnFlushFunc func(fields map[storage.FieldID]storage.FieldData) error
// column-based json format consumer class
type JSONColumnConsumer struct {
collectionSchema *schemapb.CollectionSchema // collection schema
validators map[storage.FieldID]*Validator // validators for each field
fieldsData map[storage.FieldID]storage.FieldData // in-memory fields data
primaryKey storage.FieldID // name of primary key
callFlushFunc ColumnFlushFunc // call back function to flush segment
}
func NewJSONColumnConsumer(collectionSchema *schemapb.CollectionSchema, flushFunc ColumnFlushFunc) *JSONColumnConsumer {
if collectionSchema == nil {
return nil
}
v := &JSONColumnConsumer{
collectionSchema: collectionSchema,
validators: make(map[storage.FieldID]*Validator),
callFlushFunc: flushFunc,
}
initValidators(collectionSchema, v.validators)
v.fieldsData = initSegmentData(collectionSchema)
for i := 0; i < len(collectionSchema.Fields); i++ {
schema := collectionSchema.Fields[i]
if schema.GetIsPrimaryKey() {
v.primaryKey = schema.GetFieldID()
break
}
}
return v
}
func (v *JSONColumnConsumer) flush() error {
// check row count, should be equal
rowCount := 0
for id, field := range v.fieldsData {
// skip the autoid field
if id == v.primaryKey && v.validators[v.primaryKey].autoID {
continue
}
cnt := field.RowNum()
// skip 0 row fields since a data file may only import one column(there are several data files imported)
if cnt == 0 {
continue
}
// only check non-zero row fields
if rowCount == 0 {
rowCount = cnt
} else if rowCount != cnt {
return errors.New("JSON column consumer: " + strconv.FormatInt(id, 10) + " row count " + strconv.Itoa(cnt) + " doesn't equal " + strconv.Itoa(rowCount))
}
}
if rowCount == 0 {
return errors.New("JSON column consumer: row count is 0")
}
log.Info("JSON column consumer: rows parsed", zap.Int("rowCount", rowCount))
// output the fileds data, let outside split them into segments
return v.callFlushFunc(v.fieldsData)
}
func (v *JSONColumnConsumer) Handle(columns map[storage.FieldID][]interface{}) error {
if v == nil || v.validators == nil || len(v.validators) == 0 {
return errors.New("JSON column consumer is not initialized")
}
// flush at the end
if columns == nil {
err := v.flush()
log.Info("JSON column consumer finished")
return err
}
// consume columns data
for id, values := range columns {
validator, ok := v.validators[id]
if !ok {
// not a valid field id
break
}
if validator.primaryKey && validator.autoID {
// autoid is no need to provide
break
}
// convert and consume data
for i := 0; i < len(values); i++ {
if err := validator.convertFunc(values[i], v.fieldsData[id]); err != nil {
return errors.New("JSON column consumer: " + err.Error() + " of field " + strconv.FormatInt(id, 10))
}
}
}
return nil
}