diff --git a/internal/storage/data_codec.go b/internal/storage/data_codec.go index 111fca1d2a..bfcbe4a4c5 100644 --- a/internal/storage/data_codec.go +++ b/internal/storage/data_codec.go @@ -1005,49 +1005,45 @@ func (deleteCodec *DeleteCodec) Deserialize(blobs []*Blob) (partitionID UniqueID var pid, sid UniqueID result := &DeleteData{} - for _, blob := range blobs { + + deserializeBlob := func(blob *Blob) error { binlogReader, err := NewBinlogReader(blob.Value) if err != nil { - return InvalidUniqueID, InvalidUniqueID, nil, err + return err } + defer binlogReader.Close() pid, sid = binlogReader.PartitionID, binlogReader.SegmentID eventReader, err := binlogReader.NextEventReader() if err != nil { - binlogReader.Close() - return InvalidUniqueID, InvalidUniqueID, nil, err + return err } + defer eventReader.Close() - dataset, err := eventReader.GetByteArrayDataSet() + rr, err := eventReader.GetArrowRecordReader() if err != nil { - eventReader.Close() - binlogReader.Close() - return InvalidUniqueID, InvalidUniqueID, nil, err + return err } + defer rr.Release() - batchSize := int64(1024) - for dataset.HasNext() { - stringArray, err := dataset.NextBatch(batchSize) - if err != nil { - return InvalidUniqueID, InvalidUniqueID, nil, err - } - for i := 0; i < len(stringArray); i++ { + for rr.Next() { + rec := rr.Record() + defer rec.Release() + column := rec.Column(0) + for i := 0; i < column.Len(); i++ { deleteLog := &DeleteLog{} - if err = json.Unmarshal(stringArray[i], deleteLog); err != nil { + strVal := column.ValueStr(i) + if err = json.Unmarshal([]byte(strVal), deleteLog); err != nil { // compatible with versions that only support int64 type primary keys // compatible with fmt.Sprintf("%d,%d", pk, ts) // compatible error info (unmarshal err invalid character ',' after top-level value) - splits := strings.Split(stringArray[i].String(), ",") + splits := strings.Split(strVal, ",") if len(splits) != 2 { - eventReader.Close() - binlogReader.Close() - return InvalidUniqueID, InvalidUniqueID, nil, fmt.Errorf("the format of delta log is incorrect, %v can not be split", stringArray[i]) + return fmt.Errorf("the format of delta log is incorrect, %v can not be split", strVal) } pk, err := strconv.ParseInt(splits[0], 10, 64) if err != nil { - eventReader.Close() - binlogReader.Close() - return InvalidUniqueID, InvalidUniqueID, nil, err + return err } deleteLog.Pk = &Int64PrimaryKey{ Value: pk, @@ -1055,17 +1051,20 @@ func (deleteCodec *DeleteCodec) Deserialize(blobs []*Blob) (partitionID UniqueID deleteLog.PkType = int64(schemapb.DataType_Int64) deleteLog.Ts, err = strconv.ParseUint(splits[1], 10, 64) if err != nil { - eventReader.Close() - binlogReader.Close() - return InvalidUniqueID, InvalidUniqueID, nil, err + return err } } result.Append(deleteLog.Pk, deleteLog.Ts) } } - eventReader.Close() - binlogReader.Close() + return nil + } + + for _, blob := range blobs { + if err := deserializeBlob(blob); err != nil { + return InvalidUniqueID, InvalidUniqueID, nil, err + } } return pid, sid, result, nil diff --git a/internal/storage/payload.go b/internal/storage/payload.go index d4fc137cbd..c810c67659 100644 --- a/internal/storage/payload.go +++ b/internal/storage/payload.go @@ -19,6 +19,7 @@ package storage import ( "github.com/apache/arrow/go/v12/parquet" "github.com/apache/arrow/go/v12/parquet/file" + "github.com/apache/arrow/go/v12/parquet/pqarrow" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" ) @@ -69,6 +70,7 @@ type PayloadReaderInterface interface { GetPayloadLengthFromReader() (int, error) GetByteArrayDataSet() (*DataSet[parquet.ByteArray, *file.ByteArrayColumnChunkReader], error) + GetArrowRecordReader() (pqarrow.RecordReader, error) ReleasePayloadReader() error Close() error diff --git a/internal/storage/payload_reader.go b/internal/storage/payload_reader.go index 6f6f185d40..1a5b462209 100644 --- a/internal/storage/payload_reader.go +++ b/internal/storage/payload_reader.go @@ -2,11 +2,14 @@ package storage import ( "bytes" + "context" "fmt" "github.com/apache/arrow/go/v12/arrow" + "github.com/apache/arrow/go/v12/arrow/memory" "github.com/apache/arrow/go/v12/parquet" "github.com/apache/arrow/go/v12/parquet/file" + "github.com/apache/arrow/go/v12/parquet/pqarrow" "github.com/cockroachdb/errors" "github.com/golang/protobuf/proto" @@ -263,6 +266,19 @@ func (r *PayloadReader) GetByteArrayDataSet() (*DataSet[parquet.ByteArray, *file return NewDataSet[parquet.ByteArray, *file.ByteArrayColumnChunkReader](r.reader, 0, r.numRows), nil } +func (r *PayloadReader) GetArrowRecordReader() (pqarrow.RecordReader, error) { + arrowReader, err := pqarrow.NewFileReader(r.reader, pqarrow.ArrowReadProperties{BatchSize: 1024}, memory.DefaultAllocator) + if err != nil { + return nil, err + } + + rr, err := arrowReader.GetRecordReader(context.Background(), nil, nil) + if err != nil { + return nil, err + } + return rr, nil +} + func (r *PayloadReader) GetArrayFromPayload() ([]*schemapb.ScalarField, error) { if r.colType != schemapb.DataType_Array { return nil, fmt.Errorf("failed to get string from datatype %v", r.colType.String()) diff --git a/internal/storage/payload_test.go b/internal/storage/payload_test.go index b49058545f..fe0db83732 100644 --- a/internal/storage/payload_test.go +++ b/internal/storage/payload_test.go @@ -18,8 +18,10 @@ package storage import ( "math" + "math/rand" "testing" + "github.com/apache/arrow/go/v12/arrow/array" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -1238,6 +1240,30 @@ func TestPayload_ReaderAndWriter(t *testing.T) { _, err = r.GetStringFromPayload() assert.Error(t, err) }) + t.Run("TestGetArrayError", func(t *testing.T) { + w, err := NewPayloadWriter(schemapb.DataType_Bool) + require.Nil(t, err) + require.NotNil(t, w) + + err = w.AddBoolToPayload([]bool{false, true, true}) + assert.NoError(t, err) + + err = w.FinishPayloadWriter() + assert.NoError(t, err) + + buffer, err := w.GetPayloadBufferFromWriter() + assert.NoError(t, err) + + r, err := NewPayloadReader(schemapb.DataType_Array, buffer) + assert.NoError(t, err) + + _, err = r.GetArrayFromPayload() + assert.Error(t, err) + + r.colType = 999 + _, err = r.GetArrayFromPayload() + assert.Error(t, err) + }) t.Run("TestGetBinaryVectorError", func(t *testing.T) { w, err := NewPayloadWriter(schemapb.DataType_Bool) require.Nil(t, err) @@ -1385,3 +1411,152 @@ func TestPayload_ReaderAndWriter(t *testing.T) { w.ReleasePayloadWriter() }) } + +func TestArrowRecordReader(t *testing.T) { + t.Run("TestArrowRecordReader", func(t *testing.T) { + w, err := NewPayloadWriter(schemapb.DataType_String) + assert.NoError(t, err) + defer w.Close() + + err = w.AddOneStringToPayload("hello0") + assert.NoError(t, err) + err = w.AddOneStringToPayload("hello1") + assert.NoError(t, err) + err = w.AddOneStringToPayload("hello2") + assert.NoError(t, err) + err = w.FinishPayloadWriter() + assert.NoError(t, err) + length, err := w.GetPayloadLengthFromWriter() + assert.NoError(t, err) + assert.Equal(t, 3, length) + buffer, err := w.GetPayloadBufferFromWriter() + assert.NoError(t, err) + + r, err := NewPayloadReader(schemapb.DataType_String, buffer) + assert.NoError(t, err) + length, err = r.GetPayloadLengthFromReader() + assert.NoError(t, err) + assert.Equal(t, 3, length) + + rr, err := r.GetArrowRecordReader() + assert.NoError(t, err) + + for rr.Next() { + rec := rr.Record() + arr := rec.Column(0).(*array.String) + defer rec.Release() + + assert.Equal(t, "hello0", arr.Value(0)) + assert.Equal(t, "hello1", arr.Value(1)) + assert.Equal(t, "hello2", arr.Value(2)) + } + }) +} + +func dataGen(size int) ([]byte, error) { + w, err := NewPayloadWriter(schemapb.DataType_String) + if err != nil { + return nil, err + } + defer w.Close() + + letterRunes := []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") + + for i := 0; i < size; i++ { + b := make([]rune, 20) + for i := range b { + b[i] = letterRunes[rand.Intn(len(letterRunes))] + } + w.AddOneStringToPayload(string(b)) + } + err = w.FinishPayloadWriter() + if err != nil { + return nil, err + } + buffer, err := w.GetPayloadBufferFromWriter() + if err != nil { + return nil, err + } + return buffer, err +} + +func BenchmarkDefaultReader(b *testing.B) { + size := 1000000 + buffer, err := dataGen(size) + assert.NoError(b, err) + + b.ResetTimer() + r, err := NewPayloadReader(schemapb.DataType_String, buffer) + require.Nil(b, err) + defer r.ReleasePayloadReader() + + length, err := r.GetPayloadLengthFromReader() + assert.NoError(b, err) + assert.Equal(b, length, size) + + d, err := r.GetStringFromPayload() + assert.NoError(b, err) + for i := 0; i < 100; i++ { + for _, de := range d { + assert.Equal(b, 20, len(de)) + } + } +} + +func BenchmarkDataSetReader(b *testing.B) { + size := 1000000 + buffer, err := dataGen(size) + assert.NoError(b, err) + + b.ResetTimer() + r, err := NewPayloadReader(schemapb.DataType_String, buffer) + require.Nil(b, err) + defer r.ReleasePayloadReader() + + length, err := r.GetPayloadLengthFromReader() + assert.NoError(b, err) + assert.Equal(b, length, size) + + ds, err := r.GetByteArrayDataSet() + assert.NoError(b, err) + + for i := 0; i < 100; i++ { + for ds.HasNext() { + stringArray, err := ds.NextBatch(1024) + assert.NoError(b, err) + for _, de := range stringArray { + assert.Equal(b, 20, len(string(de))) + } + } + } +} + +func BenchmarkArrowRecordReader(b *testing.B) { + size := 1000000 + buffer, err := dataGen(size) + assert.NoError(b, err) + + b.ResetTimer() + r, err := NewPayloadReader(schemapb.DataType_String, buffer) + require.Nil(b, err) + defer r.ReleasePayloadReader() + + length, err := r.GetPayloadLengthFromReader() + assert.NoError(b, err) + assert.Equal(b, length, size) + + rr, err := r.GetArrowRecordReader() + assert.NoError(b, err) + defer rr.Release() + + for i := 0; i < 100; i++ { + for rr.Next() { + rec := rr.Record() + arr := rec.Column(0).(*array.String) + defer rec.Release() + for i := 0; i < arr.Len(); i++ { + assert.Equal(b, 20, len(arr.Value(i))) + } + } + } +}