mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-12-01 11:29:48 +08:00
99984b88e1
Signed-off-by: xige-16 <xi.ge@zilliz.com>
442 lines
14 KiB
Go
442 lines
14 KiB
Go
package storage
|
|
|
|
import (
|
|
"bytes"
|
|
"errors"
|
|
"fmt"
|
|
"reflect"
|
|
|
|
"github.com/apache/arrow/go/v8/arrow"
|
|
"github.com/apache/arrow/go/v8/parquet"
|
|
"github.com/apache/arrow/go/v8/parquet/file"
|
|
"github.com/milvus-io/milvus/internal/proto/schemapb"
|
|
)
|
|
|
|
// PayloadReader reads data from payload
|
|
type PayloadReader struct {
|
|
reader *file.Reader
|
|
colType schemapb.DataType
|
|
numRows int64
|
|
}
|
|
|
|
func NewPayloadReader(colType schemapb.DataType, buf []byte) (*PayloadReader, error) {
|
|
if len(buf) == 0 {
|
|
return nil, errors.New("create Payload reader failed, buffer is empty")
|
|
}
|
|
parquetReader, err := file.NewParquetReader(bytes.NewReader(buf))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &PayloadReader{reader: parquetReader, colType: colType, numRows: parquetReader.NumRows()}, nil
|
|
}
|
|
|
|
// GetDataFromPayload returns data,length from payload, returns err if failed
|
|
// Return:
|
|
// `interface{}`: all types.
|
|
// `int`: dim, only meaningful to FLOAT/BINARY VECTOR type.
|
|
// `error`: error.
|
|
func (r *PayloadReader) GetDataFromPayload() (interface{}, int, error) {
|
|
switch r.colType {
|
|
case schemapb.DataType_Bool:
|
|
val, err := r.GetBoolFromPayload()
|
|
return val, 0, err
|
|
case schemapb.DataType_Int8:
|
|
val, err := r.GetInt8FromPayload()
|
|
return val, 0, err
|
|
case schemapb.DataType_Int16:
|
|
val, err := r.GetInt16FromPayload()
|
|
return val, 0, err
|
|
case schemapb.DataType_Int32:
|
|
val, err := r.GetInt32FromPayload()
|
|
return val, 0, err
|
|
case schemapb.DataType_Int64:
|
|
val, err := r.GetInt64FromPayload()
|
|
return val, 0, err
|
|
case schemapb.DataType_Float:
|
|
val, err := r.GetFloatFromPayload()
|
|
return val, 0, err
|
|
case schemapb.DataType_Double:
|
|
val, err := r.GetDoubleFromPayload()
|
|
return val, 0, err
|
|
case schemapb.DataType_BinaryVector:
|
|
return r.GetBinaryVectorFromPayload()
|
|
case schemapb.DataType_FloatVector:
|
|
return r.GetFloatVectorFromPayload()
|
|
case schemapb.DataType_String, schemapb.DataType_VarChar:
|
|
val, err := r.GetStringFromPayload()
|
|
return val, 0, err
|
|
default:
|
|
return nil, 0, errors.New("unknown type")
|
|
}
|
|
}
|
|
|
|
// ReleasePayloadReader release payload reader.
|
|
func (r *PayloadReader) ReleasePayloadReader() {
|
|
r.Close()
|
|
}
|
|
|
|
// GetBoolFromPayload returns bool slice from payload.
|
|
func (r *PayloadReader) GetBoolFromPayload() ([]bool, error) {
|
|
if r.colType != schemapb.DataType_Bool {
|
|
return nil, fmt.Errorf("failed to get bool from datatype %v", r.colType.String())
|
|
}
|
|
dumper, err := r.createDumper()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
ret := make([]bool, r.numRows)
|
|
var i int64
|
|
for i = 0; i < r.numRows; i++ {
|
|
v, hasValue := dumper.Next()
|
|
if !hasValue {
|
|
return nil, fmt.Errorf("unmatched row number: expect %v, actual %v", r.numRows, i)
|
|
}
|
|
ret[i] = v.(bool)
|
|
}
|
|
return ret, nil
|
|
}
|
|
|
|
// GetByteFromPayload returns byte slice from payload
|
|
func (r *PayloadReader) GetByteFromPayload() ([]byte, error) {
|
|
if r.colType != schemapb.DataType_Int8 {
|
|
return nil, fmt.Errorf("failed to get byte from datatype %v", r.colType.String())
|
|
}
|
|
dumper, err := r.createDumper()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
ret := make([]byte, r.numRows)
|
|
var i int64
|
|
for i = 0; i < r.numRows; i++ {
|
|
v, hasValue := dumper.Next()
|
|
if !hasValue {
|
|
return nil, fmt.Errorf("unmatched row number: expect %v, actual %v", r.numRows, i)
|
|
}
|
|
ret[i] = byte(v.(int32))
|
|
}
|
|
return ret, nil
|
|
}
|
|
|
|
// GetInt8FromPayload returns int8 slice from payload
|
|
func (r *PayloadReader) GetInt8FromPayload() ([]int8, error) {
|
|
if r.colType != schemapb.DataType_Int8 {
|
|
return nil, fmt.Errorf("failed to get int8 from datatype %v", r.colType.String())
|
|
}
|
|
dumper, err := r.createDumper()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
ret := make([]int8, r.numRows)
|
|
var i int64
|
|
for i = 0; i < r.numRows; i++ {
|
|
v, hasValue := dumper.Next()
|
|
if !hasValue {
|
|
return nil, fmt.Errorf("unmatched row number: expect %v, actual %v", r.numRows, i)
|
|
}
|
|
// need to trasfer because parquet didn't support int8
|
|
ret[i] = int8(v.(int32))
|
|
}
|
|
return ret, nil
|
|
}
|
|
|
|
func (r *PayloadReader) GetInt16FromPayload() ([]int16, error) {
|
|
if r.colType != schemapb.DataType_Int16 {
|
|
return nil, fmt.Errorf("failed to get int16 from datatype %v", r.colType.String())
|
|
}
|
|
dumper, err := r.createDumper()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
ret := make([]int16, r.numRows)
|
|
var i int64
|
|
for i = 0; i < r.numRows; i++ {
|
|
v, hasValue := dumper.Next()
|
|
if !hasValue {
|
|
return nil, fmt.Errorf("unmatched row number: expect %v, actual %v", r.numRows, i)
|
|
}
|
|
// need to trasfer because parquet didn't support int16
|
|
ret[i] = int16(v.(int32))
|
|
}
|
|
return ret, nil
|
|
}
|
|
|
|
func (r *PayloadReader) GetInt32FromPayload() ([]int32, error) {
|
|
if r.colType != schemapb.DataType_Int32 {
|
|
return nil, fmt.Errorf("failed to get int32 from datatype %v", r.colType.String())
|
|
}
|
|
dumper, err := r.createDumper()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
ret := make([]int32, r.numRows)
|
|
var i int64
|
|
for i = 0; i < r.numRows; i++ {
|
|
v, hasValue := dumper.Next()
|
|
if !hasValue {
|
|
return nil, fmt.Errorf("unmatched row number: expect %v, actual %v", r.numRows, i)
|
|
}
|
|
ret[i] = v.(int32)
|
|
}
|
|
return ret, nil
|
|
}
|
|
|
|
func (r *PayloadReader) GetInt64FromPayload() ([]int64, error) {
|
|
if r.colType != schemapb.DataType_Int64 {
|
|
return nil, fmt.Errorf("failed to get int64 from datatype %v", r.colType.String())
|
|
}
|
|
dumper, err := r.createDumper()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
ret := make([]int64, r.numRows)
|
|
var i int64
|
|
for i = 0; i < r.numRows; i++ {
|
|
v, hasValue := dumper.Next()
|
|
if !hasValue {
|
|
return nil, fmt.Errorf("unmatched row number: expect %v, actual %v", r.numRows, i)
|
|
}
|
|
ret[i] = v.(int64)
|
|
}
|
|
return ret, nil
|
|
}
|
|
|
|
func (r *PayloadReader) GetFloatFromPayload() ([]float32, error) {
|
|
if r.colType != schemapb.DataType_Float {
|
|
return nil, fmt.Errorf("failed to get float32 from datatype %v", r.colType.String())
|
|
}
|
|
dumper, err := r.createDumper()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
ret := make([]float32, r.numRows)
|
|
var i int64
|
|
for i = 0; i < r.numRows; i++ {
|
|
v, hasValue := dumper.Next()
|
|
if !hasValue {
|
|
return nil, fmt.Errorf("unmatched row number: expect %v, actual %v", r.numRows, i)
|
|
}
|
|
ret[i] = v.(float32)
|
|
}
|
|
return ret, nil
|
|
}
|
|
|
|
func (r *PayloadReader) GetDoubleFromPayload() ([]float64, error) {
|
|
if r.colType != schemapb.DataType_Double {
|
|
return nil, fmt.Errorf("failed to get double from datatype %v", r.colType.String())
|
|
}
|
|
dumper, err := r.createDumper()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
ret := make([]float64, r.numRows)
|
|
var i int64
|
|
for i = 0; i < r.numRows; i++ {
|
|
v, hasValue := dumper.Next()
|
|
if !hasValue {
|
|
return nil, fmt.Errorf("unmatched row number: expect %v, actual %v", r.numRows, i)
|
|
}
|
|
ret[i] = v.(float64)
|
|
}
|
|
return ret, nil
|
|
}
|
|
|
|
func (r *PayloadReader) GetStringFromPayload() ([]string, error) {
|
|
if r.colType != schemapb.DataType_String && r.colType != schemapb.DataType_VarChar {
|
|
return nil, fmt.Errorf("failed to get string from datatype %v", r.colType.String())
|
|
}
|
|
dumper, err := r.createDumper()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
ret := make([]string, r.numRows)
|
|
var i int64
|
|
for i = 0; i < r.numRows; i++ {
|
|
v, hasValue := dumper.Next()
|
|
if !hasValue {
|
|
return nil, fmt.Errorf("unmatched row number: expect %v, actual %v", r.numRows, i)
|
|
}
|
|
ret[i] = v.(parquet.ByteArray).String()
|
|
}
|
|
return ret, nil
|
|
}
|
|
|
|
// GetBinaryVectorFromPayload returns vector, dimension, error
|
|
func (r *PayloadReader) GetBinaryVectorFromPayload() ([]byte, int, error) {
|
|
if r.colType != schemapb.DataType_BinaryVector {
|
|
return nil, -1, fmt.Errorf("failed to get binary vector from datatype %v", r.colType.String())
|
|
}
|
|
dumper, err := r.createDumper()
|
|
if err != nil {
|
|
return nil, -1, err
|
|
}
|
|
|
|
dim := r.reader.RowGroup(0).Column(0).Descriptor().TypeLength()
|
|
ret := make([]byte, int64(dim)*r.numRows)
|
|
for i := 0; i < int(r.numRows); i++ {
|
|
v, ok := dumper.Next()
|
|
if !ok {
|
|
return nil, -1, fmt.Errorf("unmatched row number: row %v, dim %v", r.numRows, dim)
|
|
}
|
|
parquetArray := v.(parquet.FixedLenByteArray)
|
|
copy(ret[i*dim:(i+1)*dim], parquetArray)
|
|
}
|
|
return ret, dim * 8, nil
|
|
}
|
|
|
|
// GetFloatVectorFromPayload returns vector, dimension, error
|
|
func (r *PayloadReader) GetFloatVectorFromPayload() ([]float32, int, error) {
|
|
if r.colType != schemapb.DataType_FloatVector {
|
|
return nil, -1, fmt.Errorf("failed to get float vector from datatype %v", r.colType.String())
|
|
}
|
|
dumper, err := r.createDumper()
|
|
if err != nil {
|
|
return nil, -1, err
|
|
}
|
|
|
|
dim := r.reader.RowGroup(0).Column(0).Descriptor().TypeLength() / 4
|
|
ret := make([]float32, int64(dim)*r.numRows)
|
|
for i := 0; i < int(r.numRows); i++ {
|
|
v, ok := dumper.Next()
|
|
if !ok {
|
|
return nil, -1, fmt.Errorf("unmatched row number: row %v, dim %v", r.numRows, dim)
|
|
}
|
|
parquetArray := v.(parquet.FixedLenByteArray)
|
|
copy(arrow.Float32Traits.CastToBytes(ret[i*dim:(i+1)*dim]), parquetArray)
|
|
}
|
|
return ret, dim, nil
|
|
}
|
|
|
|
func (r *PayloadReader) GetPayloadLengthFromReader() (int, error) {
|
|
return int(r.numRows), nil
|
|
}
|
|
|
|
// Close closes the payload reader
|
|
func (r *PayloadReader) Close() {
|
|
r.reader.Close()
|
|
}
|
|
|
|
type Dumper struct {
|
|
reader file.ColumnChunkReader
|
|
batchSize int64
|
|
valueOffset int
|
|
valuesBuffered int
|
|
|
|
levelOffset int64
|
|
levelsBuffered int64
|
|
defLevels []int16
|
|
repLevels []int16
|
|
|
|
valueBuffer interface{}
|
|
}
|
|
|
|
func (r *PayloadReader) createDumper() (*Dumper, error) {
|
|
var valueBuffer interface{}
|
|
switch r.reader.RowGroup(0).Column(0).(type) {
|
|
case *file.BooleanColumnChunkReader:
|
|
if r.colType != schemapb.DataType_Bool {
|
|
return nil, errors.New("incorrect data type")
|
|
}
|
|
valueBuffer = make([]bool, r.numRows)
|
|
case *file.Int32ColumnChunkReader:
|
|
if r.colType != schemapb.DataType_Int32 && r.colType != schemapb.DataType_Int16 && r.colType != schemapb.DataType_Int8 {
|
|
return nil, fmt.Errorf("incorrect data type, expect int32/int16/int8 but find %v", r.colType.String())
|
|
}
|
|
valueBuffer = make([]int32, r.numRows)
|
|
case *file.Int64ColumnChunkReader:
|
|
if r.colType != schemapb.DataType_Int64 {
|
|
return nil, fmt.Errorf("incorrect data type, expect int64 but find %v", r.colType.String())
|
|
}
|
|
valueBuffer = make([]int64, r.numRows)
|
|
case *file.Float32ColumnChunkReader:
|
|
if r.colType != schemapb.DataType_Float {
|
|
return nil, fmt.Errorf("incorrect data type, expect float32 but find %v", r.colType.String())
|
|
}
|
|
valueBuffer = make([]float32, r.numRows)
|
|
case *file.Float64ColumnChunkReader:
|
|
if r.colType != schemapb.DataType_Double {
|
|
return nil, fmt.Errorf("incorrect data type, expect float64 but find %v", r.colType.String())
|
|
}
|
|
valueBuffer = make([]float64, r.numRows)
|
|
case *file.ByteArrayColumnChunkReader:
|
|
if r.colType != schemapb.DataType_String && r.colType != schemapb.DataType_VarChar {
|
|
return nil, fmt.Errorf("incorrect data type, expect string/varchar but find %v", r.colType.String())
|
|
}
|
|
valueBuffer = make([]parquet.ByteArray, r.numRows)
|
|
case *file.FixedLenByteArrayColumnChunkReader:
|
|
if r.colType != schemapb.DataType_FloatVector && r.colType != schemapb.DataType_BinaryVector {
|
|
return nil, fmt.Errorf("incorrect data type, expect floavector/binaryvector but find %v", r.colType.String())
|
|
}
|
|
valueBuffer = make([]parquet.FixedLenByteArray, r.numRows)
|
|
}
|
|
|
|
return &Dumper{
|
|
reader: r.reader.RowGroup(0).Column(0),
|
|
batchSize: r.numRows,
|
|
defLevels: make([]int16, r.numRows),
|
|
repLevels: make([]int16, r.numRows),
|
|
valueBuffer: valueBuffer,
|
|
}, nil
|
|
}
|
|
|
|
func (dump *Dumper) readNextBatch() {
|
|
switch reader := dump.reader.(type) {
|
|
case *file.BooleanColumnChunkReader:
|
|
values := dump.valueBuffer.([]bool)
|
|
dump.levelsBuffered, dump.valuesBuffered, _ = reader.ReadBatch(dump.batchSize, values, dump.defLevels, dump.repLevels)
|
|
case *file.Int32ColumnChunkReader:
|
|
values := dump.valueBuffer.([]int32)
|
|
dump.levelsBuffered, dump.valuesBuffered, _ = reader.ReadBatch(dump.batchSize, values, dump.defLevels, dump.repLevels)
|
|
case *file.Int64ColumnChunkReader:
|
|
values := dump.valueBuffer.([]int64)
|
|
dump.levelsBuffered, dump.valuesBuffered, _ = reader.ReadBatch(dump.batchSize, values, dump.defLevels, dump.repLevels)
|
|
case *file.Float32ColumnChunkReader:
|
|
values := dump.valueBuffer.([]float32)
|
|
dump.levelsBuffered, dump.valuesBuffered, _ = reader.ReadBatch(dump.batchSize, values, dump.defLevels, dump.repLevels)
|
|
case *file.Float64ColumnChunkReader:
|
|
values := dump.valueBuffer.([]float64)
|
|
dump.levelsBuffered, dump.valuesBuffered, _ = reader.ReadBatch(dump.batchSize, values, dump.defLevels, dump.repLevels)
|
|
case *file.Int96ColumnChunkReader:
|
|
values := dump.valueBuffer.([]parquet.Int96)
|
|
dump.levelsBuffered, dump.valuesBuffered, _ = reader.ReadBatch(dump.batchSize, values, dump.defLevels, dump.repLevels)
|
|
case *file.ByteArrayColumnChunkReader:
|
|
values := dump.valueBuffer.([]parquet.ByteArray)
|
|
dump.levelsBuffered, dump.valuesBuffered, _ = reader.ReadBatch(dump.batchSize, values, dump.defLevels, dump.repLevels)
|
|
case *file.FixedLenByteArrayColumnChunkReader:
|
|
values := dump.valueBuffer.([]parquet.FixedLenByteArray)
|
|
dump.levelsBuffered, dump.valuesBuffered, _ = reader.ReadBatch(dump.batchSize, values, dump.defLevels, dump.repLevels)
|
|
}
|
|
|
|
dump.valueOffset = 0
|
|
dump.levelOffset = 0
|
|
}
|
|
|
|
func (dump *Dumper) hasNext() bool {
|
|
return dump.levelOffset < dump.levelsBuffered || dump.reader.HasNext()
|
|
}
|
|
|
|
func (dump *Dumper) Next() (interface{}, bool) {
|
|
if dump.levelOffset == dump.levelsBuffered {
|
|
if !dump.hasNext() {
|
|
return nil, false
|
|
}
|
|
dump.readNextBatch()
|
|
if dump.levelsBuffered == 0 {
|
|
return nil, false
|
|
}
|
|
}
|
|
|
|
defLevel := dump.defLevels[int(dump.levelOffset)]
|
|
// repLevel := dump.repLevels[int(dump.levelOffset)]
|
|
dump.levelOffset++
|
|
|
|
if defLevel < dump.reader.Descriptor().MaxDefinitionLevel() {
|
|
return nil, true
|
|
}
|
|
|
|
vb := reflect.ValueOf(dump.valueBuffer)
|
|
v := vb.Index(dump.valueOffset).Interface()
|
|
dump.valueOffset++
|
|
|
|
return v, true
|
|
}
|