mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-11-29 18:38:44 +08:00
Support import numpy file (#16348)
Signed-off-by: groot <yihua.mo@zilliz.com>
This commit is contained in:
parent
6682d1b635
commit
be8d9a8b6b
5
go.mod
5
go.mod
@ -32,6 +32,8 @@ require (
|
||||
github.com/pierrec/lz4 v2.5.2+incompatible // indirect
|
||||
github.com/pkg/errors v0.9.1
|
||||
github.com/prometheus/client_golang v1.11.0
|
||||
github.com/quasilyte/go-ruleguard v0.2.1 // indirect
|
||||
github.com/sbinet/npyio v0.6.0
|
||||
github.com/shirou/gopsutil v3.21.8+incompatible
|
||||
github.com/spaolacci/murmur3 v1.1.0
|
||||
github.com/spf13/cast v1.3.1
|
||||
@ -59,4 +61,5 @@ replace (
|
||||
github.com/dgrijalva/jwt-go => github.com/golang-jwt/jwt v3.2.2+incompatible // Fix security alert for jwt-go 3.2.0
|
||||
github.com/keybase/go-keychain => github.com/99designs/go-keychain v0.0.0-20191008050251-8e49817e8af4
|
||||
google.golang.org/grpc => google.golang.org/grpc v1.38.0
|
||||
)
|
||||
)
|
||||
|
||||
|
6
go.sum
6
go.sum
@ -110,6 +110,7 @@ github.com/bketelsen/crypt v0.0.3-0.20200106085610-5cbc8cc4026c/go.mod h1:MKsuJm
|
||||
github.com/bketelsen/crypt v0.0.4/go.mod h1:aI6NrJ0pMGgvZKL1iVgXLnfIFJtfV+bKCoqOes/6LfM=
|
||||
github.com/bmizerany/perks v0.0.0-20141205001514-d9a9656a3a4b/go.mod h1:ac9efd0D1fsDb3EJvhqgXRbFx7bs2wqZ10HQPeU8U/Q=
|
||||
github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
|
||||
github.com/campoy/embedmd v1.0.0/go.mod h1:oxyr9RCiSXg0M3VJ3ks0UGfp98BpSSGr0kpiX3MzVl8=
|
||||
github.com/casbin/casbin/v2 v2.1.2/go.mod h1:YcPU1XXisHhLzuxH9coDNf2FbKpjGlbCg3n9yuLkIJQ=
|
||||
github.com/cenkalti/backoff v2.2.1+incompatible/go.mod h1:90ReRw6GdpyfrHakVjL/QHaoyV4aDUVVkXQJJJ3NXXM=
|
||||
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
|
||||
@ -584,6 +585,8 @@ github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4O
|
||||
github.com/prometheus/procfs v0.6.0 h1:mxy4L2jP6qMonqmq+aTtOx1ifVWUgG/TAmntgbh3xv4=
|
||||
github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA=
|
||||
github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU=
|
||||
github.com/quasilyte/go-ruleguard v0.2.1 h1:56eRm0daAyny9UhJnmtJW/UyLZQusukBAB8oT8AHKHo=
|
||||
github.com/quasilyte/go-ruleguard v0.2.1/go.mod h1:hN2rVc/uS4bQhQKTio2XaSJSafJwqBUWWwtssT3cQmc=
|
||||
github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=
|
||||
github.com/rivo/tview v0.0.0-20200219210816-cd38d7432498/go.mod h1:6lkG1x+13OShEf0EaOCaTQYyB7d5nSbb181KtjlS+84=
|
||||
github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
@ -598,6 +601,8 @@ github.com/ruudk/golang-pdf417 v0.0.0-20181029194003-1af4ab5afa58/go.mod h1:6lfF
|
||||
github.com/ryanuber/columnize v0.0.0-20160712163229-9b3edd62028f/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts=
|
||||
github.com/samuel/go-zookeeper v0.0.0-20190923202752-2cc03de413da/go.mod h1:gi+0XIa01GRL2eRQVjQkKGqKF3SF9vZR/HnPullcV2E=
|
||||
github.com/sanity-io/litter v1.2.0/go.mod h1:JF6pZUFgu2Q0sBZ+HSV35P8TVPI1TTzEwyu9FXAw2W4=
|
||||
github.com/sbinet/npyio v0.6.0 h1:IyqqQIzRjDym9xnIXsToCKei/qCzxDP+Y74KoMlMgXo=
|
||||
github.com/sbinet/npyio v0.6.0/go.mod h1:/q3BNr6dJOy+t6h7RZchTJ0nwRJO52mivaem29WE1j8=
|
||||
github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529/go.mod h1:DxrIzT+xaE7yg65j358z/aeFdxmN0P9QXhEzd20vsDc=
|
||||
github.com/shirou/gopsutil v3.21.8+incompatible h1:sh0foI8tMRlCidUJR+KzqWYWxrkuuPIGiO6Vp+KXdCU=
|
||||
github.com/shirou/gopsutil v3.21.8+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA=
|
||||
@ -1038,6 +1043,7 @@ golang.org/x/tools v0.0.0-20200618134242-20370b0cb4b2/go.mod h1:EkVYQZoAsY45+roY
|
||||
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
|
||||
golang.org/x/tools v0.0.0-20200729194436-6467de6f59a7/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA=
|
||||
golang.org/x/tools v0.0.0-20200804011535-6c149bb5ef0d/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA=
|
||||
golang.org/x/tools v0.0.0-20200812195022-5ae4c3c160a0/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA=
|
||||
golang.org/x/tools v0.0.0-20200825202427-b303f430e36d/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA=
|
||||
golang.org/x/tools v0.0.0-20200904185747-39188db58858/go.mod h1:Cj7w3i3Rnn0Xh82ur9kSqwfTHTeVxaDqrfMjpcNT6bE=
|
||||
golang.org/x/tools v0.0.0-20201110124207-079ba7bd75cd/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
|
||||
|
@ -7,6 +7,7 @@ import (
|
||||
"os"
|
||||
"path"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/allocator"
|
||||
"github.com/milvus-io/milvus/internal/common"
|
||||
@ -89,6 +90,13 @@ func (p *ImportWrapper) printFieldsDataInfo(fieldsData map[string]storage.FieldD
|
||||
log.Debug(msg, stats...)
|
||||
}
|
||||
|
||||
func getFileNameAndExt(filePath string) (string, string) {
|
||||
fileName := path.Base(filePath)
|
||||
fileType := path.Ext(fileName)
|
||||
fileNameWithoutExt := strings.TrimSuffix(fileName, fileType)
|
||||
return fileNameWithoutExt, fileType
|
||||
}
|
||||
|
||||
// import process entry
|
||||
// filePath and rowBased are from ImportTask
|
||||
// if onlyValidate is true, this process only do validation, no data generated, callFlushFunc will not be called
|
||||
@ -99,8 +107,7 @@ func (p *ImportWrapper) Import(filePaths []string, rowBased bool, onlyValidate b
|
||||
// according to shard number, so the callFlushFunc will be called in the JSONRowConsumer
|
||||
for i := 0; i < len(filePaths); i++ {
|
||||
filePath := filePaths[i]
|
||||
fileName := path.Base(filePath)
|
||||
fileType := path.Ext(fileName)
|
||||
_, fileType := getFileNameAndExt(filePath)
|
||||
log.Debug("imprort wrapper: row-based file ", zap.Any("filePath", filePath), zap.Any("fileType", fileType))
|
||||
|
||||
if fileType == JSONFileExt {
|
||||
@ -183,8 +190,7 @@ func (p *ImportWrapper) Import(filePaths []string, rowBased bool, onlyValidate b
|
||||
// parse/validate/consume data
|
||||
for i := 0; i < len(filePaths); i++ {
|
||||
filePath := filePaths[i]
|
||||
fileName := path.Base(filePath)
|
||||
fileType := path.Ext(fileName)
|
||||
fileName, fileType := getFileNameAndExt(filePath)
|
||||
log.Debug("imprort wrapper: column-based file ", zap.Any("filePath", filePath), zap.Any("fileType", fileType))
|
||||
|
||||
if fileType == JSONFileExt {
|
||||
@ -218,7 +224,28 @@ func (p *ImportWrapper) Import(filePaths []string, rowBased bool, onlyValidate b
|
||||
return err
|
||||
}
|
||||
} else if fileType == NumpyFileExt {
|
||||
file, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
log.Error("imprort error: "+err.Error(), zap.String("filePath", filePath))
|
||||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// the numpy parser return a storage.FieldData, here construct a map[string]storage.FieldData to combine
|
||||
flushFunc := func(field storage.FieldData) error {
|
||||
fields := make(map[string]storage.FieldData)
|
||||
fields[fileName] = field
|
||||
combineFunc(fields)
|
||||
return nil
|
||||
}
|
||||
|
||||
// for numpy file, we say the file name(without extension) is the filed name
|
||||
parser := NewNumpyParser(p.ctx, p.collectionSchema, flushFunc)
|
||||
err = parser.Parse(file, fileName, onlyValidate)
|
||||
if err != nil {
|
||||
log.Error("imprort error: "+err.Error(), zap.String("filePath", filePath))
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -126,7 +126,7 @@ func Test_ImportRowBased(t *testing.T) {
|
||||
|
||||
}
|
||||
|
||||
func Test_ImportColumnBased(t *testing.T) {
|
||||
func Test_ImportColumnBased_json(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
err := os.MkdirAll(TempFilesPath, os.ModePerm)
|
||||
assert.Nil(t, err)
|
||||
@ -208,6 +208,88 @@ func Test_ImportColumnBased(t *testing.T) {
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
func Test_ImportColumnBased_numpy(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
err := os.MkdirAll(TempFilesPath, os.ModePerm)
|
||||
assert.Nil(t, err)
|
||||
defer os.RemoveAll(TempFilesPath)
|
||||
|
||||
idAllocator := newIDAllocator(ctx, t)
|
||||
|
||||
content := []byte(`{
|
||||
"field_bool": [true, false, true, true, true],
|
||||
"field_int8": [10, 11, 12, 13, 14],
|
||||
"field_int16": [100, 101, 102, 103, 104],
|
||||
"field_int32": [1000, 1001, 1002, 1003, 1004],
|
||||
"field_int64": [10000, 10001, 10002, 10003, 10004],
|
||||
"field_float": [3.14, 3.15, 3.16, 3.17, 3.18],
|
||||
"field_double": [5.1, 5.2, 5.3, 5.4, 5.5],
|
||||
"field_string": ["a", "b", "c", "d", "e"]
|
||||
}`)
|
||||
|
||||
files := make([]string, 0)
|
||||
|
||||
filePath := TempFilesPath + "scalar_fields.json"
|
||||
fp1 := saveFile(t, filePath, content)
|
||||
fp1.Close()
|
||||
files = append(files, filePath)
|
||||
|
||||
filePath = TempFilesPath + "field_binary_vector.npy"
|
||||
bin := [][2]uint8{{1, 2}, {3, 4}, {5, 6}, {7, 8}, {9, 10}}
|
||||
err = CreateNumpyFile(filePath, bin)
|
||||
assert.Nil(t, err)
|
||||
files = append(files, filePath)
|
||||
|
||||
filePath = TempFilesPath + "field_float_vector.npy"
|
||||
flo := [][4]float32{{1, 2, 3, 4}, {3, 4, 5, 6}, {5, 6, 7, 8}, {7, 8, 9, 10}, {9, 10, 11, 12}}
|
||||
err = CreateNumpyFile(filePath, flo)
|
||||
assert.Nil(t, err)
|
||||
files = append(files, filePath)
|
||||
|
||||
rowCount := 0
|
||||
flushFunc := func(fields map[string]storage.FieldData) error {
|
||||
count := 0
|
||||
for _, data := range fields {
|
||||
assert.Less(t, 0, data.RowNum())
|
||||
if count == 0 {
|
||||
count = data.RowNum()
|
||||
} else {
|
||||
assert.Equal(t, count, data.RowNum())
|
||||
}
|
||||
}
|
||||
rowCount += count
|
||||
return nil
|
||||
}
|
||||
|
||||
// success case
|
||||
wrapper := NewImportWrapper(ctx, sampleSchema(), 2, 1, idAllocator, flushFunc)
|
||||
|
||||
err = wrapper.Import(files, false, false)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 5, rowCount)
|
||||
|
||||
// parse error
|
||||
content = []byte(`{
|
||||
"field_bool": [true, false, true, true, true]
|
||||
}`)
|
||||
|
||||
filePath = TempFilesPath + "rows_2.json"
|
||||
fp2 := saveFile(t, filePath, content)
|
||||
defer fp2.Close()
|
||||
|
||||
wrapper = NewImportWrapper(ctx, sampleSchema(), 2, 1, idAllocator, flushFunc)
|
||||
files = make([]string, 0)
|
||||
files = append(files, filePath)
|
||||
err = wrapper.Import(files, false, false)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
// file doesn't exist
|
||||
files = make([]string, 0)
|
||||
files = append(files, "/dummy/dummy.json")
|
||||
err = wrapper.Import(files, false, false)
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
func perfSchema(dim int) *schemapb.CollectionSchema {
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Name: "schema",
|
||||
|
@ -24,7 +24,7 @@ type JSONParser struct {
|
||||
fields map[string]int64 // fields need to be parsed
|
||||
}
|
||||
|
||||
// newImportManager helper function to create a importManager
|
||||
// NewJSONParser helper function to create a JSONParser
|
||||
func NewJSONParser(ctx context.Context, collectionSchema *schemapb.CollectionSchema) *JSONParser {
|
||||
fields := make(map[string]int64)
|
||||
for i := 0; i < len(collectionSchema.Fields); i++ {
|
||||
|
356
internal/util/importutil/numpy_adapter.go
Normal file
356
internal/util/importutil/numpy_adapter.go
Normal file
@ -0,0 +1,356 @@
|
||||
package importutil
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"io"
|
||||
"os"
|
||||
|
||||
"github.com/sbinet/npyio"
|
||||
"github.com/sbinet/npyio/npy"
|
||||
)
|
||||
|
||||
func CreateNumpyFile(path string, data interface{}) error {
|
||||
f, err := os.Create(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
err = npyio.Write(f, data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// a class to expand other numpy lib ability
|
||||
// we evaluate two go-numpy lins: github.com/kshedden/gonpy and github.com/sbinet/npyio
|
||||
// the npyio lib read data one by one, the performance is poor, we expand the read methods
|
||||
// to read data in one batch, the performance is 100X faster
|
||||
// the gonpy lib also read data in one batch, but it has no method to read bool data, and the ability
|
||||
// to handle different data type is not strong as the npylib, so we choose the npyio lib to expand.
|
||||
type NumpyAdapter struct {
|
||||
reader io.Reader // data source, typically is os.File
|
||||
npyReader *npy.Reader // reader of npyio lib
|
||||
order binary.ByteOrder // LittleEndian or BigEndian
|
||||
readPosition int // how many elements have been read
|
||||
}
|
||||
|
||||
func NewNumpyAdapter(reader io.Reader) (*NumpyAdapter, error) {
|
||||
r, err := npyio.NewReader(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
adapter := &NumpyAdapter{
|
||||
reader: reader,
|
||||
npyReader: r,
|
||||
readPosition: 0,
|
||||
}
|
||||
adapter.setByteOrder()
|
||||
|
||||
return adapter, err
|
||||
}
|
||||
|
||||
// the logic of this method is copied from npyio lib
|
||||
func (n *NumpyAdapter) setByteOrder() {
|
||||
var nativeEndian binary.ByteOrder
|
||||
v := uint16(1)
|
||||
switch byte(v >> 8) {
|
||||
case 0:
|
||||
nativeEndian = binary.LittleEndian
|
||||
case 1:
|
||||
nativeEndian = binary.BigEndian
|
||||
}
|
||||
|
||||
switch n.npyReader.Header.Descr.Type[0] {
|
||||
case '<':
|
||||
n.order = binary.LittleEndian
|
||||
case '>':
|
||||
n.order = binary.BigEndian
|
||||
default:
|
||||
n.order = nativeEndian
|
||||
}
|
||||
}
|
||||
|
||||
func (n *NumpyAdapter) Reader() io.Reader {
|
||||
return n.reader
|
||||
}
|
||||
|
||||
func (n *NumpyAdapter) NpyReader() *npy.Reader {
|
||||
return n.npyReader
|
||||
}
|
||||
|
||||
func (n *NumpyAdapter) GetType() string {
|
||||
return n.npyReader.Header.Descr.Type
|
||||
}
|
||||
|
||||
func (n *NumpyAdapter) GetShape() []int {
|
||||
return n.npyReader.Header.Descr.Shape
|
||||
}
|
||||
|
||||
func (n *NumpyAdapter) checkSize(size int) int {
|
||||
shape := n.GetShape()
|
||||
|
||||
// empty file?
|
||||
if len(shape) == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
total := 1
|
||||
for i := 0; i < len(shape); i++ {
|
||||
total *= shape[i]
|
||||
}
|
||||
|
||||
if total == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
// overflow?
|
||||
if size > (total - n.readPosition) {
|
||||
return total - n.readPosition
|
||||
}
|
||||
|
||||
return size
|
||||
}
|
||||
|
||||
func (n *NumpyAdapter) ReadBool(size int) ([]bool, error) {
|
||||
if n.npyReader == nil {
|
||||
return nil, errors.New("reader is not initialized")
|
||||
}
|
||||
|
||||
// incorrect type
|
||||
switch n.npyReader.Header.Descr.Type {
|
||||
case "b1", "<b1", "|b1", "bool":
|
||||
default:
|
||||
return nil, errors.New("numpy data is not bool type")
|
||||
}
|
||||
|
||||
// avoid read overflow
|
||||
readSize := n.checkSize(size)
|
||||
if readSize <= 0 {
|
||||
return nil, errors.New("nothing to read")
|
||||
}
|
||||
|
||||
data := make([]bool, readSize)
|
||||
err := binary.Read(n.reader, n.order, &data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// update read position after successfully read
|
||||
n.readPosition += readSize
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (n *NumpyAdapter) ReadUint8(size int) ([]uint8, error) {
|
||||
if n.npyReader == nil {
|
||||
return nil, errors.New("reader is not initialized")
|
||||
}
|
||||
|
||||
// incorrect type
|
||||
switch n.npyReader.Header.Descr.Type {
|
||||
case "u1", "<u1", "|u1", "uint8":
|
||||
default:
|
||||
return nil, errors.New("numpy data is not uint8 type")
|
||||
}
|
||||
|
||||
// avoid read overflow
|
||||
readSize := n.checkSize(size)
|
||||
if readSize <= 0 {
|
||||
return nil, errors.New("nothing to read")
|
||||
}
|
||||
|
||||
data := make([]uint8, readSize)
|
||||
err := binary.Read(n.reader, n.order, &data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// update read position after successfully read
|
||||
n.readPosition += readSize
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (n *NumpyAdapter) ReadInt8(size int) ([]int8, error) {
|
||||
if n.npyReader == nil {
|
||||
return nil, errors.New("reader is not initialized")
|
||||
}
|
||||
|
||||
// incorrect type
|
||||
switch n.npyReader.Header.Descr.Type {
|
||||
case "i1", "<i1", "|i1", ">i1", "int8":
|
||||
default:
|
||||
return nil, errors.New("numpy data is not int8 type")
|
||||
}
|
||||
|
||||
// avoid read overflow
|
||||
readSize := n.checkSize(size)
|
||||
if readSize <= 0 {
|
||||
return nil, errors.New("nothing to read")
|
||||
}
|
||||
|
||||
data := make([]int8, readSize)
|
||||
err := binary.Read(n.reader, n.order, &data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// update read position after successfully read
|
||||
n.readPosition += readSize
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (n *NumpyAdapter) ReadInt16(size int) ([]int16, error) {
|
||||
if n.npyReader == nil {
|
||||
return nil, errors.New("reader is not initialized")
|
||||
}
|
||||
|
||||
// incorrect type
|
||||
switch n.npyReader.Header.Descr.Type {
|
||||
case "i2", "<i2", "|i2", ">i2", "int16":
|
||||
default:
|
||||
return nil, errors.New("numpy data is not int16 type")
|
||||
}
|
||||
|
||||
// avoid read overflow
|
||||
readSize := n.checkSize(size)
|
||||
if readSize <= 0 {
|
||||
return nil, errors.New("nothing to read")
|
||||
}
|
||||
|
||||
data := make([]int16, readSize)
|
||||
err := binary.Read(n.reader, n.order, &data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// update read position after successfully read
|
||||
n.readPosition += readSize
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (n *NumpyAdapter) ReadInt32(size int) ([]int32, error) {
|
||||
if n.npyReader == nil {
|
||||
return nil, errors.New("reader is not initialized")
|
||||
}
|
||||
|
||||
// incorrect type
|
||||
switch n.npyReader.Header.Descr.Type {
|
||||
case "i4", "<i4", "|i4", ">i4", "int32":
|
||||
default:
|
||||
return nil, errors.New("numpy data is not int32 type")
|
||||
}
|
||||
|
||||
// avoid read overflow
|
||||
readSize := n.checkSize(size)
|
||||
if readSize <= 0 {
|
||||
return nil, errors.New("nothing to read")
|
||||
}
|
||||
|
||||
data := make([]int32, readSize)
|
||||
err := binary.Read(n.reader, n.order, &data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// update read position after successfully read
|
||||
n.readPosition += readSize
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (n *NumpyAdapter) ReadInt64(size int) ([]int64, error) {
|
||||
if n.npyReader == nil {
|
||||
return nil, errors.New("reader is not initialized")
|
||||
}
|
||||
|
||||
// incorrect type
|
||||
switch n.npyReader.Header.Descr.Type {
|
||||
case "i8", "<i8", "|i8", ">i8", "int64":
|
||||
default:
|
||||
return nil, errors.New("numpy data is not int64 type")
|
||||
}
|
||||
|
||||
// avoid read overflow
|
||||
readSize := n.checkSize(size)
|
||||
if readSize <= 0 {
|
||||
return nil, errors.New("nothing to read")
|
||||
}
|
||||
|
||||
data := make([]int64, readSize)
|
||||
err := binary.Read(n.reader, n.order, &data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// update read position after successfully read
|
||||
n.readPosition += readSize
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (n *NumpyAdapter) ReadFloat32(size int) ([]float32, error) {
|
||||
if n.npyReader == nil {
|
||||
return nil, errors.New("reader is not initialized")
|
||||
}
|
||||
|
||||
// incorrect type
|
||||
switch n.npyReader.Header.Descr.Type {
|
||||
case "f4", "<f4", "|f4", ">f4", "float32":
|
||||
default:
|
||||
return nil, errors.New("numpy data is not float32 type")
|
||||
}
|
||||
|
||||
// avoid read overflow
|
||||
readSize := n.checkSize(size)
|
||||
if readSize <= 0 {
|
||||
return nil, errors.New("nothing to read")
|
||||
}
|
||||
|
||||
data := make([]float32, readSize)
|
||||
err := binary.Read(n.reader, n.order, &data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// update read position after successfully read
|
||||
n.readPosition += readSize
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (n *NumpyAdapter) ReadFloat64(size int) ([]float64, error) {
|
||||
if n.npyReader == nil {
|
||||
return nil, errors.New("reader is not initialized")
|
||||
}
|
||||
|
||||
// incorrect type
|
||||
switch n.npyReader.Header.Descr.Type {
|
||||
case "f8", "<f8", "|f8", ">f8", "float64":
|
||||
default:
|
||||
return nil, errors.New("numpy data is not float32 type")
|
||||
}
|
||||
|
||||
// avoid read overflow
|
||||
readSize := n.checkSize(size)
|
||||
if readSize <= 0 {
|
||||
return nil, errors.New("nothing to read")
|
||||
}
|
||||
|
||||
data := make([]float64, readSize)
|
||||
err := binary.Read(n.reader, n.order, &data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// update read position after successfully read
|
||||
n.readPosition += readSize
|
||||
|
||||
return data, nil
|
||||
}
|
455
internal/util/importutil/numpy_adapter_test.go
Normal file
455
internal/util/importutil/numpy_adapter_test.go
Normal file
@ -0,0 +1,455 @@
|
||||
package importutil
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/sbinet/npyio/npy"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type MockReader struct {
|
||||
}
|
||||
|
||||
func (r *MockReader) Read(p []byte) (n int, err error) {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
func Test_CreateNumpyFile(t *testing.T) {
|
||||
// directory doesn't exist
|
||||
data1 := []float32{1, 2, 3, 4, 5}
|
||||
err := CreateNumpyFile("/dummy_not_exist/dummy.npy", data1)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
// invalid data type
|
||||
data2 := make(map[string]int)
|
||||
err = CreateNumpyFile("/tmp/dummy.npy", data2)
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
func Test_SetByteOrder(t *testing.T) {
|
||||
adapter := &NumpyAdapter{
|
||||
reader: nil,
|
||||
npyReader: &npy.Reader{},
|
||||
}
|
||||
assert.Nil(t, adapter.Reader())
|
||||
assert.NotNil(t, adapter.NpyReader())
|
||||
|
||||
adapter.npyReader.Header.Descr.Type = "<i8"
|
||||
adapter.setByteOrder()
|
||||
assert.Equal(t, binary.LittleEndian, adapter.order)
|
||||
|
||||
adapter.npyReader.Header.Descr.Type = ">i8"
|
||||
adapter.setByteOrder()
|
||||
assert.Equal(t, binary.BigEndian, adapter.order)
|
||||
}
|
||||
|
||||
func Test_ReadError(t *testing.T) {
|
||||
adapter := &NumpyAdapter{
|
||||
reader: nil,
|
||||
npyReader: nil,
|
||||
}
|
||||
|
||||
// reader is nil
|
||||
{
|
||||
_, err := adapter.ReadBool(1)
|
||||
assert.NotNil(t, err)
|
||||
_, err = adapter.ReadUint8(1)
|
||||
assert.NotNil(t, err)
|
||||
_, err = adapter.ReadInt8(1)
|
||||
assert.NotNil(t, err)
|
||||
_, err = adapter.ReadInt16(1)
|
||||
assert.NotNil(t, err)
|
||||
_, err = adapter.ReadInt32(1)
|
||||
assert.NotNil(t, err)
|
||||
_, err = adapter.ReadInt64(1)
|
||||
assert.NotNil(t, err)
|
||||
_, err = adapter.ReadFloat32(1)
|
||||
assert.NotNil(t, err)
|
||||
_, err = adapter.ReadFloat64(1)
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
adapter = &NumpyAdapter{
|
||||
reader: &MockReader{},
|
||||
npyReader: &npy.Reader{},
|
||||
}
|
||||
|
||||
{
|
||||
adapter.npyReader.Header.Descr.Type = "bool"
|
||||
data, err := adapter.ReadBool(1)
|
||||
assert.Nil(t, data)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
adapter.npyReader.Header.Descr.Type = "dummy"
|
||||
data, err = adapter.ReadBool(1)
|
||||
assert.Nil(t, data)
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
{
|
||||
adapter.npyReader.Header.Descr.Type = "u1"
|
||||
data, err := adapter.ReadUint8(1)
|
||||
assert.Nil(t, data)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
adapter.npyReader.Header.Descr.Type = "dummy"
|
||||
data, err = adapter.ReadUint8(1)
|
||||
assert.Nil(t, data)
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
{
|
||||
adapter.npyReader.Header.Descr.Type = "i1"
|
||||
data, err := adapter.ReadInt8(1)
|
||||
assert.Nil(t, data)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
adapter.npyReader.Header.Descr.Type = "dummy"
|
||||
data, err = adapter.ReadInt8(1)
|
||||
assert.Nil(t, data)
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
{
|
||||
adapter.npyReader.Header.Descr.Type = "i2"
|
||||
data, err := adapter.ReadInt16(1)
|
||||
assert.Nil(t, data)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
adapter.npyReader.Header.Descr.Type = "dummy"
|
||||
data, err = adapter.ReadInt16(1)
|
||||
assert.Nil(t, data)
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
{
|
||||
adapter.npyReader.Header.Descr.Type = "i4"
|
||||
data, err := adapter.ReadInt32(1)
|
||||
assert.Nil(t, data)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
adapter.npyReader.Header.Descr.Type = "dummy"
|
||||
data, err = adapter.ReadInt32(1)
|
||||
assert.Nil(t, data)
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
{
|
||||
adapter.npyReader.Header.Descr.Type = "i8"
|
||||
data, err := adapter.ReadInt64(1)
|
||||
assert.Nil(t, data)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
adapter.npyReader.Header.Descr.Type = "dummy"
|
||||
data, err = adapter.ReadInt64(1)
|
||||
assert.Nil(t, data)
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
{
|
||||
adapter.npyReader.Header.Descr.Type = "f4"
|
||||
data, err := adapter.ReadFloat32(1)
|
||||
assert.Nil(t, data)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
adapter.npyReader.Header.Descr.Type = "dummy"
|
||||
data, err = adapter.ReadFloat32(1)
|
||||
assert.Nil(t, data)
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
{
|
||||
adapter.npyReader.Header.Descr.Type = "f8"
|
||||
data, err := adapter.ReadFloat64(1)
|
||||
assert.Nil(t, data)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
adapter.npyReader.Header.Descr.Type = "dummy"
|
||||
data, err = adapter.ReadFloat64(1)
|
||||
assert.Nil(t, data)
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Read(t *testing.T) {
|
||||
err := os.MkdirAll(TempFilesPath, os.ModePerm)
|
||||
assert.Nil(t, err)
|
||||
defer os.RemoveAll(TempFilesPath)
|
||||
|
||||
{
|
||||
filePath := TempFilesPath + "bool.npy"
|
||||
data := []bool{true, false, true, false}
|
||||
CreateNumpyFile(filePath, data)
|
||||
|
||||
file, err := os.Open(filePath)
|
||||
assert.Nil(t, err)
|
||||
defer file.Close()
|
||||
|
||||
adapter, err := NewNumpyAdapter(file)
|
||||
assert.Nil(t, err)
|
||||
|
||||
res, err := adapter.ReadBool(len(data) - 1)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, len(data)-1, len(res))
|
||||
|
||||
for i := 0; i < len(res); i++ {
|
||||
assert.Equal(t, data[i], res[i])
|
||||
}
|
||||
|
||||
res, err = adapter.ReadBool(len(data))
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 1, len(res))
|
||||
assert.Equal(t, data[len(data)-1], res[0])
|
||||
|
||||
res, err = adapter.ReadBool(len(data))
|
||||
assert.NotNil(t, err)
|
||||
assert.Nil(t, res)
|
||||
|
||||
// incorrect type read
|
||||
resu1, err := adapter.ReadUint8(len(data))
|
||||
assert.NotNil(t, err)
|
||||
assert.Nil(t, resu1)
|
||||
|
||||
resi1, err := adapter.ReadInt8(len(data))
|
||||
assert.NotNil(t, err)
|
||||
assert.Nil(t, resi1)
|
||||
|
||||
resi2, err := adapter.ReadInt16(len(data))
|
||||
assert.NotNil(t, err)
|
||||
assert.Nil(t, resi2)
|
||||
|
||||
resi4, err := adapter.ReadInt32(len(data))
|
||||
assert.NotNil(t, err)
|
||||
assert.Nil(t, resi4)
|
||||
|
||||
resi8, err := adapter.ReadInt64(len(data))
|
||||
assert.NotNil(t, err)
|
||||
assert.Nil(t, resi8)
|
||||
|
||||
resf4, err := adapter.ReadFloat32(len(data))
|
||||
assert.NotNil(t, err)
|
||||
assert.Nil(t, resf4)
|
||||
|
||||
resf8, err := adapter.ReadFloat64(len(data))
|
||||
assert.NotNil(t, err)
|
||||
assert.Nil(t, resf8)
|
||||
}
|
||||
|
||||
{
|
||||
filePath := TempFilesPath + "uint8.npy"
|
||||
data := []uint8{1, 2, 3, 4, 5, 6}
|
||||
CreateNumpyFile(filePath, data)
|
||||
|
||||
file, err := os.Open(filePath)
|
||||
assert.Nil(t, err)
|
||||
defer file.Close()
|
||||
|
||||
adapter, err := NewNumpyAdapter(file)
|
||||
assert.Nil(t, err)
|
||||
|
||||
res, err := adapter.ReadUint8(len(data) - 1)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, len(data)-1, len(res))
|
||||
|
||||
for i := 0; i < len(res); i++ {
|
||||
assert.Equal(t, data[i], res[i])
|
||||
}
|
||||
|
||||
res, err = adapter.ReadUint8(len(data))
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 1, len(res))
|
||||
assert.Equal(t, data[len(data)-1], res[0])
|
||||
|
||||
res, err = adapter.ReadUint8(len(data))
|
||||
assert.NotNil(t, err)
|
||||
assert.Nil(t, res)
|
||||
|
||||
// incorrect type read
|
||||
resb, err := adapter.ReadBool(len(data))
|
||||
assert.NotNil(t, err)
|
||||
assert.Nil(t, resb)
|
||||
}
|
||||
|
||||
{
|
||||
filePath := TempFilesPath + "int8.npy"
|
||||
data := []int8{1, 2, 3, 4, 5, 6}
|
||||
CreateNumpyFile(filePath, data)
|
||||
|
||||
file, err := os.Open(filePath)
|
||||
assert.Nil(t, err)
|
||||
defer file.Close()
|
||||
|
||||
adapter, err := NewNumpyAdapter(file)
|
||||
assert.Nil(t, err)
|
||||
|
||||
res, err := adapter.ReadInt8(len(data) - 1)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, len(data)-1, len(res))
|
||||
|
||||
for i := 0; i < len(res); i++ {
|
||||
assert.Equal(t, data[i], res[i])
|
||||
}
|
||||
|
||||
res, err = adapter.ReadInt8(len(data))
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 1, len(res))
|
||||
assert.Equal(t, data[len(data)-1], res[0])
|
||||
|
||||
res, err = adapter.ReadInt8(len(data))
|
||||
assert.NotNil(t, err)
|
||||
assert.Nil(t, res)
|
||||
}
|
||||
|
||||
{
|
||||
filePath := TempFilesPath + "int16.npy"
|
||||
data := []int16{1, 2, 3, 4, 5, 6}
|
||||
CreateNumpyFile(filePath, data)
|
||||
|
||||
file, err := os.Open(filePath)
|
||||
assert.Nil(t, err)
|
||||
defer file.Close()
|
||||
|
||||
adapter, err := NewNumpyAdapter(file)
|
||||
assert.Nil(t, err)
|
||||
|
||||
res, err := adapter.ReadInt16(len(data) - 1)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, len(data)-1, len(res))
|
||||
|
||||
for i := 0; i < len(res); i++ {
|
||||
assert.Equal(t, data[i], res[i])
|
||||
}
|
||||
|
||||
res, err = adapter.ReadInt16(len(data))
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 1, len(res))
|
||||
assert.Equal(t, data[len(data)-1], res[0])
|
||||
|
||||
res, err = adapter.ReadInt16(len(data))
|
||||
assert.NotNil(t, err)
|
||||
assert.Nil(t, res)
|
||||
}
|
||||
|
||||
{
|
||||
filePath := TempFilesPath + "int32.npy"
|
||||
data := []int32{1, 2, 3, 4, 5, 6}
|
||||
CreateNumpyFile(filePath, data)
|
||||
|
||||
file, err := os.Open(filePath)
|
||||
assert.Nil(t, err)
|
||||
defer file.Close()
|
||||
|
||||
adapter, err := NewNumpyAdapter(file)
|
||||
assert.Nil(t, err)
|
||||
|
||||
res, err := adapter.ReadInt32(len(data) - 1)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, len(data)-1, len(res))
|
||||
|
||||
for i := 0; i < len(res); i++ {
|
||||
assert.Equal(t, data[i], res[i])
|
||||
}
|
||||
|
||||
res, err = adapter.ReadInt32(len(data))
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 1, len(res))
|
||||
assert.Equal(t, data[len(data)-1], res[0])
|
||||
|
||||
res, err = adapter.ReadInt32(len(data))
|
||||
assert.NotNil(t, err)
|
||||
assert.Nil(t, res)
|
||||
}
|
||||
|
||||
{
|
||||
filePath := TempFilesPath + "int64.npy"
|
||||
data := []int64{1, 2, 3, 4, 5, 6}
|
||||
CreateNumpyFile(filePath, data)
|
||||
|
||||
file, err := os.Open(filePath)
|
||||
assert.Nil(t, err)
|
||||
defer file.Close()
|
||||
|
||||
adapter, err := NewNumpyAdapter(file)
|
||||
assert.Nil(t, err)
|
||||
|
||||
res, err := adapter.ReadInt64(len(data) - 1)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, len(data)-1, len(res))
|
||||
|
||||
for i := 0; i < len(res); i++ {
|
||||
assert.Equal(t, data[i], res[i])
|
||||
}
|
||||
|
||||
res, err = adapter.ReadInt64(len(data))
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 1, len(res))
|
||||
assert.Equal(t, data[len(data)-1], res[0])
|
||||
|
||||
res, err = adapter.ReadInt64(len(data))
|
||||
assert.NotNil(t, err)
|
||||
assert.Nil(t, res)
|
||||
}
|
||||
|
||||
{
|
||||
filePath := TempFilesPath + "float.npy"
|
||||
data := []float32{1, 2, 3, 4, 5, 6}
|
||||
CreateNumpyFile(filePath, data)
|
||||
|
||||
file, err := os.Open(filePath)
|
||||
assert.Nil(t, err)
|
||||
defer file.Close()
|
||||
|
||||
adapter, err := NewNumpyAdapter(file)
|
||||
assert.Nil(t, err)
|
||||
|
||||
res, err := adapter.ReadFloat32(len(data) - 1)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, len(data)-1, len(res))
|
||||
|
||||
for i := 0; i < len(res); i++ {
|
||||
assert.Equal(t, data[i], res[i])
|
||||
}
|
||||
|
||||
res, err = adapter.ReadFloat32(len(data))
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 1, len(res))
|
||||
assert.Equal(t, data[len(data)-1], res[0])
|
||||
|
||||
res, err = adapter.ReadFloat32(len(data))
|
||||
assert.NotNil(t, err)
|
||||
assert.Nil(t, res)
|
||||
}
|
||||
|
||||
{
|
||||
filePath := TempFilesPath + "double.npy"
|
||||
data := []float64{1, 2, 3, 4, 5, 6}
|
||||
CreateNumpyFile(filePath, data)
|
||||
|
||||
file, err := os.Open(filePath)
|
||||
assert.Nil(t, err)
|
||||
defer file.Close()
|
||||
|
||||
adapter, err := NewNumpyAdapter(file)
|
||||
assert.Nil(t, err)
|
||||
|
||||
res, err := adapter.ReadFloat64(len(data) - 1)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, len(data)-1, len(res))
|
||||
|
||||
for i := 0; i < len(res); i++ {
|
||||
assert.Equal(t, data[i], res[i])
|
||||
}
|
||||
|
||||
res, err = adapter.ReadFloat64(len(data))
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 1, len(res))
|
||||
assert.Equal(t, data[len(data)-1], res[0])
|
||||
|
||||
res, err = adapter.ReadFloat64(len(data))
|
||||
assert.NotNil(t, err)
|
||||
assert.Nil(t, res)
|
||||
}
|
||||
}
|
290
internal/util/importutil/numpy_parser.go
Normal file
290
internal/util/importutil/numpy_parser.go
Normal file
@ -0,0 +1,290 @@
|
||||
package importutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"strconv"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
"github.com/milvus-io/milvus/internal/proto/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/storage"
|
||||
)
|
||||
|
||||
type ColumnDesc struct {
|
||||
name string // name of the target column
|
||||
dt schemapb.DataType // data type of the target column
|
||||
elementCount int // how many elements need to be read
|
||||
dimension int // only for vector
|
||||
}
|
||||
|
||||
type NumpyParser struct {
|
||||
ctx context.Context // for canceling parse process
|
||||
collectionSchema *schemapb.CollectionSchema // collection schema
|
||||
columnDesc *ColumnDesc // description for target column
|
||||
|
||||
columnData storage.FieldData // in-memory column data
|
||||
callFlushFunc func(field storage.FieldData) error // call back function to output column data
|
||||
}
|
||||
|
||||
// NewNumpyParser helper function to create a NumpyParser
|
||||
func NewNumpyParser(ctx context.Context, collectionSchema *schemapb.CollectionSchema,
|
||||
flushFunc func(field storage.FieldData) error) *NumpyParser {
|
||||
if collectionSchema == nil || flushFunc == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
parser := &NumpyParser{
|
||||
ctx: ctx,
|
||||
collectionSchema: collectionSchema,
|
||||
columnDesc: &ColumnDesc{},
|
||||
callFlushFunc: flushFunc,
|
||||
}
|
||||
|
||||
return parser
|
||||
}
|
||||
|
||||
func (p *NumpyParser) logError(msg string) error {
|
||||
log.Error(msg)
|
||||
return errors.New(msg)
|
||||
}
|
||||
|
||||
// data type converted from numpy header description, for vector field, the type is int8(binary vector) or float32(float vector)
|
||||
func convertNumpyType(str string) (schemapb.DataType, error) {
|
||||
switch str {
|
||||
case "b1", "<b1", "|b1", "bool":
|
||||
return schemapb.DataType_Bool, nil
|
||||
case "u1", "<u1", "|u1", "uint8": // binary vector data type is uint8
|
||||
return schemapb.DataType_BinaryVector, nil
|
||||
case "i1", "<i1", "|i1", ">i1", "int8":
|
||||
return schemapb.DataType_Int8, nil
|
||||
case "i2", "<i2", "|i2", ">i2", "int16":
|
||||
return schemapb.DataType_Int16, nil
|
||||
case "i4", "<i4", "|i4", ">i4", "int32":
|
||||
return schemapb.DataType_Int32, nil
|
||||
case "i8", "<i8", "|i8", ">i8", "int64":
|
||||
return schemapb.DataType_Int64, nil
|
||||
case "f4", "<f4", "|f4", ">f4", "float32":
|
||||
return schemapb.DataType_Float, nil
|
||||
case "f8", "<f8", "|f8", ">f8", "float64":
|
||||
return schemapb.DataType_Double, nil
|
||||
default:
|
||||
return schemapb.DataType_None, errors.New("unsupported data type " + str)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *NumpyParser) validate(adapter *NumpyAdapter, fieldName string) error {
|
||||
if adapter == nil {
|
||||
return errors.New("numpy adapter is nil")
|
||||
}
|
||||
|
||||
// check existence of the target field
|
||||
var schema *schemapb.FieldSchema
|
||||
for i := 0; i < len(p.collectionSchema.Fields); i++ {
|
||||
schema = p.collectionSchema.Fields[i]
|
||||
if schema.GetName() == fieldName {
|
||||
p.columnDesc.name = fieldName
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if p.columnDesc.name == "" {
|
||||
return errors.New("the field " + fieldName + " doesn't exist")
|
||||
}
|
||||
|
||||
p.columnDesc.dt = schema.DataType
|
||||
elementType, err := convertNumpyType(adapter.GetType())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
shape := adapter.GetShape()
|
||||
|
||||
// 1. field data type should be consist to numpy data type
|
||||
// 2. vector field dimension should be consist to numpy shape
|
||||
if schemapb.DataType_FloatVector == schema.DataType {
|
||||
if elementType != schemapb.DataType_Float {
|
||||
return errors.New("illegal data type " + adapter.GetType() + " for field " + schema.GetName())
|
||||
}
|
||||
|
||||
// vector field, the shape should be 2
|
||||
if len(shape) != 2 {
|
||||
return errors.New("illegal numpy shape " + strconv.Itoa(len(shape)) + " for field " + schema.GetName())
|
||||
}
|
||||
|
||||
// shape[0] is row count, shape[1] is element count per row
|
||||
p.columnDesc.elementCount = shape[0] * shape[1]
|
||||
|
||||
p.columnDesc.dimension, err = getFieldDimension(schema)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if shape[1] != p.columnDesc.dimension {
|
||||
return errors.New("illegal row width " + strconv.Itoa(shape[1]) + " for field " + schema.GetName() + " dimension " + strconv.Itoa(p.columnDesc.dimension))
|
||||
}
|
||||
} else if schemapb.DataType_BinaryVector == schema.DataType {
|
||||
if elementType != schemapb.DataType_BinaryVector {
|
||||
return errors.New("illegal data type " + adapter.GetType() + " for field " + schema.GetName())
|
||||
}
|
||||
|
||||
// vector field, the shape should be 2
|
||||
if len(shape) != 2 {
|
||||
return errors.New("illegal numpy shape " + strconv.Itoa(len(shape)) + " for field " + schema.GetName())
|
||||
}
|
||||
|
||||
// shape[0] is row count, shape[1] is element count per row
|
||||
p.columnDesc.elementCount = shape[0] * shape[1]
|
||||
|
||||
p.columnDesc.dimension, err = getFieldDimension(schema)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if shape[1] != p.columnDesc.dimension/8 {
|
||||
return errors.New("illegal row width " + strconv.Itoa(shape[1]) + " for field " + schema.GetName() + " dimension " + strconv.Itoa(p.columnDesc.dimension))
|
||||
}
|
||||
} else {
|
||||
if elementType != schema.DataType {
|
||||
return errors.New("illegal data type " + adapter.GetType() + " for field " + schema.GetName())
|
||||
}
|
||||
|
||||
// scalar field, the shape should be 1
|
||||
if len(shape) != 1 {
|
||||
return errors.New("illegal numpy shape " + strconv.Itoa(len(shape)) + " for field " + schema.GetName())
|
||||
}
|
||||
|
||||
p.columnDesc.elementCount = shape[0]
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// this method read numpy data section into a storage.FieldData
|
||||
// please note it will require a large memory block(the memory size is almost equal to numpy file size)
|
||||
func (p *NumpyParser) consume(adapter *NumpyAdapter) error {
|
||||
switch p.columnDesc.dt {
|
||||
case schemapb.DataType_Bool:
|
||||
data, err := adapter.ReadBool(p.columnDesc.elementCount)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.columnData = &storage.BoolFieldData{
|
||||
NumRows: []int64{int64(p.columnDesc.elementCount)},
|
||||
Data: data,
|
||||
}
|
||||
|
||||
case schemapb.DataType_Int8:
|
||||
data, err := adapter.ReadInt8(p.columnDesc.elementCount)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.columnData = &storage.Int8FieldData{
|
||||
NumRows: []int64{int64(p.columnDesc.elementCount)},
|
||||
Data: data,
|
||||
}
|
||||
case schemapb.DataType_Int16:
|
||||
data, err := adapter.ReadInt16(p.columnDesc.elementCount)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.columnData = &storage.Int16FieldData{
|
||||
NumRows: []int64{int64(p.columnDesc.elementCount)},
|
||||
Data: data,
|
||||
}
|
||||
case schemapb.DataType_Int32:
|
||||
data, err := adapter.ReadInt32(p.columnDesc.elementCount)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.columnData = &storage.Int32FieldData{
|
||||
NumRows: []int64{int64(p.columnDesc.elementCount)},
|
||||
Data: data,
|
||||
}
|
||||
case schemapb.DataType_Int64:
|
||||
data, err := adapter.ReadInt64(p.columnDesc.elementCount)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.columnData = &storage.Int64FieldData{
|
||||
NumRows: []int64{int64(p.columnDesc.elementCount)},
|
||||
Data: data,
|
||||
}
|
||||
case schemapb.DataType_Float:
|
||||
data, err := adapter.ReadFloat32(p.columnDesc.elementCount)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.columnData = &storage.FloatFieldData{
|
||||
NumRows: []int64{int64(p.columnDesc.elementCount)},
|
||||
Data: data,
|
||||
}
|
||||
case schemapb.DataType_Double:
|
||||
data, err := adapter.ReadFloat64(p.columnDesc.elementCount)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.columnData = &storage.DoubleFieldData{
|
||||
NumRows: []int64{int64(p.columnDesc.elementCount)},
|
||||
Data: data,
|
||||
}
|
||||
case schemapb.DataType_BinaryVector:
|
||||
data, err := adapter.ReadUint8(p.columnDesc.elementCount)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.columnData = &storage.BinaryVectorFieldData{
|
||||
NumRows: []int64{int64(p.columnDesc.elementCount)},
|
||||
Data: data,
|
||||
Dim: p.columnDesc.dimension,
|
||||
}
|
||||
case schemapb.DataType_FloatVector:
|
||||
data, err := adapter.ReadFloat32(p.columnDesc.elementCount)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.columnData = &storage.FloatVectorFieldData{
|
||||
NumRows: []int64{int64(p.columnDesc.elementCount)},
|
||||
Data: data,
|
||||
Dim: p.columnDesc.dimension,
|
||||
}
|
||||
default:
|
||||
return errors.New("unsupported data type: " + strconv.Itoa(int(p.columnDesc.dt)))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *NumpyParser) Parse(reader io.Reader, fieldName string, onlyValidate bool) error {
|
||||
adapter, err := NewNumpyAdapter(reader)
|
||||
if err != nil {
|
||||
return p.logError("Numpy parse: " + err.Error())
|
||||
}
|
||||
|
||||
// the validation method only check the file header information
|
||||
err = p.validate(adapter, fieldName)
|
||||
if err != nil {
|
||||
return p.logError("Numpy parse: " + err.Error())
|
||||
}
|
||||
|
||||
if onlyValidate {
|
||||
return nil
|
||||
}
|
||||
|
||||
// read all data from the numpy file
|
||||
err = p.consume(adapter)
|
||||
if err != nil {
|
||||
return p.logError("Numpy parse: " + err.Error())
|
||||
}
|
||||
|
||||
return p.callFlushFunc(p.columnData)
|
||||
}
|
509
internal/util/importutil/numpy_parser_test.go
Normal file
509
internal/util/importutil/numpy_parser_test.go
Normal file
@ -0,0 +1,509 @@
|
||||
package importutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/storage"
|
||||
"github.com/milvus-io/milvus/internal/util/timerecord"
|
||||
"github.com/sbinet/npyio/npy"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_NewNumpyParser(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
parser := NewNumpyParser(ctx, nil, nil)
|
||||
assert.Nil(t, parser)
|
||||
}
|
||||
|
||||
func Test_ConvertNumpyType(t *testing.T) {
|
||||
checkFunc := func(inputs []string, output schemapb.DataType) {
|
||||
for i := 0; i < len(inputs); i++ {
|
||||
dt, err := convertNumpyType(inputs[i])
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, output, dt)
|
||||
}
|
||||
}
|
||||
|
||||
checkFunc([]string{"b1", "<b1", "|b1", "bool"}, schemapb.DataType_Bool)
|
||||
checkFunc([]string{"i1", "<i1", "|i1", ">i1", "int8"}, schemapb.DataType_Int8)
|
||||
checkFunc([]string{"i2", "<i2", "|i2", ">i2", "int16"}, schemapb.DataType_Int16)
|
||||
checkFunc([]string{"i4", "<i4", "|i4", ">i4", "int32"}, schemapb.DataType_Int32)
|
||||
checkFunc([]string{"i8", "<i8", "|i8", ">i8", "int64"}, schemapb.DataType_Int64)
|
||||
checkFunc([]string{"f4", "<f4", "|f4", ">f4", "float32"}, schemapb.DataType_Float)
|
||||
checkFunc([]string{"f8", "<f8", "|f8", ">f8", "float64"}, schemapb.DataType_Double)
|
||||
|
||||
dt, err := convertNumpyType("dummy")
|
||||
assert.NotNil(t, err)
|
||||
assert.Equal(t, schemapb.DataType_None, dt)
|
||||
}
|
||||
|
||||
func Test_Validate(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
err := os.MkdirAll(TempFilesPath, os.ModePerm)
|
||||
assert.Nil(t, err)
|
||||
defer os.RemoveAll(TempFilesPath)
|
||||
|
||||
schema := sampleSchema()
|
||||
flushFunc := func(field storage.FieldData) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
adapter := &NumpyAdapter{npyReader: &npy.Reader{}}
|
||||
|
||||
{
|
||||
// string type is not supported
|
||||
p := NewNumpyParser(ctx, &schemapb.CollectionSchema{
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: 109,
|
||||
Name: "field_string",
|
||||
IsPrimaryKey: false,
|
||||
Description: "string",
|
||||
DataType: schemapb.DataType_String,
|
||||
},
|
||||
},
|
||||
}, flushFunc)
|
||||
err = p.validate(adapter, "dummy")
|
||||
assert.NotNil(t, err)
|
||||
err = p.validate(adapter, "field_string")
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
// reader is nil
|
||||
parser := NewNumpyParser(ctx, schema, flushFunc)
|
||||
err = parser.validate(nil, "")
|
||||
assert.NotNil(t, err)
|
||||
|
||||
// validate scalar data
|
||||
func() {
|
||||
filePath := TempFilesPath + "scalar_1.npy"
|
||||
data1 := []float64{0, 1, 2, 3, 4, 5}
|
||||
CreateNumpyFile(filePath, data1)
|
||||
|
||||
file1, err := os.Open(filePath)
|
||||
assert.Nil(t, err)
|
||||
defer file1.Close()
|
||||
|
||||
adapter, err := NewNumpyAdapter(file1)
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = parser.validate(adapter, "field_double")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, len(data1), parser.columnDesc.elementCount)
|
||||
|
||||
err = parser.validate(adapter, "")
|
||||
assert.NotNil(t, err)
|
||||
|
||||
// data type mismatch
|
||||
filePath = TempFilesPath + "scalar_2.npy"
|
||||
data2 := []int64{0, 1, 2, 3, 4, 5}
|
||||
CreateNumpyFile(filePath, data2)
|
||||
|
||||
file2, err := os.Open(filePath)
|
||||
assert.Nil(t, err)
|
||||
defer file2.Close()
|
||||
|
||||
adapter, err = NewNumpyAdapter(file2)
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = parser.validate(adapter, "field_double")
|
||||
assert.NotNil(t, err)
|
||||
|
||||
// shape mismatch
|
||||
filePath = TempFilesPath + "scalar_2.npy"
|
||||
data3 := [][2]float64{{1, 1}}
|
||||
CreateNumpyFile(filePath, data3)
|
||||
|
||||
file3, err := os.Open(filePath)
|
||||
assert.Nil(t, err)
|
||||
defer file2.Close()
|
||||
|
||||
adapter, err = NewNumpyAdapter(file3)
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = parser.validate(adapter, "field_double")
|
||||
assert.NotNil(t, err)
|
||||
}()
|
||||
|
||||
// validate binary vector data
|
||||
func() {
|
||||
filePath := TempFilesPath + "binary_vector_1.npy"
|
||||
data1 := [][2]uint8{{0, 1}, {2, 3}, {4, 5}}
|
||||
CreateNumpyFile(filePath, data1)
|
||||
|
||||
file1, err := os.Open(filePath)
|
||||
assert.Nil(t, err)
|
||||
defer file1.Close()
|
||||
|
||||
adapter, err := NewNumpyAdapter(file1)
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = parser.validate(adapter, "field_binary_vector")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, len(data1)*len(data1[0]), parser.columnDesc.elementCount)
|
||||
|
||||
// data type mismatch
|
||||
filePath = TempFilesPath + "binary_vector_2.npy"
|
||||
data2 := [][2]uint16{{0, 1}, {2, 3}, {4, 5}}
|
||||
CreateNumpyFile(filePath, data2)
|
||||
|
||||
file2, err := os.Open(filePath)
|
||||
assert.Nil(t, err)
|
||||
defer file2.Close()
|
||||
|
||||
adapter, err = NewNumpyAdapter(file2)
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = parser.validate(adapter, "field_binary_vector")
|
||||
assert.NotNil(t, err)
|
||||
|
||||
// shape mismatch
|
||||
filePath = TempFilesPath + "binary_vector_3.npy"
|
||||
data3 := []uint8{1, 2, 3}
|
||||
CreateNumpyFile(filePath, data3)
|
||||
|
||||
file3, err := os.Open(filePath)
|
||||
assert.Nil(t, err)
|
||||
defer file3.Close()
|
||||
|
||||
adapter, err = NewNumpyAdapter(file3)
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = parser.validate(adapter, "field_binary_vector")
|
||||
assert.NotNil(t, err)
|
||||
|
||||
// shape[1] mismatch
|
||||
filePath = TempFilesPath + "binary_vector_4.npy"
|
||||
data4 := [][3]uint8{{0, 1, 2}, {2, 3, 4}, {4, 5, 6}}
|
||||
CreateNumpyFile(filePath, data4)
|
||||
|
||||
file4, err := os.Open(filePath)
|
||||
assert.Nil(t, err)
|
||||
defer file4.Close()
|
||||
|
||||
adapter, err = NewNumpyAdapter(file4)
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = parser.validate(adapter, "field_binary_vector")
|
||||
assert.NotNil(t, err)
|
||||
|
||||
// dimension mismatch
|
||||
p := NewNumpyParser(ctx, &schemapb.CollectionSchema{
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: 109,
|
||||
Name: "field_binary_vector",
|
||||
DataType: schemapb.DataType_BinaryVector,
|
||||
},
|
||||
},
|
||||
}, flushFunc)
|
||||
|
||||
err = p.validate(adapter, "field_binary_vector")
|
||||
assert.NotNil(t, err)
|
||||
}()
|
||||
|
||||
// validate float vector data
|
||||
func() {
|
||||
filePath := TempFilesPath + "float_vector.npy"
|
||||
data1 := [][4]float32{{0, 0, 0, 0}, {1, 1, 1, 1}, {2, 2, 2, 2}, {3, 3, 3, 3}}
|
||||
CreateNumpyFile(filePath, data1)
|
||||
|
||||
file1, err := os.Open(filePath)
|
||||
assert.Nil(t, err)
|
||||
defer file1.Close()
|
||||
|
||||
adapter, err := NewNumpyAdapter(file1)
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = parser.validate(adapter, "field_float_vector")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, len(data1)*len(data1[0]), parser.columnDesc.elementCount)
|
||||
|
||||
// data type mismatch
|
||||
filePath = TempFilesPath + "float_vector_2.npy"
|
||||
data2 := [][4]int32{{0, 1, 2, 3}}
|
||||
CreateNumpyFile(filePath, data2)
|
||||
|
||||
file2, err := os.Open(filePath)
|
||||
assert.Nil(t, err)
|
||||
defer file2.Close()
|
||||
|
||||
adapter, err = NewNumpyAdapter(file2)
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = parser.validate(adapter, "field_float_vector")
|
||||
assert.NotNil(t, err)
|
||||
|
||||
// shape mismatch
|
||||
filePath = TempFilesPath + "float_vector_3.npy"
|
||||
data3 := []float32{1, 2, 3}
|
||||
CreateNumpyFile(filePath, data3)
|
||||
|
||||
file3, err := os.Open(filePath)
|
||||
assert.Nil(t, err)
|
||||
defer file3.Close()
|
||||
|
||||
adapter, err = NewNumpyAdapter(file3)
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = parser.validate(adapter, "field_float_vector")
|
||||
assert.NotNil(t, err)
|
||||
|
||||
// shape[1] mismatch
|
||||
filePath = TempFilesPath + "float_vector_4.npy"
|
||||
data4 := [][3]float32{{0, 0, 0}, {1, 1, 1}}
|
||||
CreateNumpyFile(filePath, data4)
|
||||
|
||||
file4, err := os.Open(filePath)
|
||||
assert.Nil(t, err)
|
||||
defer file4.Close()
|
||||
|
||||
adapter, err = NewNumpyAdapter(file4)
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = parser.validate(adapter, "field_float_vector")
|
||||
assert.NotNil(t, err)
|
||||
|
||||
// dimension mismatch
|
||||
p := NewNumpyParser(ctx, &schemapb.CollectionSchema{
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: 109,
|
||||
Name: "field_float_vector",
|
||||
DataType: schemapb.DataType_FloatVector,
|
||||
},
|
||||
},
|
||||
}, flushFunc)
|
||||
|
||||
err = p.validate(adapter, "field_float_vector")
|
||||
assert.NotNil(t, err)
|
||||
}()
|
||||
}
|
||||
|
||||
func Test_Parse(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
err := os.MkdirAll(TempFilesPath, os.ModePerm)
|
||||
assert.Nil(t, err)
|
||||
defer os.RemoveAll(TempFilesPath)
|
||||
|
||||
schema := sampleSchema()
|
||||
|
||||
checkFunc := func(data interface{}, fieldName string, callback func(field storage.FieldData) error) {
|
||||
|
||||
filePath := TempFilesPath + fieldName + ".npy"
|
||||
CreateNumpyFile(filePath, data)
|
||||
|
||||
func() {
|
||||
file, err := os.Open(filePath)
|
||||
assert.Nil(t, err)
|
||||
defer file.Close()
|
||||
|
||||
parser := NewNumpyParser(ctx, schema, callback)
|
||||
err = parser.Parse(file, fieldName, false)
|
||||
assert.Nil(t, err)
|
||||
}()
|
||||
|
||||
// validation failed
|
||||
func() {
|
||||
file, err := os.Open(filePath)
|
||||
assert.Nil(t, err)
|
||||
defer file.Close()
|
||||
|
||||
parser := NewNumpyParser(ctx, schema, callback)
|
||||
err = parser.Parse(file, "dummy", false)
|
||||
assert.NotNil(t, err)
|
||||
}()
|
||||
|
||||
// read data error
|
||||
func() {
|
||||
parser := NewNumpyParser(ctx, schema, callback)
|
||||
err = parser.Parse(&MockReader{}, fieldName, false)
|
||||
assert.NotNil(t, err)
|
||||
}()
|
||||
}
|
||||
|
||||
// scalar bool
|
||||
data1 := []bool{true, false, true, false, true}
|
||||
flushFunc := func(field storage.FieldData) error {
|
||||
assert.NotNil(t, field)
|
||||
assert.Equal(t, len(data1), field.RowNum())
|
||||
|
||||
for i := 0; i < len(data1); i++ {
|
||||
assert.Equal(t, data1[i], field.GetRow(i))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
checkFunc(data1, "field_bool", flushFunc)
|
||||
|
||||
// scalar int8
|
||||
data2 := []int8{1, 2, 3, 4, 5}
|
||||
flushFunc = func(field storage.FieldData) error {
|
||||
assert.NotNil(t, field)
|
||||
assert.Equal(t, len(data2), field.RowNum())
|
||||
|
||||
for i := 0; i < len(data2); i++ {
|
||||
assert.Equal(t, data2[i], field.GetRow(i))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
checkFunc(data2, "field_int8", flushFunc)
|
||||
|
||||
// scalar int16
|
||||
data3 := []int16{1, 2, 3, 4, 5}
|
||||
flushFunc = func(field storage.FieldData) error {
|
||||
assert.NotNil(t, field)
|
||||
assert.Equal(t, len(data3), field.RowNum())
|
||||
|
||||
for i := 0; i < len(data3); i++ {
|
||||
assert.Equal(t, data3[i], field.GetRow(i))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
checkFunc(data3, "field_int16", flushFunc)
|
||||
|
||||
// scalar int32
|
||||
data4 := []int32{1, 2, 3, 4, 5}
|
||||
flushFunc = func(field storage.FieldData) error {
|
||||
assert.NotNil(t, field)
|
||||
assert.Equal(t, len(data4), field.RowNum())
|
||||
|
||||
for i := 0; i < len(data4); i++ {
|
||||
assert.Equal(t, data4[i], field.GetRow(i))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
checkFunc(data4, "field_int32", flushFunc)
|
||||
|
||||
// scalar int64
|
||||
data5 := []int64{1, 2, 3, 4, 5}
|
||||
flushFunc = func(field storage.FieldData) error {
|
||||
assert.NotNil(t, field)
|
||||
assert.Equal(t, len(data5), field.RowNum())
|
||||
|
||||
for i := 0; i < len(data5); i++ {
|
||||
assert.Equal(t, data5[i], field.GetRow(i))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
checkFunc(data5, "field_int64", flushFunc)
|
||||
|
||||
// scalar float
|
||||
data6 := []float32{1, 2, 3, 4, 5}
|
||||
flushFunc = func(field storage.FieldData) error {
|
||||
assert.NotNil(t, field)
|
||||
assert.Equal(t, len(data6), field.RowNum())
|
||||
|
||||
for i := 0; i < len(data6); i++ {
|
||||
assert.Equal(t, data6[i], field.GetRow(i))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
checkFunc(data6, "field_float", flushFunc)
|
||||
|
||||
// scalar double
|
||||
data7 := []float64{1, 2, 3, 4, 5}
|
||||
flushFunc = func(field storage.FieldData) error {
|
||||
assert.NotNil(t, field)
|
||||
assert.Equal(t, len(data7), field.RowNum())
|
||||
|
||||
for i := 0; i < len(data7); i++ {
|
||||
assert.Equal(t, data7[i], field.GetRow(i))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
checkFunc(data7, "field_double", flushFunc)
|
||||
|
||||
// binary vector
|
||||
data8 := [][2]uint8{{1, 2}, {3, 4}, {5, 6}}
|
||||
flushFunc = func(field storage.FieldData) error {
|
||||
assert.NotNil(t, field)
|
||||
assert.Equal(t, len(data8), field.RowNum())
|
||||
|
||||
for i := 0; i < len(data8); i++ {
|
||||
row := field.GetRow(i).([]uint8)
|
||||
for k := 0; k < len(row); k++ {
|
||||
assert.Equal(t, data8[i][k], row[k])
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
checkFunc(data8, "field_binary_vector", flushFunc)
|
||||
|
||||
// double vector
|
||||
data9 := [][4]float32{{1.1, 2.1, 3.1, 4.1}, {5.2, 6.2, 7.2, 8.2}}
|
||||
flushFunc = func(field storage.FieldData) error {
|
||||
assert.NotNil(t, field)
|
||||
assert.Equal(t, len(data9), field.RowNum())
|
||||
|
||||
for i := 0; i < len(data9); i++ {
|
||||
row := field.GetRow(i).([]float32)
|
||||
for k := 0; k < len(row); k++ {
|
||||
assert.Equal(t, data9[i][k], row[k])
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
checkFunc(data9, "field_float_vector", flushFunc)
|
||||
}
|
||||
|
||||
func Test_Parse_perf(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
err := os.MkdirAll(TempFilesPath, os.ModePerm)
|
||||
assert.Nil(t, err)
|
||||
defer os.RemoveAll(TempFilesPath)
|
||||
|
||||
tr := timerecord.NewTimeRecorder("numpy parse performance")
|
||||
|
||||
// change the parameter to test performance
|
||||
rowCount := 10000
|
||||
dotValue := float32(3.1415926)
|
||||
const (
|
||||
dim = 128
|
||||
)
|
||||
|
||||
schema := perfSchema(dim)
|
||||
|
||||
data := make([][dim]float32, 0)
|
||||
for i := 0; i < rowCount; i++ {
|
||||
var row [dim]float32
|
||||
for k := 0; k < dim; k++ {
|
||||
row[k] = float32(i) + dotValue
|
||||
}
|
||||
data = append(data, row)
|
||||
}
|
||||
|
||||
tr.Record("generate large data")
|
||||
|
||||
flushFunc := func(field storage.FieldData) error {
|
||||
assert.Equal(t, len(data), field.RowNum())
|
||||
return nil
|
||||
}
|
||||
|
||||
filePath := TempFilesPath + "perf.npy"
|
||||
CreateNumpyFile(filePath, data)
|
||||
|
||||
tr.Record("generate large numpy file " + filePath)
|
||||
|
||||
file, err := os.Open(filePath)
|
||||
assert.Nil(t, err)
|
||||
defer file.Close()
|
||||
|
||||
parser := NewNumpyParser(ctx, schema, flushFunc)
|
||||
err = parser.Parse(file, "Vector", false)
|
||||
assert.Nil(t, err)
|
||||
|
||||
tr.Record("parse large numpy files: " + filePath)
|
||||
}
|
@ -47,6 +47,7 @@ go test -race -cover "${MILVUS_DIR}/util/retry/..." -failfast
|
||||
go test -race -cover "${MILVUS_DIR}/util/sessionutil/..." -failfast
|
||||
go test -race -cover "${MILVUS_DIR}/util/trace/..." -failfast
|
||||
go test -race -cover "${MILVUS_DIR}/util/typeutil/..." -failfast
|
||||
go test -race -cover "${MILVUS_DIR}/util/importutil/..." -failfast
|
||||
|
||||
# TODO: remove to distributed
|
||||
#go test -race -cover "${MILVUS_DIR}/proxy/..." -failfast
|
||||
|
Loading…
Reference in New Issue
Block a user