enhance: avoid memory copy and serde in mix compaction (#37479)
Some checks are pending
Code Checker / Code Checker AMD64 Ubuntu 22.04 (push) Waiting to run
Code Checker / Code Checker Amazonlinux 2023 (push) Waiting to run
Code Checker / Code Checker rockylinux8 (push) Waiting to run
Mac Code Checker / Code Checker MacOS 12 (push) Waiting to run
Build and test / Build and test AMD64 Ubuntu 22.04 (push) Waiting to run
Build and test / UT for Cpp (push) Blocked by required conditions
Build and test / UT for Go (push) Blocked by required conditions
Build and test / Integration Test (push) Blocked by required conditions
Build and test / Upload Code Coverage (push) Blocked by required conditions

See: #37234

---------

Signed-off-by: Ted Xu <ted.xu@zilliz.com>
This commit is contained in:
Ted Xu 2024-11-08 08:30:57 +08:00 committed by GitHub
parent 1f7ce9e9c1
commit bc9562feb1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 361 additions and 100 deletions

View File

@ -23,6 +23,7 @@ import (
"math"
"time"
"github.com/apache/arrow/go/v12/arrow/array"
"github.com/cockroachdb/errors"
"github.com/samber/lo"
"go.opentelemetry.io/otel"
@ -33,6 +34,7 @@ import (
"github.com/milvus-io/milvus/internal/flushcommon/io"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/metrics"
"github.com/milvus-io/milvus/pkg/util/funcutil"
@ -199,26 +201,43 @@ func (t *mixCompactionTask) writeSegment(ctx context.Context,
log.Warn("compact wrong, fail to merge deltalogs", zap.Error(err))
return
}
isValueDeleted := func(v *storage.Value) bool {
ts, ok := delta[v.PK.GetValue()]
isValueDeleted := func(pk any, ts typeutil.Timestamp) bool {
oldts, ok := delta[pk]
// insert task and delete task has the same ts when upsert
// here should be < instead of <=
// to avoid the upsert data to be deleted after compact
if ok && uint64(v.Timestamp) < ts {
if ok && ts < oldts {
deletedRowCount++
return true
}
// Filtering expired entity
if isExpiredEntity(t.plan.GetCollectionTtl(), t.currentTs, typeutil.Timestamp(ts)) {
expiredRowCount++
return true
}
return false
}
iter, err := storage.NewBinlogDeserializeReader(blobs, pkField.GetFieldID())
reader, err := storage.NewCompositeBinlogRecordReader(blobs)
if err != nil {
log.Warn("compact wrong, failed to new insert binlogs reader", zap.Error(err))
return
}
defer iter.Close()
defer reader.Close()
writeSlice := func(r storage.Record, start, end int) error {
sliced := r.Slice(start, end)
defer sliced.Release()
err = mWriter.WriteRecord(sliced)
if err != nil {
log.Warn("compact wrong, failed to writer row", zap.Error(err))
return err
}
return nil
}
for {
err = iter.Next()
err = reader.Next()
if err != nil {
if err == sio.EOF {
err = nil
@ -228,23 +247,45 @@ func (t *mixCompactionTask) writeSegment(ctx context.Context,
return
}
}
v := iter.Value()
r := reader.Record()
pkArray := r.Column(pkField.FieldID)
tsArray := r.Column(common.TimeStampField).(*array.Int64)
if isValueDeleted(v) {
deletedRowCount++
continue
sliceStart := -1
rows := r.Len()
for i := 0; i < rows; i++ {
// Filtering deleted entities
var pk any
switch pkField.DataType {
case schemapb.DataType_Int64:
pk = pkArray.(*array.Int64).Value(i)
case schemapb.DataType_VarChar:
pk = pkArray.(*array.String).Value(i)
default:
panic("invalid data type")
}
ts := typeutil.Timestamp(tsArray.Value(i))
if isValueDeleted(pk, ts) {
if sliceStart != -1 {
err = writeSlice(r, sliceStart, i)
if err != nil {
return
}
sliceStart = -1
}
continue
}
if sliceStart == -1 {
sliceStart = i
}
}
// Filtering expired entity
if isExpiredEntity(t.plan.GetCollectionTtl(), t.currentTs, typeutil.Timestamp(v.Timestamp)) {
expiredRowCount++
continue
}
err = mWriter.Write(v)
if err != nil {
log.Warn("compact wrong, failed to writer row", zap.Error(err))
return
if sliceStart != -1 {
err = writeSlice(r, sliceStart, r.Len())
if err != nil {
return
}
}
}
return

View File

@ -9,6 +9,7 @@ import (
"fmt"
"math"
"github.com/apache/arrow/go/v12/arrow/array"
"github.com/samber/lo"
"go.uber.org/atomic"
"go.uber.org/zap"
@ -20,6 +21,7 @@ import (
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/etcdpb"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/typeutil"
@ -183,12 +185,7 @@ func (w *MultiSegmentWriter) getWriter() (*SegmentWriter, error) {
return w.writers[w.current], nil
}
func (w *MultiSegmentWriter) Write(v *storage.Value) error {
writer, err := w.getWriter()
if err != nil {
return err
}
func (w *MultiSegmentWriter) writeInternal(writer *SegmentWriter) error {
if writer.IsFull() {
// init segment fieldBinlogs if it is not exist
if _, ok := w.cachedMeta[writer.segmentID]; !ok {
@ -206,6 +203,29 @@ func (w *MultiSegmentWriter) Write(v *storage.Value) error {
mergeFieldBinlogs(w.cachedMeta[writer.segmentID], partialBinlogs)
}
return nil
}
func (w *MultiSegmentWriter) WriteRecord(r storage.Record) error {
writer, err := w.getWriter()
if err != nil {
return err
}
if err := w.writeInternal(writer); err != nil {
return err
}
return writer.WriteRecord(r)
}
func (w *MultiSegmentWriter) Write(v *storage.Value) error {
writer, err := w.getWriter()
if err != nil {
return err
}
if err := w.writeInternal(writer); err != nil {
return err
}
return writer.Write(v)
}
@ -358,6 +378,48 @@ func (w *SegmentWriter) WrittenMemorySize() uint64 {
return w.writer.WrittenMemorySize()
}
func (w *SegmentWriter) WriteRecord(r storage.Record) error {
tsArray := r.Column(common.TimeStampField).(*array.Int64)
rows := r.Len()
for i := 0; i < rows; i++ {
ts := typeutil.Timestamp(tsArray.Value(i))
if ts < w.tsFrom {
w.tsFrom = ts
}
if ts > w.tsTo {
w.tsTo = ts
}
switch schemapb.DataType(w.pkstats.PkType) {
case schemapb.DataType_Int64:
pkArray := r.Column(w.GetPkID()).(*array.Int64)
pk := &storage.Int64PrimaryKey{
Value: pkArray.Value(i),
}
w.pkstats.Update(pk)
case schemapb.DataType_VarChar:
pkArray := r.Column(w.GetPkID()).(*array.String)
pk := &storage.VarCharPrimaryKey{
Value: pkArray.Value(i),
}
w.pkstats.Update(pk)
default:
panic("invalid data type")
}
for fieldID, stats := range w.bm25Stats {
field, ok := r.Column(fieldID).(*array.Binary)
if !ok {
return fmt.Errorf("bm25 field value not found")
}
stats.AppendBytes(field.Value(i))
}
w.rowCount.Inc()
}
return w.writer.WriteRecord(r)
}
func (w *SegmentWriter) Write(v *storage.Value) error {
ts := typeutil.Timestamp(v.Timestamp)
if ts < w.tsFrom {

View File

@ -24,9 +24,12 @@ import (
"github.com/apache/arrow/go/v12/arrow"
"github.com/apache/arrow/go/v12/arrow/array"
"github.com/apache/arrow/go/v12/arrow/memory"
"github.com/apache/arrow/go/v12/parquet"
"github.com/apache/arrow/go/v12/parquet/compress"
"github.com/apache/arrow/go/v12/parquet/pqarrow"
"github.com/cockroachdb/errors"
"github.com/samber/lo"
"go.uber.org/atomic"
"google.golang.org/protobuf/proto"
@ -40,6 +43,7 @@ type Record interface {
Column(i FieldID) arrow.Array
Len() int
Release()
Slice(start, end int) Record
}
type RecordReader interface {
@ -50,6 +54,7 @@ type RecordReader interface {
type RecordWriter interface {
Write(r Record) error
GetWrittenUncompressed() uint64
Close()
}
@ -64,6 +69,8 @@ type compositeRecord struct {
schema map[FieldID]schemapb.DataType
}
var _ Record = (*compositeRecord)(nil)
func (r *compositeRecord) Column(i FieldID) arrow.Array {
return r.recs[i].Column(0)
}
@ -93,6 +100,17 @@ func (r *compositeRecord) ArrowSchema() *arrow.Schema {
return arrow.NewSchema(fields, nil)
}
func (r *compositeRecord) Slice(start, end int) Record {
slices := make(map[FieldID]arrow.Record)
for i, rec := range r.recs {
slices[i] = rec.NewSlice(int64(start), int64(end))
}
return &compositeRecord{
recs: slices,
schema: r.schema,
}
}
type serdeEntry struct {
// arrowType returns the arrow type for the given dimension
arrowType func(int) arrow.DataType
@ -527,6 +545,17 @@ func (deser *DeserializeReader[T]) Next() error {
return nil
}
func (deser *DeserializeReader[T]) NextRecord() (Record, error) {
if len(deser.values) != 0 {
return nil, errors.New("deserialize result is not empty")
}
if err := deser.rr.Next(); err != nil {
return nil, err
}
return deser.rr.Record(), nil
}
func (deser *DeserializeReader[T]) Value() T {
return deser.values[deser.pos]
}
@ -580,6 +609,16 @@ func (r *selectiveRecord) Release() {
// do nothing.
}
func (r *selectiveRecord) Slice(start, end int) Record {
panic("not implemented")
}
func calculateArraySize(a arrow.Array) int {
return lo.SumBy[*memory.Buffer, int](a.Data().Buffers(), func(b *memory.Buffer) int {
return b.Len()
})
}
func newSelectiveRecord(r Record, selectedFieldId FieldID) *selectiveRecord {
dt, ok := r.Schema()[selectedFieldId]
if !ok {
@ -594,16 +633,29 @@ func newSelectiveRecord(r Record, selectedFieldId FieldID) *selectiveRecord {
}
}
var _ RecordWriter = (*compositeRecordWriter)(nil)
var _ RecordWriter = (*CompositeRecordWriter)(nil)
type compositeRecordWriter struct {
type CompositeRecordWriter struct {
writers map[FieldID]RecordWriter
writtenUncompressed uint64
}
func (crw *compositeRecordWriter) Write(r Record) error {
func (crw *CompositeRecordWriter) GetWrittenUncompressed() uint64 {
return crw.writtenUncompressed
}
func (crw *CompositeRecordWriter) Write(r Record) error {
if len(r.Schema()) != len(crw.writers) {
return fmt.Errorf("schema length mismatch %d, expected %d", len(r.Schema()), len(crw.writers))
}
var bytes uint64
for fid := range r.Schema() {
arr := r.Column(fid)
bytes += uint64(calculateArraySize(arr))
}
crw.writtenUncompressed += bytes
for fieldId, w := range crw.writers {
sr := newSelectiveRecord(r, fieldId)
if err := w.Write(sr); err != nil {
@ -613,7 +665,7 @@ func (crw *compositeRecordWriter) Write(r Record) error {
return nil
}
func (crw *compositeRecordWriter) Close() {
func (crw *CompositeRecordWriter) Close() {
if crw != nil {
for _, w := range crw.writers {
if w != nil {
@ -623,8 +675,8 @@ func (crw *compositeRecordWriter) Close() {
}
}
func newCompositeRecordWriter(writers map[FieldID]RecordWriter) *compositeRecordWriter {
return &compositeRecordWriter{
func NewCompositeRecordWriter(writers map[FieldID]RecordWriter) *CompositeRecordWriter {
return &CompositeRecordWriter{
writers: writers,
}
}
@ -640,22 +692,29 @@ func WithRecordWriterProps(writerProps *parquet.WriterProperties) RecordWriterOp
}
type singleFieldRecordWriter struct {
fw *pqarrow.FileWriter
fieldId FieldID
schema *arrow.Schema
numRows int
fw *pqarrow.FileWriter
fieldId FieldID
schema *arrow.Schema
writerProps *parquet.WriterProperties
numRows int
writtenUncompressed uint64
}
func (sfw *singleFieldRecordWriter) Write(r Record) error {
sfw.numRows += r.Len()
a := r.Column(sfw.fieldId)
sfw.writtenUncompressed += uint64(a.Data().Buffers()[0].Len())
rec := array.NewRecord(sfw.schema, []arrow.Array{a}, int64(r.Len()))
defer rec.Release()
return sfw.fw.WriteBuffered(rec)
}
func (sfw *singleFieldRecordWriter) GetWrittenUncompressed() uint64 {
return sfw.writtenUncompressed
}
func (sfw *singleFieldRecordWriter) Close() {
sfw.fw.Close()
}
@ -687,7 +746,8 @@ type multiFieldRecordWriter struct {
fieldIds []FieldID
schema *arrow.Schema
numRows int
numRows int
writtenUncompressed uint64
}
func (mfw *multiFieldRecordWriter) Write(r Record) error {
@ -695,12 +755,17 @@ func (mfw *multiFieldRecordWriter) Write(r Record) error {
columns := make([]arrow.Array, len(mfw.fieldIds))
for i, fieldId := range mfw.fieldIds {
columns[i] = r.Column(fieldId)
mfw.writtenUncompressed += uint64(calculateArraySize(columns[i]))
}
rec := array.NewRecord(mfw.schema, columns, int64(r.Len()))
defer rec.Release()
return mfw.fw.WriteBuffered(rec)
}
func (mfw *multiFieldRecordWriter) GetWrittenUncompressed() uint64 {
return mfw.writtenUncompressed
}
func (mfw *multiFieldRecordWriter) Close() {
mfw.fw.Close()
}
@ -765,6 +830,23 @@ func (sw *SerializeWriter[T]) Write(value T) error {
return nil
}
func (sw *SerializeWriter[T]) WriteRecord(r Record) error {
if len(sw.buffer) != 0 {
return errors.New("serialize buffer is not empty")
}
if err := sw.rw.Write(r); err != nil {
return err
}
size := 0
for fid := range r.Schema() {
size += calculateArraySize(r.Column(fid))
}
sw.writtenMemorySize.Add(uint64(size))
return nil
}
func (sw *SerializeWriter[T]) WrittenMemorySize() uint64 {
return sw.writtenMemorySize.Load()
}

View File

@ -28,7 +28,6 @@ import (
"github.com/apache/arrow/go/v12/arrow"
"github.com/apache/arrow/go/v12/arrow/array"
"github.com/apache/arrow/go/v12/arrow/memory"
"github.com/samber/lo"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/common"
@ -38,9 +37,9 @@ import (
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
var _ RecordReader = (*compositeBinlogRecordReader)(nil)
var _ RecordReader = (*CompositeBinlogRecordReader)(nil)
type compositeBinlogRecordReader struct {
type CompositeBinlogRecordReader struct {
blobs [][]*Blob
blobPos int
@ -51,7 +50,7 @@ type compositeBinlogRecordReader struct {
r compositeRecord
}
func (crr *compositeBinlogRecordReader) iterateNextBatch() error {
func (crr *CompositeBinlogRecordReader) iterateNextBatch() error {
if crr.closers != nil {
for _, close := range crr.closers {
if close != nil {
@ -91,7 +90,7 @@ func (crr *compositeBinlogRecordReader) iterateNextBatch() error {
return nil
}
func (crr *compositeBinlogRecordReader) Next() error {
func (crr *CompositeBinlogRecordReader) Next() error {
if crr.rrs == nil {
if crr.blobs == nil || len(crr.blobs) == 0 {
return io.EOF
@ -135,11 +134,11 @@ func (crr *compositeBinlogRecordReader) Next() error {
return nil
}
func (crr *compositeBinlogRecordReader) Record() Record {
func (crr *CompositeBinlogRecordReader) Record() Record {
return &crr.r
}
func (crr *compositeBinlogRecordReader) Close() {
func (crr *CompositeBinlogRecordReader) Close() {
for _, close := range crr.closers {
if close != nil {
close()
@ -158,7 +157,7 @@ func parseBlobKey(blobKey string) (colId FieldID, logId UniqueID) {
return InvalidUniqueID, InvalidUniqueID
}
func newCompositeBinlogRecordReader(blobs []*Blob) (*compositeBinlogRecordReader, error) {
func NewCompositeBinlogRecordReader(blobs []*Blob) (*CompositeBinlogRecordReader, error) {
blobMap := make(map[FieldID][]*Blob)
for _, blob := range blobs {
colId, _ := parseBlobKey(blob.Key)
@ -178,13 +177,13 @@ func newCompositeBinlogRecordReader(blobs []*Blob) (*compositeBinlogRecordReader
})
sortedBlobs = append(sortedBlobs, blobsForField)
}
return &compositeBinlogRecordReader{
return &CompositeBinlogRecordReader{
blobs: sortedBlobs,
}, nil
}
func NewBinlogDeserializeReader(blobs []*Blob, PKfieldID UniqueID) (*DeserializeReader[*Value], error) {
reader, err := newCompositeBinlogRecordReader(blobs)
reader, err := NewCompositeBinlogRecordReader(blobs)
if err != nil {
return nil, err
}
@ -234,7 +233,7 @@ func NewBinlogDeserializeReader(blobs []*Blob, PKfieldID UniqueID) (*Deserialize
}
func newDeltalogOneFieldReader(blobs []*Blob) (*DeserializeReader[*DeleteLog], error) {
reader, err := newCompositeBinlogRecordReader(blobs)
reader, err := NewCompositeBinlogRecordReader(blobs)
if err != nil {
return nil, err
}
@ -264,8 +263,6 @@ type BinlogStreamWriter struct {
segmentID UniqueID
fieldSchema *schemapb.FieldSchema
memorySize int // To be updated on the fly
buf bytes.Buffer
rw *singleFieldRecordWriter
}
@ -306,7 +303,7 @@ func (bsw *BinlogStreamWriter) Finalize() (*Blob, error) {
Key: strconv.Itoa(int(bsw.fieldSchema.FieldID)),
Value: b.Bytes(),
RowNum: int64(bsw.rw.numRows),
MemorySize: int64(bsw.memorySize),
MemorySize: int64(bsw.rw.writtenUncompressed),
}, nil
}
@ -319,7 +316,7 @@ func (bsw *BinlogStreamWriter) writeBinlogHeaders(w io.Writer) error {
de := NewBaseDescriptorEvent(bsw.collectionID, bsw.partitionID, bsw.segmentID)
de.PayloadDataType = bsw.fieldSchema.DataType
de.FieldID = bsw.fieldSchema.FieldID
de.descriptorEventData.AddExtra(originalSizeKey, strconv.Itoa(bsw.memorySize))
de.descriptorEventData.AddExtra(originalSizeKey, strconv.Itoa(int(bsw.rw.writtenUncompressed)))
de.descriptorEventData.AddExtra(nullableKey, bsw.fieldSchema.Nullable)
if err := de.Write(w); err != nil {
return err
@ -356,6 +353,50 @@ func NewBinlogStreamWriters(collectionID, partitionID, segmentID UniqueID,
return bws
}
func ValueSerializer(v []*Value, fieldSchema []*schemapb.FieldSchema) (Record, uint64, error) {
builders := make(map[FieldID]array.Builder, len(fieldSchema))
types := make(map[FieldID]schemapb.DataType, len(fieldSchema))
for _, f := range fieldSchema {
dim, _ := typeutil.GetDim(f)
builders[f.FieldID] = array.NewBuilder(memory.DefaultAllocator, serdeMap[f.DataType].arrowType(int(dim)))
types[f.FieldID] = f.DataType
}
var memorySize uint64
for _, vv := range v {
m := vv.Value.(map[FieldID]any)
for fid, e := range m {
typeEntry, ok := serdeMap[types[fid]]
if !ok {
panic("unknown type")
}
ok = typeEntry.serialize(builders[fid], e)
if !ok {
return nil, 0, merr.WrapErrServiceInternal(fmt.Sprintf("serialize error on type %s", types[fid]))
}
}
}
arrays := make([]arrow.Array, len(types))
fields := make([]arrow.Field, len(types))
field2Col := make(map[FieldID]int, len(types))
i := 0
for fid, builder := range builders {
arrays[i] = builder.NewArray()
memorySize += uint64(calculateArraySize(arrays[i]))
builder.Release()
fields[i] = arrow.Field{
Name: strconv.Itoa(int(fid)),
Type: arrays[i].DataType(),
Nullable: true, // No nullable check here.
}
field2Col[fid] = i
i++
}
return newSimpleArrowRecord(array.NewRecord(arrow.NewSchema(fields, nil), arrays, int64(len(v))), types, field2Col), memorySize, nil
}
func NewBinlogSerializeWriter(schema *schemapb.CollectionSchema, partitionID, segmentID UniqueID,
eventWriters map[FieldID]*BinlogStreamWriter, batchSize int,
) (*SerializeWriter[*Value], error) {
@ -368,53 +409,9 @@ func NewBinlogSerializeWriter(schema *schemapb.CollectionSchema, partitionID, se
}
rws[fid] = rw
}
compositeRecordWriter := newCompositeRecordWriter(rws)
compositeRecordWriter := NewCompositeRecordWriter(rws)
return NewSerializeRecordWriter[*Value](compositeRecordWriter, func(v []*Value) (Record, uint64, error) {
builders := make(map[FieldID]array.Builder, len(schema.Fields))
types := make(map[FieldID]schemapb.DataType, len(schema.Fields))
for _, f := range schema.Fields {
dim, _ := typeutil.GetDim(f)
builders[f.FieldID] = array.NewBuilder(memory.DefaultAllocator, serdeMap[f.DataType].arrowType(int(dim)))
types[f.FieldID] = f.DataType
}
var memorySize uint64
for _, vv := range v {
m := vv.Value.(map[FieldID]any)
for fid, e := range m {
typeEntry, ok := serdeMap[types[fid]]
if !ok {
panic("unknown type")
}
ok = typeEntry.serialize(builders[fid], e)
if !ok {
return nil, 0, merr.WrapErrServiceInternal(fmt.Sprintf("serialize error on type %s", types[fid]))
}
}
}
arrays := make([]arrow.Array, len(types))
fields := make([]arrow.Field, len(types))
field2Col := make(map[FieldID]int, len(types))
i := 0
for fid, builder := range builders {
arrays[i] = builder.NewArray()
size := lo.SumBy[*memory.Buffer, int](arrays[i].Data().Buffers(), func(b *memory.Buffer) int {
return b.Len()
})
eventWriters[fid].memorySize += size
memorySize += uint64(size)
builder.Release()
fields[i] = arrow.Field{
Name: strconv.Itoa(int(fid)),
Type: arrays[i].DataType(),
Nullable: true, // No nullable check here.
}
field2Col[fid] = i
i++
}
return newSimpleArrowRecord(array.NewRecord(arrow.NewSchema(fields, nil), arrays, int64(len(v))), types, field2Col), memorySize, nil
return ValueSerializer(v, schema.Fields)
}, batchSize), nil
}
@ -515,7 +512,7 @@ func newDeltalogSerializeWriter(eventWriter *DeltalogStreamWriter, batchSize int
return nil, err
}
rws[0] = rw
compositeRecordWriter := newCompositeRecordWriter(rws)
compositeRecordWriter := NewCompositeRecordWriter(rws)
return NewSerializeRecordWriter[*DeleteLog](compositeRecordWriter, func(v []*DeleteLog) (Record, uint64, error) {
builder := array.NewBuilder(memory.DefaultAllocator, arrow.BinaryTypes.String)

View File

@ -19,9 +19,13 @@ package storage
import (
"bytes"
"context"
"fmt"
"io"
"math/rand"
"sort"
"strconv"
"testing"
"time"
"github.com/apache/arrow/go/v12/arrow"
"github.com/apache/arrow/go/v12/arrow/array"
@ -29,9 +33,12 @@ import (
"github.com/apache/arrow/go/v12/parquet/file"
"github.com/apache/arrow/go/v12/parquet/pqarrow"
"github.com/stretchr/testify/assert"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
)
func TestBinlogDeserializeReader(t *testing.T) {
@ -182,6 +189,78 @@ func TestBinlogSerializeWriter(t *testing.T) {
})
}
func BenchmarkSerializeWriter(b *testing.B) {
const (
dim = 128
numRows = 200000
)
var (
rId = &schemapb.FieldSchema{FieldID: common.RowIDField, Name: common.RowIDFieldName, DataType: schemapb.DataType_Int64}
ts = &schemapb.FieldSchema{FieldID: common.TimeStampField, Name: common.TimeStampFieldName, DataType: schemapb.DataType_Int64}
pk = &schemapb.FieldSchema{FieldID: 100, Name: "pk", IsPrimaryKey: true, DataType: schemapb.DataType_VarChar, TypeParams: []*commonpb.KeyValuePair{{Key: common.MaxLengthKey, Value: "100"}}}
f = &schemapb.FieldSchema{FieldID: 101, Name: "random", DataType: schemapb.DataType_Double}
// fVec = &schemapb.FieldSchema{FieldID: 102, Name: "vec", DataType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{{Key: common.DimKey, Value: strconv.Itoa(dim)}}}
)
schema := &schemapb.CollectionSchema{Name: "test-aaa", Fields: []*schemapb.FieldSchema{rId, ts, pk, f}}
// prepare data values
start := time.Now()
vec := make([]float32, dim)
for j := 0; j < dim; j++ {
vec[j] = rand.Float32()
}
values := make([]*Value, numRows)
for i := 0; i < numRows; i++ {
value := &Value{}
value.Value = make(map[int64]interface{}, len(schema.GetFields()))
m := value.Value.(map[int64]interface{})
for _, field := range schema.GetFields() {
switch field.GetDataType() {
case schemapb.DataType_Int64:
m[field.GetFieldID()] = int64(i)
case schemapb.DataType_VarChar:
k := fmt.Sprintf("test_pk_%d", i)
m[field.GetFieldID()] = k
value.PK = &VarCharPrimaryKey{
Value: k,
}
case schemapb.DataType_Double:
m[field.GetFieldID()] = float64(i)
case schemapb.DataType_FloatVector:
m[field.GetFieldID()] = vec
}
}
value.ID = int64(i)
value.Timestamp = int64(0)
value.IsDeleted = false
value.Value = m
values[i] = value
}
sort.Slice(values, func(i, j int) bool {
return values[i].PK.LT(values[j].PK)
})
log.Info("prepare data done", zap.Int("len", len(values)), zap.Duration("dur", time.Since(start)))
b.ResetTimer()
sizes := []int{100, 1000, 10000, 100000}
for _, s := range sizes {
b.Run(fmt.Sprintf("batch size=%d", s), func(b *testing.B) {
for i := 0; i < b.N; i++ {
writers := NewBinlogStreamWriters(0, 0, 0, schema.Fields)
writer, err := NewBinlogSerializeWriter(schema, 0, 0, writers, s)
assert.NoError(b, err)
for _, v := range values {
_ = writer.Write(v)
assert.NoError(b, err)
}
writer.Close()
}
})
}
}
func TestNull(t *testing.T) {
t.Run("test null", func(t *testing.T) {
schema := generateTestSchema()