Support import numpy file (#16348)

Signed-off-by: groot <yihua.mo@zilliz.com>
This commit is contained in:
groot 2022-04-03 11:27:29 +08:00 committed by GitHub
parent 6682d1b635
commit be8d9a8b6b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 1736 additions and 7 deletions

5
go.mod
View File

@ -32,6 +32,8 @@ require (
github.com/pierrec/lz4 v2.5.2+incompatible // indirect
github.com/pkg/errors v0.9.1
github.com/prometheus/client_golang v1.11.0
github.com/quasilyte/go-ruleguard v0.2.1 // indirect
github.com/sbinet/npyio v0.6.0
github.com/shirou/gopsutil v3.21.8+incompatible
github.com/spaolacci/murmur3 v1.1.0
github.com/spf13/cast v1.3.1
@ -59,4 +61,5 @@ replace (
github.com/dgrijalva/jwt-go => github.com/golang-jwt/jwt v3.2.2+incompatible // Fix security alert for jwt-go 3.2.0
github.com/keybase/go-keychain => github.com/99designs/go-keychain v0.0.0-20191008050251-8e49817e8af4
google.golang.org/grpc => google.golang.org/grpc v1.38.0
)
)

6
go.sum
View File

@ -110,6 +110,7 @@ github.com/bketelsen/crypt v0.0.3-0.20200106085610-5cbc8cc4026c/go.mod h1:MKsuJm
github.com/bketelsen/crypt v0.0.4/go.mod h1:aI6NrJ0pMGgvZKL1iVgXLnfIFJtfV+bKCoqOes/6LfM=
github.com/bmizerany/perks v0.0.0-20141205001514-d9a9656a3a4b/go.mod h1:ac9efd0D1fsDb3EJvhqgXRbFx7bs2wqZ10HQPeU8U/Q=
github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
github.com/campoy/embedmd v1.0.0/go.mod h1:oxyr9RCiSXg0M3VJ3ks0UGfp98BpSSGr0kpiX3MzVl8=
github.com/casbin/casbin/v2 v2.1.2/go.mod h1:YcPU1XXisHhLzuxH9coDNf2FbKpjGlbCg3n9yuLkIJQ=
github.com/cenkalti/backoff v2.2.1+incompatible/go.mod h1:90ReRw6GdpyfrHakVjL/QHaoyV4aDUVVkXQJJJ3NXXM=
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
@ -584,6 +585,8 @@ github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4O
github.com/prometheus/procfs v0.6.0 h1:mxy4L2jP6qMonqmq+aTtOx1ifVWUgG/TAmntgbh3xv4=
github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA=
github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU=
github.com/quasilyte/go-ruleguard v0.2.1 h1:56eRm0daAyny9UhJnmtJW/UyLZQusukBAB8oT8AHKHo=
github.com/quasilyte/go-ruleguard v0.2.1/go.mod h1:hN2rVc/uS4bQhQKTio2XaSJSafJwqBUWWwtssT3cQmc=
github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=
github.com/rivo/tview v0.0.0-20200219210816-cd38d7432498/go.mod h1:6lkG1x+13OShEf0EaOCaTQYyB7d5nSbb181KtjlS+84=
github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
@ -598,6 +601,8 @@ github.com/ruudk/golang-pdf417 v0.0.0-20181029194003-1af4ab5afa58/go.mod h1:6lfF
github.com/ryanuber/columnize v0.0.0-20160712163229-9b3edd62028f/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts=
github.com/samuel/go-zookeeper v0.0.0-20190923202752-2cc03de413da/go.mod h1:gi+0XIa01GRL2eRQVjQkKGqKF3SF9vZR/HnPullcV2E=
github.com/sanity-io/litter v1.2.0/go.mod h1:JF6pZUFgu2Q0sBZ+HSV35P8TVPI1TTzEwyu9FXAw2W4=
github.com/sbinet/npyio v0.6.0 h1:IyqqQIzRjDym9xnIXsToCKei/qCzxDP+Y74KoMlMgXo=
github.com/sbinet/npyio v0.6.0/go.mod h1:/q3BNr6dJOy+t6h7RZchTJ0nwRJO52mivaem29WE1j8=
github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529/go.mod h1:DxrIzT+xaE7yg65j358z/aeFdxmN0P9QXhEzd20vsDc=
github.com/shirou/gopsutil v3.21.8+incompatible h1:sh0foI8tMRlCidUJR+KzqWYWxrkuuPIGiO6Vp+KXdCU=
github.com/shirou/gopsutil v3.21.8+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA=
@ -1038,6 +1043,7 @@ golang.org/x/tools v0.0.0-20200618134242-20370b0cb4b2/go.mod h1:EkVYQZoAsY45+roY
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
golang.org/x/tools v0.0.0-20200729194436-6467de6f59a7/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA=
golang.org/x/tools v0.0.0-20200804011535-6c149bb5ef0d/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA=
golang.org/x/tools v0.0.0-20200812195022-5ae4c3c160a0/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA=
golang.org/x/tools v0.0.0-20200825202427-b303f430e36d/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA=
golang.org/x/tools v0.0.0-20200904185747-39188db58858/go.mod h1:Cj7w3i3Rnn0Xh82ur9kSqwfTHTeVxaDqrfMjpcNT6bE=
golang.org/x/tools v0.0.0-20201110124207-079ba7bd75cd/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=

View File

@ -7,6 +7,7 @@ import (
"os"
"path"
"strconv"
"strings"
"github.com/milvus-io/milvus/internal/allocator"
"github.com/milvus-io/milvus/internal/common"
@ -89,6 +90,13 @@ func (p *ImportWrapper) printFieldsDataInfo(fieldsData map[string]storage.FieldD
log.Debug(msg, stats...)
}
func getFileNameAndExt(filePath string) (string, string) {
fileName := path.Base(filePath)
fileType := path.Ext(fileName)
fileNameWithoutExt := strings.TrimSuffix(fileName, fileType)
return fileNameWithoutExt, fileType
}
// import process entry
// filePath and rowBased are from ImportTask
// if onlyValidate is true, this process only do validation, no data generated, callFlushFunc will not be called
@ -99,8 +107,7 @@ func (p *ImportWrapper) Import(filePaths []string, rowBased bool, onlyValidate b
// according to shard number, so the callFlushFunc will be called in the JSONRowConsumer
for i := 0; i < len(filePaths); i++ {
filePath := filePaths[i]
fileName := path.Base(filePath)
fileType := path.Ext(fileName)
_, fileType := getFileNameAndExt(filePath)
log.Debug("imprort wrapper: row-based file ", zap.Any("filePath", filePath), zap.Any("fileType", fileType))
if fileType == JSONFileExt {
@ -183,8 +190,7 @@ func (p *ImportWrapper) Import(filePaths []string, rowBased bool, onlyValidate b
// parse/validate/consume data
for i := 0; i < len(filePaths); i++ {
filePath := filePaths[i]
fileName := path.Base(filePath)
fileType := path.Ext(fileName)
fileName, fileType := getFileNameAndExt(filePath)
log.Debug("imprort wrapper: column-based file ", zap.Any("filePath", filePath), zap.Any("fileType", fileType))
if fileType == JSONFileExt {
@ -218,7 +224,28 @@ func (p *ImportWrapper) Import(filePaths []string, rowBased bool, onlyValidate b
return err
}
} else if fileType == NumpyFileExt {
file, err := os.Open(filePath)
if err != nil {
log.Error("imprort error: "+err.Error(), zap.String("filePath", filePath))
return err
}
defer file.Close()
// the numpy parser return a storage.FieldData, here construct a map[string]storage.FieldData to combine
flushFunc := func(field storage.FieldData) error {
fields := make(map[string]storage.FieldData)
fields[fileName] = field
combineFunc(fields)
return nil
}
// for numpy file, we say the file name(without extension) is the filed name
parser := NewNumpyParser(p.ctx, p.collectionSchema, flushFunc)
err = parser.Parse(file, fileName, onlyValidate)
if err != nil {
log.Error("imprort error: "+err.Error(), zap.String("filePath", filePath))
return err
}
}
}

View File

@ -126,7 +126,7 @@ func Test_ImportRowBased(t *testing.T) {
}
func Test_ImportColumnBased(t *testing.T) {
func Test_ImportColumnBased_json(t *testing.T) {
ctx := context.Background()
err := os.MkdirAll(TempFilesPath, os.ModePerm)
assert.Nil(t, err)
@ -208,6 +208,88 @@ func Test_ImportColumnBased(t *testing.T) {
assert.NotNil(t, err)
}
func Test_ImportColumnBased_numpy(t *testing.T) {
ctx := context.Background()
err := os.MkdirAll(TempFilesPath, os.ModePerm)
assert.Nil(t, err)
defer os.RemoveAll(TempFilesPath)
idAllocator := newIDAllocator(ctx, t)
content := []byte(`{
"field_bool": [true, false, true, true, true],
"field_int8": [10, 11, 12, 13, 14],
"field_int16": [100, 101, 102, 103, 104],
"field_int32": [1000, 1001, 1002, 1003, 1004],
"field_int64": [10000, 10001, 10002, 10003, 10004],
"field_float": [3.14, 3.15, 3.16, 3.17, 3.18],
"field_double": [5.1, 5.2, 5.3, 5.4, 5.5],
"field_string": ["a", "b", "c", "d", "e"]
}`)
files := make([]string, 0)
filePath := TempFilesPath + "scalar_fields.json"
fp1 := saveFile(t, filePath, content)
fp1.Close()
files = append(files, filePath)
filePath = TempFilesPath + "field_binary_vector.npy"
bin := [][2]uint8{{1, 2}, {3, 4}, {5, 6}, {7, 8}, {9, 10}}
err = CreateNumpyFile(filePath, bin)
assert.Nil(t, err)
files = append(files, filePath)
filePath = TempFilesPath + "field_float_vector.npy"
flo := [][4]float32{{1, 2, 3, 4}, {3, 4, 5, 6}, {5, 6, 7, 8}, {7, 8, 9, 10}, {9, 10, 11, 12}}
err = CreateNumpyFile(filePath, flo)
assert.Nil(t, err)
files = append(files, filePath)
rowCount := 0
flushFunc := func(fields map[string]storage.FieldData) error {
count := 0
for _, data := range fields {
assert.Less(t, 0, data.RowNum())
if count == 0 {
count = data.RowNum()
} else {
assert.Equal(t, count, data.RowNum())
}
}
rowCount += count
return nil
}
// success case
wrapper := NewImportWrapper(ctx, sampleSchema(), 2, 1, idAllocator, flushFunc)
err = wrapper.Import(files, false, false)
assert.Nil(t, err)
assert.Equal(t, 5, rowCount)
// parse error
content = []byte(`{
"field_bool": [true, false, true, true, true]
}`)
filePath = TempFilesPath + "rows_2.json"
fp2 := saveFile(t, filePath, content)
defer fp2.Close()
wrapper = NewImportWrapper(ctx, sampleSchema(), 2, 1, idAllocator, flushFunc)
files = make([]string, 0)
files = append(files, filePath)
err = wrapper.Import(files, false, false)
assert.NotNil(t, err)
// file doesn't exist
files = make([]string, 0)
files = append(files, "/dummy/dummy.json")
err = wrapper.Import(files, false, false)
assert.NotNil(t, err)
}
func perfSchema(dim int) *schemapb.CollectionSchema {
schema := &schemapb.CollectionSchema{
Name: "schema",

View File

@ -24,7 +24,7 @@ type JSONParser struct {
fields map[string]int64 // fields need to be parsed
}
// newImportManager helper function to create a importManager
// NewJSONParser helper function to create a JSONParser
func NewJSONParser(ctx context.Context, collectionSchema *schemapb.CollectionSchema) *JSONParser {
fields := make(map[string]int64)
for i := 0; i < len(collectionSchema.Fields); i++ {

View File

@ -0,0 +1,356 @@
package importutil
import (
"encoding/binary"
"errors"
"io"
"os"
"github.com/sbinet/npyio"
"github.com/sbinet/npyio/npy"
)
func CreateNumpyFile(path string, data interface{}) error {
f, err := os.Create(path)
if err != nil {
return err
}
defer f.Close()
err = npyio.Write(f, data)
if err != nil {
return err
}
return nil
}
// a class to expand other numpy lib ability
// we evaluate two go-numpy lins: github.com/kshedden/gonpy and github.com/sbinet/npyio
// the npyio lib read data one by one, the performance is poor, we expand the read methods
// to read data in one batch, the performance is 100X faster
// the gonpy lib also read data in one batch, but it has no method to read bool data, and the ability
// to handle different data type is not strong as the npylib, so we choose the npyio lib to expand.
type NumpyAdapter struct {
reader io.Reader // data source, typically is os.File
npyReader *npy.Reader // reader of npyio lib
order binary.ByteOrder // LittleEndian or BigEndian
readPosition int // how many elements have been read
}
func NewNumpyAdapter(reader io.Reader) (*NumpyAdapter, error) {
r, err := npyio.NewReader(reader)
if err != nil {
return nil, err
}
adapter := &NumpyAdapter{
reader: reader,
npyReader: r,
readPosition: 0,
}
adapter.setByteOrder()
return adapter, err
}
// the logic of this method is copied from npyio lib
func (n *NumpyAdapter) setByteOrder() {
var nativeEndian binary.ByteOrder
v := uint16(1)
switch byte(v >> 8) {
case 0:
nativeEndian = binary.LittleEndian
case 1:
nativeEndian = binary.BigEndian
}
switch n.npyReader.Header.Descr.Type[0] {
case '<':
n.order = binary.LittleEndian
case '>':
n.order = binary.BigEndian
default:
n.order = nativeEndian
}
}
func (n *NumpyAdapter) Reader() io.Reader {
return n.reader
}
func (n *NumpyAdapter) NpyReader() *npy.Reader {
return n.npyReader
}
func (n *NumpyAdapter) GetType() string {
return n.npyReader.Header.Descr.Type
}
func (n *NumpyAdapter) GetShape() []int {
return n.npyReader.Header.Descr.Shape
}
func (n *NumpyAdapter) checkSize(size int) int {
shape := n.GetShape()
// empty file?
if len(shape) == 0 {
return 0
}
total := 1
for i := 0; i < len(shape); i++ {
total *= shape[i]
}
if total == 0 {
return 0
}
// overflow?
if size > (total - n.readPosition) {
return total - n.readPosition
}
return size
}
func (n *NumpyAdapter) ReadBool(size int) ([]bool, error) {
if n.npyReader == nil {
return nil, errors.New("reader is not initialized")
}
// incorrect type
switch n.npyReader.Header.Descr.Type {
case "b1", "<b1", "|b1", "bool":
default:
return nil, errors.New("numpy data is not bool type")
}
// avoid read overflow
readSize := n.checkSize(size)
if readSize <= 0 {
return nil, errors.New("nothing to read")
}
data := make([]bool, readSize)
err := binary.Read(n.reader, n.order, &data)
if err != nil {
return nil, err
}
// update read position after successfully read
n.readPosition += readSize
return data, nil
}
func (n *NumpyAdapter) ReadUint8(size int) ([]uint8, error) {
if n.npyReader == nil {
return nil, errors.New("reader is not initialized")
}
// incorrect type
switch n.npyReader.Header.Descr.Type {
case "u1", "<u1", "|u1", "uint8":
default:
return nil, errors.New("numpy data is not uint8 type")
}
// avoid read overflow
readSize := n.checkSize(size)
if readSize <= 0 {
return nil, errors.New("nothing to read")
}
data := make([]uint8, readSize)
err := binary.Read(n.reader, n.order, &data)
if err != nil {
return nil, err
}
// update read position after successfully read
n.readPosition += readSize
return data, nil
}
func (n *NumpyAdapter) ReadInt8(size int) ([]int8, error) {
if n.npyReader == nil {
return nil, errors.New("reader is not initialized")
}
// incorrect type
switch n.npyReader.Header.Descr.Type {
case "i1", "<i1", "|i1", ">i1", "int8":
default:
return nil, errors.New("numpy data is not int8 type")
}
// avoid read overflow
readSize := n.checkSize(size)
if readSize <= 0 {
return nil, errors.New("nothing to read")
}
data := make([]int8, readSize)
err := binary.Read(n.reader, n.order, &data)
if err != nil {
return nil, err
}
// update read position after successfully read
n.readPosition += readSize
return data, nil
}
func (n *NumpyAdapter) ReadInt16(size int) ([]int16, error) {
if n.npyReader == nil {
return nil, errors.New("reader is not initialized")
}
// incorrect type
switch n.npyReader.Header.Descr.Type {
case "i2", "<i2", "|i2", ">i2", "int16":
default:
return nil, errors.New("numpy data is not int16 type")
}
// avoid read overflow
readSize := n.checkSize(size)
if readSize <= 0 {
return nil, errors.New("nothing to read")
}
data := make([]int16, readSize)
err := binary.Read(n.reader, n.order, &data)
if err != nil {
return nil, err
}
// update read position after successfully read
n.readPosition += readSize
return data, nil
}
func (n *NumpyAdapter) ReadInt32(size int) ([]int32, error) {
if n.npyReader == nil {
return nil, errors.New("reader is not initialized")
}
// incorrect type
switch n.npyReader.Header.Descr.Type {
case "i4", "<i4", "|i4", ">i4", "int32":
default:
return nil, errors.New("numpy data is not int32 type")
}
// avoid read overflow
readSize := n.checkSize(size)
if readSize <= 0 {
return nil, errors.New("nothing to read")
}
data := make([]int32, readSize)
err := binary.Read(n.reader, n.order, &data)
if err != nil {
return nil, err
}
// update read position after successfully read
n.readPosition += readSize
return data, nil
}
func (n *NumpyAdapter) ReadInt64(size int) ([]int64, error) {
if n.npyReader == nil {
return nil, errors.New("reader is not initialized")
}
// incorrect type
switch n.npyReader.Header.Descr.Type {
case "i8", "<i8", "|i8", ">i8", "int64":
default:
return nil, errors.New("numpy data is not int64 type")
}
// avoid read overflow
readSize := n.checkSize(size)
if readSize <= 0 {
return nil, errors.New("nothing to read")
}
data := make([]int64, readSize)
err := binary.Read(n.reader, n.order, &data)
if err != nil {
return nil, err
}
// update read position after successfully read
n.readPosition += readSize
return data, nil
}
func (n *NumpyAdapter) ReadFloat32(size int) ([]float32, error) {
if n.npyReader == nil {
return nil, errors.New("reader is not initialized")
}
// incorrect type
switch n.npyReader.Header.Descr.Type {
case "f4", "<f4", "|f4", ">f4", "float32":
default:
return nil, errors.New("numpy data is not float32 type")
}
// avoid read overflow
readSize := n.checkSize(size)
if readSize <= 0 {
return nil, errors.New("nothing to read")
}
data := make([]float32, readSize)
err := binary.Read(n.reader, n.order, &data)
if err != nil {
return nil, err
}
// update read position after successfully read
n.readPosition += readSize
return data, nil
}
func (n *NumpyAdapter) ReadFloat64(size int) ([]float64, error) {
if n.npyReader == nil {
return nil, errors.New("reader is not initialized")
}
// incorrect type
switch n.npyReader.Header.Descr.Type {
case "f8", "<f8", "|f8", ">f8", "float64":
default:
return nil, errors.New("numpy data is not float32 type")
}
// avoid read overflow
readSize := n.checkSize(size)
if readSize <= 0 {
return nil, errors.New("nothing to read")
}
data := make([]float64, readSize)
err := binary.Read(n.reader, n.order, &data)
if err != nil {
return nil, err
}
// update read position after successfully read
n.readPosition += readSize
return data, nil
}

View File

@ -0,0 +1,455 @@
package importutil
import (
"encoding/binary"
"io"
"os"
"testing"
"github.com/sbinet/npyio/npy"
"github.com/stretchr/testify/assert"
)
type MockReader struct {
}
func (r *MockReader) Read(p []byte) (n int, err error) {
return 0, io.EOF
}
func Test_CreateNumpyFile(t *testing.T) {
// directory doesn't exist
data1 := []float32{1, 2, 3, 4, 5}
err := CreateNumpyFile("/dummy_not_exist/dummy.npy", data1)
assert.NotNil(t, err)
// invalid data type
data2 := make(map[string]int)
err = CreateNumpyFile("/tmp/dummy.npy", data2)
assert.NotNil(t, err)
}
func Test_SetByteOrder(t *testing.T) {
adapter := &NumpyAdapter{
reader: nil,
npyReader: &npy.Reader{},
}
assert.Nil(t, adapter.Reader())
assert.NotNil(t, adapter.NpyReader())
adapter.npyReader.Header.Descr.Type = "<i8"
adapter.setByteOrder()
assert.Equal(t, binary.LittleEndian, adapter.order)
adapter.npyReader.Header.Descr.Type = ">i8"
adapter.setByteOrder()
assert.Equal(t, binary.BigEndian, adapter.order)
}
func Test_ReadError(t *testing.T) {
adapter := &NumpyAdapter{
reader: nil,
npyReader: nil,
}
// reader is nil
{
_, err := adapter.ReadBool(1)
assert.NotNil(t, err)
_, err = adapter.ReadUint8(1)
assert.NotNil(t, err)
_, err = adapter.ReadInt8(1)
assert.NotNil(t, err)
_, err = adapter.ReadInt16(1)
assert.NotNil(t, err)
_, err = adapter.ReadInt32(1)
assert.NotNil(t, err)
_, err = adapter.ReadInt64(1)
assert.NotNil(t, err)
_, err = adapter.ReadFloat32(1)
assert.NotNil(t, err)
_, err = adapter.ReadFloat64(1)
assert.NotNil(t, err)
}
adapter = &NumpyAdapter{
reader: &MockReader{},
npyReader: &npy.Reader{},
}
{
adapter.npyReader.Header.Descr.Type = "bool"
data, err := adapter.ReadBool(1)
assert.Nil(t, data)
assert.NotNil(t, err)
adapter.npyReader.Header.Descr.Type = "dummy"
data, err = adapter.ReadBool(1)
assert.Nil(t, data)
assert.NotNil(t, err)
}
{
adapter.npyReader.Header.Descr.Type = "u1"
data, err := adapter.ReadUint8(1)
assert.Nil(t, data)
assert.NotNil(t, err)
adapter.npyReader.Header.Descr.Type = "dummy"
data, err = adapter.ReadUint8(1)
assert.Nil(t, data)
assert.NotNil(t, err)
}
{
adapter.npyReader.Header.Descr.Type = "i1"
data, err := adapter.ReadInt8(1)
assert.Nil(t, data)
assert.NotNil(t, err)
adapter.npyReader.Header.Descr.Type = "dummy"
data, err = adapter.ReadInt8(1)
assert.Nil(t, data)
assert.NotNil(t, err)
}
{
adapter.npyReader.Header.Descr.Type = "i2"
data, err := adapter.ReadInt16(1)
assert.Nil(t, data)
assert.NotNil(t, err)
adapter.npyReader.Header.Descr.Type = "dummy"
data, err = adapter.ReadInt16(1)
assert.Nil(t, data)
assert.NotNil(t, err)
}
{
adapter.npyReader.Header.Descr.Type = "i4"
data, err := adapter.ReadInt32(1)
assert.Nil(t, data)
assert.NotNil(t, err)
adapter.npyReader.Header.Descr.Type = "dummy"
data, err = adapter.ReadInt32(1)
assert.Nil(t, data)
assert.NotNil(t, err)
}
{
adapter.npyReader.Header.Descr.Type = "i8"
data, err := adapter.ReadInt64(1)
assert.Nil(t, data)
assert.NotNil(t, err)
adapter.npyReader.Header.Descr.Type = "dummy"
data, err = adapter.ReadInt64(1)
assert.Nil(t, data)
assert.NotNil(t, err)
}
{
adapter.npyReader.Header.Descr.Type = "f4"
data, err := adapter.ReadFloat32(1)
assert.Nil(t, data)
assert.NotNil(t, err)
adapter.npyReader.Header.Descr.Type = "dummy"
data, err = adapter.ReadFloat32(1)
assert.Nil(t, data)
assert.NotNil(t, err)
}
{
adapter.npyReader.Header.Descr.Type = "f8"
data, err := adapter.ReadFloat64(1)
assert.Nil(t, data)
assert.NotNil(t, err)
adapter.npyReader.Header.Descr.Type = "dummy"
data, err = adapter.ReadFloat64(1)
assert.Nil(t, data)
assert.NotNil(t, err)
}
}
func Test_Read(t *testing.T) {
err := os.MkdirAll(TempFilesPath, os.ModePerm)
assert.Nil(t, err)
defer os.RemoveAll(TempFilesPath)
{
filePath := TempFilesPath + "bool.npy"
data := []bool{true, false, true, false}
CreateNumpyFile(filePath, data)
file, err := os.Open(filePath)
assert.Nil(t, err)
defer file.Close()
adapter, err := NewNumpyAdapter(file)
assert.Nil(t, err)
res, err := adapter.ReadBool(len(data) - 1)
assert.Nil(t, err)
assert.Equal(t, len(data)-1, len(res))
for i := 0; i < len(res); i++ {
assert.Equal(t, data[i], res[i])
}
res, err = adapter.ReadBool(len(data))
assert.Nil(t, err)
assert.Equal(t, 1, len(res))
assert.Equal(t, data[len(data)-1], res[0])
res, err = adapter.ReadBool(len(data))
assert.NotNil(t, err)
assert.Nil(t, res)
// incorrect type read
resu1, err := adapter.ReadUint8(len(data))
assert.NotNil(t, err)
assert.Nil(t, resu1)
resi1, err := adapter.ReadInt8(len(data))
assert.NotNil(t, err)
assert.Nil(t, resi1)
resi2, err := adapter.ReadInt16(len(data))
assert.NotNil(t, err)
assert.Nil(t, resi2)
resi4, err := adapter.ReadInt32(len(data))
assert.NotNil(t, err)
assert.Nil(t, resi4)
resi8, err := adapter.ReadInt64(len(data))
assert.NotNil(t, err)
assert.Nil(t, resi8)
resf4, err := adapter.ReadFloat32(len(data))
assert.NotNil(t, err)
assert.Nil(t, resf4)
resf8, err := adapter.ReadFloat64(len(data))
assert.NotNil(t, err)
assert.Nil(t, resf8)
}
{
filePath := TempFilesPath + "uint8.npy"
data := []uint8{1, 2, 3, 4, 5, 6}
CreateNumpyFile(filePath, data)
file, err := os.Open(filePath)
assert.Nil(t, err)
defer file.Close()
adapter, err := NewNumpyAdapter(file)
assert.Nil(t, err)
res, err := adapter.ReadUint8(len(data) - 1)
assert.Nil(t, err)
assert.Equal(t, len(data)-1, len(res))
for i := 0; i < len(res); i++ {
assert.Equal(t, data[i], res[i])
}
res, err = adapter.ReadUint8(len(data))
assert.Nil(t, err)
assert.Equal(t, 1, len(res))
assert.Equal(t, data[len(data)-1], res[0])
res, err = adapter.ReadUint8(len(data))
assert.NotNil(t, err)
assert.Nil(t, res)
// incorrect type read
resb, err := adapter.ReadBool(len(data))
assert.NotNil(t, err)
assert.Nil(t, resb)
}
{
filePath := TempFilesPath + "int8.npy"
data := []int8{1, 2, 3, 4, 5, 6}
CreateNumpyFile(filePath, data)
file, err := os.Open(filePath)
assert.Nil(t, err)
defer file.Close()
adapter, err := NewNumpyAdapter(file)
assert.Nil(t, err)
res, err := adapter.ReadInt8(len(data) - 1)
assert.Nil(t, err)
assert.Equal(t, len(data)-1, len(res))
for i := 0; i < len(res); i++ {
assert.Equal(t, data[i], res[i])
}
res, err = adapter.ReadInt8(len(data))
assert.Nil(t, err)
assert.Equal(t, 1, len(res))
assert.Equal(t, data[len(data)-1], res[0])
res, err = adapter.ReadInt8(len(data))
assert.NotNil(t, err)
assert.Nil(t, res)
}
{
filePath := TempFilesPath + "int16.npy"
data := []int16{1, 2, 3, 4, 5, 6}
CreateNumpyFile(filePath, data)
file, err := os.Open(filePath)
assert.Nil(t, err)
defer file.Close()
adapter, err := NewNumpyAdapter(file)
assert.Nil(t, err)
res, err := adapter.ReadInt16(len(data) - 1)
assert.Nil(t, err)
assert.Equal(t, len(data)-1, len(res))
for i := 0; i < len(res); i++ {
assert.Equal(t, data[i], res[i])
}
res, err = adapter.ReadInt16(len(data))
assert.Nil(t, err)
assert.Equal(t, 1, len(res))
assert.Equal(t, data[len(data)-1], res[0])
res, err = adapter.ReadInt16(len(data))
assert.NotNil(t, err)
assert.Nil(t, res)
}
{
filePath := TempFilesPath + "int32.npy"
data := []int32{1, 2, 3, 4, 5, 6}
CreateNumpyFile(filePath, data)
file, err := os.Open(filePath)
assert.Nil(t, err)
defer file.Close()
adapter, err := NewNumpyAdapter(file)
assert.Nil(t, err)
res, err := adapter.ReadInt32(len(data) - 1)
assert.Nil(t, err)
assert.Equal(t, len(data)-1, len(res))
for i := 0; i < len(res); i++ {
assert.Equal(t, data[i], res[i])
}
res, err = adapter.ReadInt32(len(data))
assert.Nil(t, err)
assert.Equal(t, 1, len(res))
assert.Equal(t, data[len(data)-1], res[0])
res, err = adapter.ReadInt32(len(data))
assert.NotNil(t, err)
assert.Nil(t, res)
}
{
filePath := TempFilesPath + "int64.npy"
data := []int64{1, 2, 3, 4, 5, 6}
CreateNumpyFile(filePath, data)
file, err := os.Open(filePath)
assert.Nil(t, err)
defer file.Close()
adapter, err := NewNumpyAdapter(file)
assert.Nil(t, err)
res, err := adapter.ReadInt64(len(data) - 1)
assert.Nil(t, err)
assert.Equal(t, len(data)-1, len(res))
for i := 0; i < len(res); i++ {
assert.Equal(t, data[i], res[i])
}
res, err = adapter.ReadInt64(len(data))
assert.Nil(t, err)
assert.Equal(t, 1, len(res))
assert.Equal(t, data[len(data)-1], res[0])
res, err = adapter.ReadInt64(len(data))
assert.NotNil(t, err)
assert.Nil(t, res)
}
{
filePath := TempFilesPath + "float.npy"
data := []float32{1, 2, 3, 4, 5, 6}
CreateNumpyFile(filePath, data)
file, err := os.Open(filePath)
assert.Nil(t, err)
defer file.Close()
adapter, err := NewNumpyAdapter(file)
assert.Nil(t, err)
res, err := adapter.ReadFloat32(len(data) - 1)
assert.Nil(t, err)
assert.Equal(t, len(data)-1, len(res))
for i := 0; i < len(res); i++ {
assert.Equal(t, data[i], res[i])
}
res, err = adapter.ReadFloat32(len(data))
assert.Nil(t, err)
assert.Equal(t, 1, len(res))
assert.Equal(t, data[len(data)-1], res[0])
res, err = adapter.ReadFloat32(len(data))
assert.NotNil(t, err)
assert.Nil(t, res)
}
{
filePath := TempFilesPath + "double.npy"
data := []float64{1, 2, 3, 4, 5, 6}
CreateNumpyFile(filePath, data)
file, err := os.Open(filePath)
assert.Nil(t, err)
defer file.Close()
adapter, err := NewNumpyAdapter(file)
assert.Nil(t, err)
res, err := adapter.ReadFloat64(len(data) - 1)
assert.Nil(t, err)
assert.Equal(t, len(data)-1, len(res))
for i := 0; i < len(res); i++ {
assert.Equal(t, data[i], res[i])
}
res, err = adapter.ReadFloat64(len(data))
assert.Nil(t, err)
assert.Equal(t, 1, len(res))
assert.Equal(t, data[len(data)-1], res[0])
res, err = adapter.ReadFloat64(len(data))
assert.NotNil(t, err)
assert.Nil(t, res)
}
}

View File

@ -0,0 +1,290 @@
package importutil
import (
"context"
"errors"
"io"
"strconv"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/proto/schemapb"
"github.com/milvus-io/milvus/internal/storage"
)
type ColumnDesc struct {
name string // name of the target column
dt schemapb.DataType // data type of the target column
elementCount int // how many elements need to be read
dimension int // only for vector
}
type NumpyParser struct {
ctx context.Context // for canceling parse process
collectionSchema *schemapb.CollectionSchema // collection schema
columnDesc *ColumnDesc // description for target column
columnData storage.FieldData // in-memory column data
callFlushFunc func(field storage.FieldData) error // call back function to output column data
}
// NewNumpyParser helper function to create a NumpyParser
func NewNumpyParser(ctx context.Context, collectionSchema *schemapb.CollectionSchema,
flushFunc func(field storage.FieldData) error) *NumpyParser {
if collectionSchema == nil || flushFunc == nil {
return nil
}
parser := &NumpyParser{
ctx: ctx,
collectionSchema: collectionSchema,
columnDesc: &ColumnDesc{},
callFlushFunc: flushFunc,
}
return parser
}
func (p *NumpyParser) logError(msg string) error {
log.Error(msg)
return errors.New(msg)
}
// data type converted from numpy header description, for vector field, the type is int8(binary vector) or float32(float vector)
func convertNumpyType(str string) (schemapb.DataType, error) {
switch str {
case "b1", "<b1", "|b1", "bool":
return schemapb.DataType_Bool, nil
case "u1", "<u1", "|u1", "uint8": // binary vector data type is uint8
return schemapb.DataType_BinaryVector, nil
case "i1", "<i1", "|i1", ">i1", "int8":
return schemapb.DataType_Int8, nil
case "i2", "<i2", "|i2", ">i2", "int16":
return schemapb.DataType_Int16, nil
case "i4", "<i4", "|i4", ">i4", "int32":
return schemapb.DataType_Int32, nil
case "i8", "<i8", "|i8", ">i8", "int64":
return schemapb.DataType_Int64, nil
case "f4", "<f4", "|f4", ">f4", "float32":
return schemapb.DataType_Float, nil
case "f8", "<f8", "|f8", ">f8", "float64":
return schemapb.DataType_Double, nil
default:
return schemapb.DataType_None, errors.New("unsupported data type " + str)
}
}
func (p *NumpyParser) validate(adapter *NumpyAdapter, fieldName string) error {
if adapter == nil {
return errors.New("numpy adapter is nil")
}
// check existence of the target field
var schema *schemapb.FieldSchema
for i := 0; i < len(p.collectionSchema.Fields); i++ {
schema = p.collectionSchema.Fields[i]
if schema.GetName() == fieldName {
p.columnDesc.name = fieldName
break
}
}
if p.columnDesc.name == "" {
return errors.New("the field " + fieldName + " doesn't exist")
}
p.columnDesc.dt = schema.DataType
elementType, err := convertNumpyType(adapter.GetType())
if err != nil {
return err
}
shape := adapter.GetShape()
// 1. field data type should be consist to numpy data type
// 2. vector field dimension should be consist to numpy shape
if schemapb.DataType_FloatVector == schema.DataType {
if elementType != schemapb.DataType_Float {
return errors.New("illegal data type " + adapter.GetType() + " for field " + schema.GetName())
}
// vector field, the shape should be 2
if len(shape) != 2 {
return errors.New("illegal numpy shape " + strconv.Itoa(len(shape)) + " for field " + schema.GetName())
}
// shape[0] is row count, shape[1] is element count per row
p.columnDesc.elementCount = shape[0] * shape[1]
p.columnDesc.dimension, err = getFieldDimension(schema)
if err != nil {
return err
}
if shape[1] != p.columnDesc.dimension {
return errors.New("illegal row width " + strconv.Itoa(shape[1]) + " for field " + schema.GetName() + " dimension " + strconv.Itoa(p.columnDesc.dimension))
}
} else if schemapb.DataType_BinaryVector == schema.DataType {
if elementType != schemapb.DataType_BinaryVector {
return errors.New("illegal data type " + adapter.GetType() + " for field " + schema.GetName())
}
// vector field, the shape should be 2
if len(shape) != 2 {
return errors.New("illegal numpy shape " + strconv.Itoa(len(shape)) + " for field " + schema.GetName())
}
// shape[0] is row count, shape[1] is element count per row
p.columnDesc.elementCount = shape[0] * shape[1]
p.columnDesc.dimension, err = getFieldDimension(schema)
if err != nil {
return err
}
if shape[1] != p.columnDesc.dimension/8 {
return errors.New("illegal row width " + strconv.Itoa(shape[1]) + " for field " + schema.GetName() + " dimension " + strconv.Itoa(p.columnDesc.dimension))
}
} else {
if elementType != schema.DataType {
return errors.New("illegal data type " + adapter.GetType() + " for field " + schema.GetName())
}
// scalar field, the shape should be 1
if len(shape) != 1 {
return errors.New("illegal numpy shape " + strconv.Itoa(len(shape)) + " for field " + schema.GetName())
}
p.columnDesc.elementCount = shape[0]
}
return nil
}
// this method read numpy data section into a storage.FieldData
// please note it will require a large memory block(the memory size is almost equal to numpy file size)
func (p *NumpyParser) consume(adapter *NumpyAdapter) error {
switch p.columnDesc.dt {
case schemapb.DataType_Bool:
data, err := adapter.ReadBool(p.columnDesc.elementCount)
if err != nil {
return err
}
p.columnData = &storage.BoolFieldData{
NumRows: []int64{int64(p.columnDesc.elementCount)},
Data: data,
}
case schemapb.DataType_Int8:
data, err := adapter.ReadInt8(p.columnDesc.elementCount)
if err != nil {
return err
}
p.columnData = &storage.Int8FieldData{
NumRows: []int64{int64(p.columnDesc.elementCount)},
Data: data,
}
case schemapb.DataType_Int16:
data, err := adapter.ReadInt16(p.columnDesc.elementCount)
if err != nil {
return err
}
p.columnData = &storage.Int16FieldData{
NumRows: []int64{int64(p.columnDesc.elementCount)},
Data: data,
}
case schemapb.DataType_Int32:
data, err := adapter.ReadInt32(p.columnDesc.elementCount)
if err != nil {
return err
}
p.columnData = &storage.Int32FieldData{
NumRows: []int64{int64(p.columnDesc.elementCount)},
Data: data,
}
case schemapb.DataType_Int64:
data, err := adapter.ReadInt64(p.columnDesc.elementCount)
if err != nil {
return err
}
p.columnData = &storage.Int64FieldData{
NumRows: []int64{int64(p.columnDesc.elementCount)},
Data: data,
}
case schemapb.DataType_Float:
data, err := adapter.ReadFloat32(p.columnDesc.elementCount)
if err != nil {
return err
}
p.columnData = &storage.FloatFieldData{
NumRows: []int64{int64(p.columnDesc.elementCount)},
Data: data,
}
case schemapb.DataType_Double:
data, err := adapter.ReadFloat64(p.columnDesc.elementCount)
if err != nil {
return err
}
p.columnData = &storage.DoubleFieldData{
NumRows: []int64{int64(p.columnDesc.elementCount)},
Data: data,
}
case schemapb.DataType_BinaryVector:
data, err := adapter.ReadUint8(p.columnDesc.elementCount)
if err != nil {
return err
}
p.columnData = &storage.BinaryVectorFieldData{
NumRows: []int64{int64(p.columnDesc.elementCount)},
Data: data,
Dim: p.columnDesc.dimension,
}
case schemapb.DataType_FloatVector:
data, err := adapter.ReadFloat32(p.columnDesc.elementCount)
if err != nil {
return err
}
p.columnData = &storage.FloatVectorFieldData{
NumRows: []int64{int64(p.columnDesc.elementCount)},
Data: data,
Dim: p.columnDesc.dimension,
}
default:
return errors.New("unsupported data type: " + strconv.Itoa(int(p.columnDesc.dt)))
}
return nil
}
func (p *NumpyParser) Parse(reader io.Reader, fieldName string, onlyValidate bool) error {
adapter, err := NewNumpyAdapter(reader)
if err != nil {
return p.logError("Numpy parse: " + err.Error())
}
// the validation method only check the file header information
err = p.validate(adapter, fieldName)
if err != nil {
return p.logError("Numpy parse: " + err.Error())
}
if onlyValidate {
return nil
}
// read all data from the numpy file
err = p.consume(adapter)
if err != nil {
return p.logError("Numpy parse: " + err.Error())
}
return p.callFlushFunc(p.columnData)
}

View File

@ -0,0 +1,509 @@
package importutil
import (
"context"
"os"
"testing"
"github.com/milvus-io/milvus/internal/proto/schemapb"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/util/timerecord"
"github.com/sbinet/npyio/npy"
"github.com/stretchr/testify/assert"
)
func Test_NewNumpyParser(t *testing.T) {
ctx := context.Background()
parser := NewNumpyParser(ctx, nil, nil)
assert.Nil(t, parser)
}
func Test_ConvertNumpyType(t *testing.T) {
checkFunc := func(inputs []string, output schemapb.DataType) {
for i := 0; i < len(inputs); i++ {
dt, err := convertNumpyType(inputs[i])
assert.Nil(t, err)
assert.Equal(t, output, dt)
}
}
checkFunc([]string{"b1", "<b1", "|b1", "bool"}, schemapb.DataType_Bool)
checkFunc([]string{"i1", "<i1", "|i1", ">i1", "int8"}, schemapb.DataType_Int8)
checkFunc([]string{"i2", "<i2", "|i2", ">i2", "int16"}, schemapb.DataType_Int16)
checkFunc([]string{"i4", "<i4", "|i4", ">i4", "int32"}, schemapb.DataType_Int32)
checkFunc([]string{"i8", "<i8", "|i8", ">i8", "int64"}, schemapb.DataType_Int64)
checkFunc([]string{"f4", "<f4", "|f4", ">f4", "float32"}, schemapb.DataType_Float)
checkFunc([]string{"f8", "<f8", "|f8", ">f8", "float64"}, schemapb.DataType_Double)
dt, err := convertNumpyType("dummy")
assert.NotNil(t, err)
assert.Equal(t, schemapb.DataType_None, dt)
}
func Test_Validate(t *testing.T) {
ctx := context.Background()
err := os.MkdirAll(TempFilesPath, os.ModePerm)
assert.Nil(t, err)
defer os.RemoveAll(TempFilesPath)
schema := sampleSchema()
flushFunc := func(field storage.FieldData) error {
return nil
}
adapter := &NumpyAdapter{npyReader: &npy.Reader{}}
{
// string type is not supported
p := NewNumpyParser(ctx, &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{
FieldID: 109,
Name: "field_string",
IsPrimaryKey: false,
Description: "string",
DataType: schemapb.DataType_String,
},
},
}, flushFunc)
err = p.validate(adapter, "dummy")
assert.NotNil(t, err)
err = p.validate(adapter, "field_string")
assert.NotNil(t, err)
}
// reader is nil
parser := NewNumpyParser(ctx, schema, flushFunc)
err = parser.validate(nil, "")
assert.NotNil(t, err)
// validate scalar data
func() {
filePath := TempFilesPath + "scalar_1.npy"
data1 := []float64{0, 1, 2, 3, 4, 5}
CreateNumpyFile(filePath, data1)
file1, err := os.Open(filePath)
assert.Nil(t, err)
defer file1.Close()
adapter, err := NewNumpyAdapter(file1)
assert.Nil(t, err)
err = parser.validate(adapter, "field_double")
assert.Nil(t, err)
assert.Equal(t, len(data1), parser.columnDesc.elementCount)
err = parser.validate(adapter, "")
assert.NotNil(t, err)
// data type mismatch
filePath = TempFilesPath + "scalar_2.npy"
data2 := []int64{0, 1, 2, 3, 4, 5}
CreateNumpyFile(filePath, data2)
file2, err := os.Open(filePath)
assert.Nil(t, err)
defer file2.Close()
adapter, err = NewNumpyAdapter(file2)
assert.Nil(t, err)
err = parser.validate(adapter, "field_double")
assert.NotNil(t, err)
// shape mismatch
filePath = TempFilesPath + "scalar_2.npy"
data3 := [][2]float64{{1, 1}}
CreateNumpyFile(filePath, data3)
file3, err := os.Open(filePath)
assert.Nil(t, err)
defer file2.Close()
adapter, err = NewNumpyAdapter(file3)
assert.Nil(t, err)
err = parser.validate(adapter, "field_double")
assert.NotNil(t, err)
}()
// validate binary vector data
func() {
filePath := TempFilesPath + "binary_vector_1.npy"
data1 := [][2]uint8{{0, 1}, {2, 3}, {4, 5}}
CreateNumpyFile(filePath, data1)
file1, err := os.Open(filePath)
assert.Nil(t, err)
defer file1.Close()
adapter, err := NewNumpyAdapter(file1)
assert.Nil(t, err)
err = parser.validate(adapter, "field_binary_vector")
assert.Nil(t, err)
assert.Equal(t, len(data1)*len(data1[0]), parser.columnDesc.elementCount)
// data type mismatch
filePath = TempFilesPath + "binary_vector_2.npy"
data2 := [][2]uint16{{0, 1}, {2, 3}, {4, 5}}
CreateNumpyFile(filePath, data2)
file2, err := os.Open(filePath)
assert.Nil(t, err)
defer file2.Close()
adapter, err = NewNumpyAdapter(file2)
assert.Nil(t, err)
err = parser.validate(adapter, "field_binary_vector")
assert.NotNil(t, err)
// shape mismatch
filePath = TempFilesPath + "binary_vector_3.npy"
data3 := []uint8{1, 2, 3}
CreateNumpyFile(filePath, data3)
file3, err := os.Open(filePath)
assert.Nil(t, err)
defer file3.Close()
adapter, err = NewNumpyAdapter(file3)
assert.Nil(t, err)
err = parser.validate(adapter, "field_binary_vector")
assert.NotNil(t, err)
// shape[1] mismatch
filePath = TempFilesPath + "binary_vector_4.npy"
data4 := [][3]uint8{{0, 1, 2}, {2, 3, 4}, {4, 5, 6}}
CreateNumpyFile(filePath, data4)
file4, err := os.Open(filePath)
assert.Nil(t, err)
defer file4.Close()
adapter, err = NewNumpyAdapter(file4)
assert.Nil(t, err)
err = parser.validate(adapter, "field_binary_vector")
assert.NotNil(t, err)
// dimension mismatch
p := NewNumpyParser(ctx, &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{
FieldID: 109,
Name: "field_binary_vector",
DataType: schemapb.DataType_BinaryVector,
},
},
}, flushFunc)
err = p.validate(adapter, "field_binary_vector")
assert.NotNil(t, err)
}()
// validate float vector data
func() {
filePath := TempFilesPath + "float_vector.npy"
data1 := [][4]float32{{0, 0, 0, 0}, {1, 1, 1, 1}, {2, 2, 2, 2}, {3, 3, 3, 3}}
CreateNumpyFile(filePath, data1)
file1, err := os.Open(filePath)
assert.Nil(t, err)
defer file1.Close()
adapter, err := NewNumpyAdapter(file1)
assert.Nil(t, err)
err = parser.validate(adapter, "field_float_vector")
assert.Nil(t, err)
assert.Equal(t, len(data1)*len(data1[0]), parser.columnDesc.elementCount)
// data type mismatch
filePath = TempFilesPath + "float_vector_2.npy"
data2 := [][4]int32{{0, 1, 2, 3}}
CreateNumpyFile(filePath, data2)
file2, err := os.Open(filePath)
assert.Nil(t, err)
defer file2.Close()
adapter, err = NewNumpyAdapter(file2)
assert.Nil(t, err)
err = parser.validate(adapter, "field_float_vector")
assert.NotNil(t, err)
// shape mismatch
filePath = TempFilesPath + "float_vector_3.npy"
data3 := []float32{1, 2, 3}
CreateNumpyFile(filePath, data3)
file3, err := os.Open(filePath)
assert.Nil(t, err)
defer file3.Close()
adapter, err = NewNumpyAdapter(file3)
assert.Nil(t, err)
err = parser.validate(adapter, "field_float_vector")
assert.NotNil(t, err)
// shape[1] mismatch
filePath = TempFilesPath + "float_vector_4.npy"
data4 := [][3]float32{{0, 0, 0}, {1, 1, 1}}
CreateNumpyFile(filePath, data4)
file4, err := os.Open(filePath)
assert.Nil(t, err)
defer file4.Close()
adapter, err = NewNumpyAdapter(file4)
assert.Nil(t, err)
err = parser.validate(adapter, "field_float_vector")
assert.NotNil(t, err)
// dimension mismatch
p := NewNumpyParser(ctx, &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{
FieldID: 109,
Name: "field_float_vector",
DataType: schemapb.DataType_FloatVector,
},
},
}, flushFunc)
err = p.validate(adapter, "field_float_vector")
assert.NotNil(t, err)
}()
}
func Test_Parse(t *testing.T) {
ctx := context.Background()
err := os.MkdirAll(TempFilesPath, os.ModePerm)
assert.Nil(t, err)
defer os.RemoveAll(TempFilesPath)
schema := sampleSchema()
checkFunc := func(data interface{}, fieldName string, callback func(field storage.FieldData) error) {
filePath := TempFilesPath + fieldName + ".npy"
CreateNumpyFile(filePath, data)
func() {
file, err := os.Open(filePath)
assert.Nil(t, err)
defer file.Close()
parser := NewNumpyParser(ctx, schema, callback)
err = parser.Parse(file, fieldName, false)
assert.Nil(t, err)
}()
// validation failed
func() {
file, err := os.Open(filePath)
assert.Nil(t, err)
defer file.Close()
parser := NewNumpyParser(ctx, schema, callback)
err = parser.Parse(file, "dummy", false)
assert.NotNil(t, err)
}()
// read data error
func() {
parser := NewNumpyParser(ctx, schema, callback)
err = parser.Parse(&MockReader{}, fieldName, false)
assert.NotNil(t, err)
}()
}
// scalar bool
data1 := []bool{true, false, true, false, true}
flushFunc := func(field storage.FieldData) error {
assert.NotNil(t, field)
assert.Equal(t, len(data1), field.RowNum())
for i := 0; i < len(data1); i++ {
assert.Equal(t, data1[i], field.GetRow(i))
}
return nil
}
checkFunc(data1, "field_bool", flushFunc)
// scalar int8
data2 := []int8{1, 2, 3, 4, 5}
flushFunc = func(field storage.FieldData) error {
assert.NotNil(t, field)
assert.Equal(t, len(data2), field.RowNum())
for i := 0; i < len(data2); i++ {
assert.Equal(t, data2[i], field.GetRow(i))
}
return nil
}
checkFunc(data2, "field_int8", flushFunc)
// scalar int16
data3 := []int16{1, 2, 3, 4, 5}
flushFunc = func(field storage.FieldData) error {
assert.NotNil(t, field)
assert.Equal(t, len(data3), field.RowNum())
for i := 0; i < len(data3); i++ {
assert.Equal(t, data3[i], field.GetRow(i))
}
return nil
}
checkFunc(data3, "field_int16", flushFunc)
// scalar int32
data4 := []int32{1, 2, 3, 4, 5}
flushFunc = func(field storage.FieldData) error {
assert.NotNil(t, field)
assert.Equal(t, len(data4), field.RowNum())
for i := 0; i < len(data4); i++ {
assert.Equal(t, data4[i], field.GetRow(i))
}
return nil
}
checkFunc(data4, "field_int32", flushFunc)
// scalar int64
data5 := []int64{1, 2, 3, 4, 5}
flushFunc = func(field storage.FieldData) error {
assert.NotNil(t, field)
assert.Equal(t, len(data5), field.RowNum())
for i := 0; i < len(data5); i++ {
assert.Equal(t, data5[i], field.GetRow(i))
}
return nil
}
checkFunc(data5, "field_int64", flushFunc)
// scalar float
data6 := []float32{1, 2, 3, 4, 5}
flushFunc = func(field storage.FieldData) error {
assert.NotNil(t, field)
assert.Equal(t, len(data6), field.RowNum())
for i := 0; i < len(data6); i++ {
assert.Equal(t, data6[i], field.GetRow(i))
}
return nil
}
checkFunc(data6, "field_float", flushFunc)
// scalar double
data7 := []float64{1, 2, 3, 4, 5}
flushFunc = func(field storage.FieldData) error {
assert.NotNil(t, field)
assert.Equal(t, len(data7), field.RowNum())
for i := 0; i < len(data7); i++ {
assert.Equal(t, data7[i], field.GetRow(i))
}
return nil
}
checkFunc(data7, "field_double", flushFunc)
// binary vector
data8 := [][2]uint8{{1, 2}, {3, 4}, {5, 6}}
flushFunc = func(field storage.FieldData) error {
assert.NotNil(t, field)
assert.Equal(t, len(data8), field.RowNum())
for i := 0; i < len(data8); i++ {
row := field.GetRow(i).([]uint8)
for k := 0; k < len(row); k++ {
assert.Equal(t, data8[i][k], row[k])
}
}
return nil
}
checkFunc(data8, "field_binary_vector", flushFunc)
// double vector
data9 := [][4]float32{{1.1, 2.1, 3.1, 4.1}, {5.2, 6.2, 7.2, 8.2}}
flushFunc = func(field storage.FieldData) error {
assert.NotNil(t, field)
assert.Equal(t, len(data9), field.RowNum())
for i := 0; i < len(data9); i++ {
row := field.GetRow(i).([]float32)
for k := 0; k < len(row); k++ {
assert.Equal(t, data9[i][k], row[k])
}
}
return nil
}
checkFunc(data9, "field_float_vector", flushFunc)
}
func Test_Parse_perf(t *testing.T) {
ctx := context.Background()
err := os.MkdirAll(TempFilesPath, os.ModePerm)
assert.Nil(t, err)
defer os.RemoveAll(TempFilesPath)
tr := timerecord.NewTimeRecorder("numpy parse performance")
// change the parameter to test performance
rowCount := 10000
dotValue := float32(3.1415926)
const (
dim = 128
)
schema := perfSchema(dim)
data := make([][dim]float32, 0)
for i := 0; i < rowCount; i++ {
var row [dim]float32
for k := 0; k < dim; k++ {
row[k] = float32(i) + dotValue
}
data = append(data, row)
}
tr.Record("generate large data")
flushFunc := func(field storage.FieldData) error {
assert.Equal(t, len(data), field.RowNum())
return nil
}
filePath := TempFilesPath + "perf.npy"
CreateNumpyFile(filePath, data)
tr.Record("generate large numpy file " + filePath)
file, err := os.Open(filePath)
assert.Nil(t, err)
defer file.Close()
parser := NewNumpyParser(ctx, schema, flushFunc)
err = parser.Parse(file, "Vector", false)
assert.Nil(t, err)
tr.Record("parse large numpy files: " + filePath)
}

View File

@ -47,6 +47,7 @@ go test -race -cover "${MILVUS_DIR}/util/retry/..." -failfast
go test -race -cover "${MILVUS_DIR}/util/sessionutil/..." -failfast
go test -race -cover "${MILVUS_DIR}/util/trace/..." -failfast
go test -race -cover "${MILVUS_DIR}/util/typeutil/..." -failfast
go test -race -cover "${MILVUS_DIR}/util/importutil/..." -failfast
# TODO: remove to distributed
#go test -race -cover "${MILVUS_DIR}/proxy/..." -failfast