mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-12-05 05:18:52 +08:00
89397d1e66
#36252 remove no need type check. if users use null type writer to write parquet, hope it successfully. Signed-off-by: lixinguo <xinguo.li@zilliz.com> Co-authored-by: lixinguo <xinguo.li@zilliz.com>
252 lines
7.5 KiB
Go
252 lines
7.5 KiB
Go
// Licensed to the LF AI & Data foundation under one
|
|
// or more contributor license agreements. See the NOTICE file
|
|
// distributed with this work for additional information
|
|
// regarding copyright ownership. The ASF licenses this file
|
|
// to you under the Apache License, Version 2.0 (the
|
|
// "License"); you may not use this file except in compliance
|
|
// with the License. You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package importv2
|
|
|
|
import (
|
|
"context"
|
|
"encoding/csv"
|
|
"encoding/json"
|
|
"fmt"
|
|
"os"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/apache/arrow/go/v12/arrow/array"
|
|
"github.com/apache/arrow/go/v12/parquet"
|
|
"github.com/apache/arrow/go/v12/parquet/pqarrow"
|
|
"github.com/samber/lo"
|
|
"github.com/sbinet/npyio"
|
|
"github.com/stretchr/testify/assert"
|
|
"go.uber.org/zap"
|
|
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
|
"github.com/milvus-io/milvus/internal/proto/datapb"
|
|
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
|
"github.com/milvus-io/milvus/internal/storage"
|
|
pq "github.com/milvus-io/milvus/internal/util/importutilv2/parquet"
|
|
"github.com/milvus-io/milvus/internal/util/testutil"
|
|
"github.com/milvus-io/milvus/pkg/log"
|
|
"github.com/milvus-io/milvus/pkg/util/merr"
|
|
"github.com/milvus-io/milvus/tests/integration"
|
|
)
|
|
|
|
const dim = 128
|
|
|
|
func CheckLogID(fieldBinlogs []*datapb.FieldBinlog) error {
|
|
for _, fieldBinlog := range fieldBinlogs {
|
|
for _, l := range fieldBinlog.GetBinlogs() {
|
|
if l.GetLogID() == 0 {
|
|
return fmt.Errorf("unexpected log id 0")
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func GenerateParquetFile(filePath string, schema *schemapb.CollectionSchema, numRows int) error {
|
|
_, err := GenerateParquetFileAndReturnInsertData(filePath, schema, numRows)
|
|
return err
|
|
}
|
|
|
|
func GenerateParquetFileAndReturnInsertData(filePath string, schema *schemapb.CollectionSchema, numRows int) (*storage.InsertData, error) {
|
|
w, err := os.OpenFile(filePath, os.O_RDWR|os.O_CREATE, 0o666)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
pqSchema, err := pq.ConvertToArrowSchema(schema, false)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
fw, err := pqarrow.NewFileWriter(pqSchema, w, parquet.NewWriterProperties(parquet.WithMaxRowGroupLength(int64(numRows))), pqarrow.DefaultWriterProps())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer fw.Close()
|
|
|
|
insertData, err := testutil.CreateInsertData(schema, numRows)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
columns, err := testutil.BuildArrayData(schema, insertData, false)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
recordBatch := array.NewRecord(pqSchema, columns, int64(numRows))
|
|
return insertData, fw.Write(recordBatch)
|
|
}
|
|
|
|
func GenerateNumpyFiles(cm storage.ChunkManager, schema *schemapb.CollectionSchema, rowCount int) (*internalpb.ImportFile, error) {
|
|
writeFn := func(path string, data interface{}) error {
|
|
f, err := os.Create(path)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer f.Close()
|
|
|
|
err = npyio.Write(f, data)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
insertData, err := testutil.CreateInsertData(schema, rowCount)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var data interface{}
|
|
paths := make([]string, 0)
|
|
for _, field := range schema.GetFields() {
|
|
if field.GetAutoID() && field.GetIsPrimaryKey() {
|
|
continue
|
|
}
|
|
path := fmt.Sprintf("%s/%s.npy", cm.RootPath(), field.GetName())
|
|
|
|
fieldID := field.GetFieldID()
|
|
fieldData := insertData.Data[fieldID]
|
|
dType := field.GetDataType()
|
|
switch dType {
|
|
case schemapb.DataType_BinaryVector:
|
|
rows := fieldData.GetDataRows().([]byte)
|
|
if dim != fieldData.(*storage.BinaryVectorFieldData).Dim {
|
|
panic(fmt.Sprintf("dim mis-match: %d, %d", dim, fieldData.(*storage.BinaryVectorFieldData).Dim))
|
|
}
|
|
const rowBytes = dim / 8
|
|
chunked := lo.Chunk(rows, rowBytes)
|
|
chunkedRows := make([][rowBytes]byte, len(chunked))
|
|
for i, innerSlice := range chunked {
|
|
copy(chunkedRows[i][:], innerSlice)
|
|
}
|
|
data = chunkedRows
|
|
case schemapb.DataType_FloatVector:
|
|
rows := fieldData.GetDataRows().([]float32)
|
|
if dim != fieldData.(*storage.FloatVectorFieldData).Dim {
|
|
panic(fmt.Sprintf("dim mis-match: %d, %d", dim, fieldData.(*storage.FloatVectorFieldData).Dim))
|
|
}
|
|
chunked := lo.Chunk(rows, dim)
|
|
chunkedRows := make([][dim]float32, len(chunked))
|
|
for i, innerSlice := range chunked {
|
|
copy(chunkedRows[i][:], innerSlice)
|
|
}
|
|
data = chunkedRows
|
|
case schemapb.DataType_Float16Vector:
|
|
rows := insertData.Data[fieldID].GetDataRows().([]byte)
|
|
if dim != fieldData.(*storage.Float16VectorFieldData).Dim {
|
|
panic(fmt.Sprintf("dim mis-match: %d, %d", dim, fieldData.(*storage.Float16VectorFieldData).Dim))
|
|
}
|
|
const rowBytes = dim * 2
|
|
chunked := lo.Chunk(rows, rowBytes)
|
|
chunkedRows := make([][rowBytes]byte, len(chunked))
|
|
for i, innerSlice := range chunked {
|
|
copy(chunkedRows[i][:], innerSlice)
|
|
}
|
|
data = chunkedRows
|
|
case schemapb.DataType_BFloat16Vector:
|
|
rows := insertData.Data[fieldID].GetDataRows().([]byte)
|
|
if dim != fieldData.(*storage.BFloat16VectorFieldData).Dim {
|
|
panic(fmt.Sprintf("dim mis-match: %d, %d", dim, fieldData.(*storage.BFloat16VectorFieldData).Dim))
|
|
}
|
|
const rowBytes = dim * 2
|
|
chunked := lo.Chunk(rows, rowBytes)
|
|
chunkedRows := make([][rowBytes]byte, len(chunked))
|
|
for i, innerSlice := range chunked {
|
|
copy(chunkedRows[i][:], innerSlice)
|
|
}
|
|
data = chunkedRows
|
|
case schemapb.DataType_SparseFloatVector:
|
|
data = insertData.Data[fieldID].(*storage.SparseFloatVectorFieldData).GetContents()
|
|
default:
|
|
data = insertData.Data[fieldID].GetDataRows()
|
|
}
|
|
|
|
err := writeFn(path, data)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
paths = append(paths, path)
|
|
}
|
|
return &internalpb.ImportFile{
|
|
Paths: paths,
|
|
}, nil
|
|
}
|
|
|
|
func GenerateJSONFile(t *testing.T, filePath string, schema *schemapb.CollectionSchema, count int) {
|
|
insertData, err := testutil.CreateInsertData(schema, count)
|
|
assert.NoError(t, err)
|
|
|
|
rows, err := testutil.CreateInsertDataRowsForJSON(schema, insertData)
|
|
assert.NoError(t, err)
|
|
|
|
jsonBytes, err := json.Marshal(rows)
|
|
assert.NoError(t, err)
|
|
|
|
err = os.WriteFile(filePath, jsonBytes, 0o644) // nolint
|
|
assert.NoError(t, err)
|
|
}
|
|
|
|
func GenerateCSVFile(t *testing.T, filePath string, schema *schemapb.CollectionSchema, count int) rune {
|
|
insertData, err := testutil.CreateInsertData(schema, count)
|
|
assert.NoError(t, err)
|
|
|
|
sep := ','
|
|
nullkey := ""
|
|
|
|
csvData, err := testutil.CreateInsertDataForCSV(schema, insertData, nullkey)
|
|
assert.NoError(t, err)
|
|
|
|
wf, err := os.OpenFile(filePath, os.O_RDWR|os.O_CREATE, 0o666)
|
|
assert.NoError(t, err)
|
|
|
|
writer := csv.NewWriter(wf)
|
|
writer.Comma = sep
|
|
writer.WriteAll(csvData)
|
|
writer.Flush()
|
|
assert.NoError(t, err)
|
|
|
|
return sep
|
|
}
|
|
|
|
func WaitForImportDone(ctx context.Context, c *integration.MiniClusterV2, jobID string) error {
|
|
for {
|
|
resp, err := c.Proxy.GetImportProgress(ctx, &internalpb.GetImportProgressRequest{
|
|
JobID: jobID,
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if err = merr.Error(resp.GetStatus()); err != nil {
|
|
return err
|
|
}
|
|
switch resp.GetState() {
|
|
case internalpb.ImportJobState_Completed:
|
|
return nil
|
|
case internalpb.ImportJobState_Failed:
|
|
return merr.WrapErrImportFailed(resp.GetReason())
|
|
default:
|
|
log.Info("import progress", zap.String("jobID", jobID),
|
|
zap.Int64("progress", resp.GetProgress()),
|
|
zap.String("state", resp.GetState().String()))
|
|
time.Sleep(1 * time.Second)
|
|
}
|
|
}
|
|
}
|