diff --git a/go.mod b/go.mod index 5cf23a9730..776ae5ddc1 100644 --- a/go.mod +++ b/go.mod @@ -32,6 +32,8 @@ require ( github.com/pierrec/lz4 v2.5.2+incompatible // indirect github.com/pkg/errors v0.9.1 github.com/prometheus/client_golang v1.11.0 + github.com/quasilyte/go-ruleguard v0.2.1 // indirect + github.com/sbinet/npyio v0.6.0 github.com/shirou/gopsutil v3.21.8+incompatible github.com/spaolacci/murmur3 v1.1.0 github.com/spf13/cast v1.3.1 @@ -59,4 +61,5 @@ replace ( github.com/dgrijalva/jwt-go => github.com/golang-jwt/jwt v3.2.2+incompatible // Fix security alert for jwt-go 3.2.0 github.com/keybase/go-keychain => github.com/99designs/go-keychain v0.0.0-20191008050251-8e49817e8af4 google.golang.org/grpc => google.golang.org/grpc v1.38.0 -) \ No newline at end of file +) + diff --git a/go.sum b/go.sum index 1f4b48ff9e..f1a1fbc805 100644 --- a/go.sum +++ b/go.sum @@ -110,6 +110,7 @@ github.com/bketelsen/crypt v0.0.3-0.20200106085610-5cbc8cc4026c/go.mod h1:MKsuJm github.com/bketelsen/crypt v0.0.4/go.mod h1:aI6NrJ0pMGgvZKL1iVgXLnfIFJtfV+bKCoqOes/6LfM= github.com/bmizerany/perks v0.0.0-20141205001514-d9a9656a3a4b/go.mod h1:ac9efd0D1fsDb3EJvhqgXRbFx7bs2wqZ10HQPeU8U/Q= github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= +github.com/campoy/embedmd v1.0.0/go.mod h1:oxyr9RCiSXg0M3VJ3ks0UGfp98BpSSGr0kpiX3MzVl8= github.com/casbin/casbin/v2 v2.1.2/go.mod h1:YcPU1XXisHhLzuxH9coDNf2FbKpjGlbCg3n9yuLkIJQ= github.com/cenkalti/backoff v2.2.1+incompatible/go.mod h1:90ReRw6GdpyfrHakVjL/QHaoyV4aDUVVkXQJJJ3NXXM= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= @@ -584,6 +585,8 @@ github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4O github.com/prometheus/procfs v0.6.0 h1:mxy4L2jP6qMonqmq+aTtOx1ifVWUgG/TAmntgbh3xv4= github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA= github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU= +github.com/quasilyte/go-ruleguard v0.2.1 h1:56eRm0daAyny9UhJnmtJW/UyLZQusukBAB8oT8AHKHo= +github.com/quasilyte/go-ruleguard v0.2.1/go.mod h1:hN2rVc/uS4bQhQKTio2XaSJSafJwqBUWWwtssT3cQmc= github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= github.com/rivo/tview v0.0.0-20200219210816-cd38d7432498/go.mod h1:6lkG1x+13OShEf0EaOCaTQYyB7d5nSbb181KtjlS+84= github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= @@ -598,6 +601,8 @@ github.com/ruudk/golang-pdf417 v0.0.0-20181029194003-1af4ab5afa58/go.mod h1:6lfF github.com/ryanuber/columnize v0.0.0-20160712163229-9b3edd62028f/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts= github.com/samuel/go-zookeeper v0.0.0-20190923202752-2cc03de413da/go.mod h1:gi+0XIa01GRL2eRQVjQkKGqKF3SF9vZR/HnPullcV2E= github.com/sanity-io/litter v1.2.0/go.mod h1:JF6pZUFgu2Q0sBZ+HSV35P8TVPI1TTzEwyu9FXAw2W4= +github.com/sbinet/npyio v0.6.0 h1:IyqqQIzRjDym9xnIXsToCKei/qCzxDP+Y74KoMlMgXo= +github.com/sbinet/npyio v0.6.0/go.mod h1:/q3BNr6dJOy+t6h7RZchTJ0nwRJO52mivaem29WE1j8= github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529/go.mod h1:DxrIzT+xaE7yg65j358z/aeFdxmN0P9QXhEzd20vsDc= github.com/shirou/gopsutil v3.21.8+incompatible h1:sh0foI8tMRlCidUJR+KzqWYWxrkuuPIGiO6Vp+KXdCU= github.com/shirou/gopsutil v3.21.8+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA= @@ -1038,6 +1043,7 @@ golang.org/x/tools v0.0.0-20200618134242-20370b0cb4b2/go.mod h1:EkVYQZoAsY45+roY golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20200729194436-6467de6f59a7/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= golang.org/x/tools v0.0.0-20200804011535-6c149bb5ef0d/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= +golang.org/x/tools v0.0.0-20200812195022-5ae4c3c160a0/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= golang.org/x/tools v0.0.0-20200825202427-b303f430e36d/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= golang.org/x/tools v0.0.0-20200904185747-39188db58858/go.mod h1:Cj7w3i3Rnn0Xh82ur9kSqwfTHTeVxaDqrfMjpcNT6bE= golang.org/x/tools v0.0.0-20201110124207-079ba7bd75cd/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= diff --git a/internal/util/importutil/import_wrapper.go b/internal/util/importutil/import_wrapper.go index 73406ab1c7..3e15824c59 100644 --- a/internal/util/importutil/import_wrapper.go +++ b/internal/util/importutil/import_wrapper.go @@ -7,6 +7,7 @@ import ( "os" "path" "strconv" + "strings" "github.com/milvus-io/milvus/internal/allocator" "github.com/milvus-io/milvus/internal/common" @@ -89,6 +90,13 @@ func (p *ImportWrapper) printFieldsDataInfo(fieldsData map[string]storage.FieldD log.Debug(msg, stats...) } +func getFileNameAndExt(filePath string) (string, string) { + fileName := path.Base(filePath) + fileType := path.Ext(fileName) + fileNameWithoutExt := strings.TrimSuffix(fileName, fileType) + return fileNameWithoutExt, fileType +} + // import process entry // filePath and rowBased are from ImportTask // if onlyValidate is true, this process only do validation, no data generated, callFlushFunc will not be called @@ -99,8 +107,7 @@ func (p *ImportWrapper) Import(filePaths []string, rowBased bool, onlyValidate b // according to shard number, so the callFlushFunc will be called in the JSONRowConsumer for i := 0; i < len(filePaths); i++ { filePath := filePaths[i] - fileName := path.Base(filePath) - fileType := path.Ext(fileName) + _, fileType := getFileNameAndExt(filePath) log.Debug("imprort wrapper: row-based file ", zap.Any("filePath", filePath), zap.Any("fileType", fileType)) if fileType == JSONFileExt { @@ -183,8 +190,7 @@ func (p *ImportWrapper) Import(filePaths []string, rowBased bool, onlyValidate b // parse/validate/consume data for i := 0; i < len(filePaths); i++ { filePath := filePaths[i] - fileName := path.Base(filePath) - fileType := path.Ext(fileName) + fileName, fileType := getFileNameAndExt(filePath) log.Debug("imprort wrapper: column-based file ", zap.Any("filePath", filePath), zap.Any("fileType", fileType)) if fileType == JSONFileExt { @@ -218,7 +224,28 @@ func (p *ImportWrapper) Import(filePaths []string, rowBased bool, onlyValidate b return err } } else if fileType == NumpyFileExt { + file, err := os.Open(filePath) + if err != nil { + log.Error("imprort error: "+err.Error(), zap.String("filePath", filePath)) + return err + } + defer file.Close() + // the numpy parser return a storage.FieldData, here construct a map[string]storage.FieldData to combine + flushFunc := func(field storage.FieldData) error { + fields := make(map[string]storage.FieldData) + fields[fileName] = field + combineFunc(fields) + return nil + } + + // for numpy file, we say the file name(without extension) is the filed name + parser := NewNumpyParser(p.ctx, p.collectionSchema, flushFunc) + err = parser.Parse(file, fileName, onlyValidate) + if err != nil { + log.Error("imprort error: "+err.Error(), zap.String("filePath", filePath)) + return err + } } } diff --git a/internal/util/importutil/import_wrapper_test.go b/internal/util/importutil/import_wrapper_test.go index 33e9d5df96..b9d4d40359 100644 --- a/internal/util/importutil/import_wrapper_test.go +++ b/internal/util/importutil/import_wrapper_test.go @@ -126,7 +126,7 @@ func Test_ImportRowBased(t *testing.T) { } -func Test_ImportColumnBased(t *testing.T) { +func Test_ImportColumnBased_json(t *testing.T) { ctx := context.Background() err := os.MkdirAll(TempFilesPath, os.ModePerm) assert.Nil(t, err) @@ -208,6 +208,88 @@ func Test_ImportColumnBased(t *testing.T) { assert.NotNil(t, err) } +func Test_ImportColumnBased_numpy(t *testing.T) { + ctx := context.Background() + err := os.MkdirAll(TempFilesPath, os.ModePerm) + assert.Nil(t, err) + defer os.RemoveAll(TempFilesPath) + + idAllocator := newIDAllocator(ctx, t) + + content := []byte(`{ + "field_bool": [true, false, true, true, true], + "field_int8": [10, 11, 12, 13, 14], + "field_int16": [100, 101, 102, 103, 104], + "field_int32": [1000, 1001, 1002, 1003, 1004], + "field_int64": [10000, 10001, 10002, 10003, 10004], + "field_float": [3.14, 3.15, 3.16, 3.17, 3.18], + "field_double": [5.1, 5.2, 5.3, 5.4, 5.5], + "field_string": ["a", "b", "c", "d", "e"] + }`) + + files := make([]string, 0) + + filePath := TempFilesPath + "scalar_fields.json" + fp1 := saveFile(t, filePath, content) + fp1.Close() + files = append(files, filePath) + + filePath = TempFilesPath + "field_binary_vector.npy" + bin := [][2]uint8{{1, 2}, {3, 4}, {5, 6}, {7, 8}, {9, 10}} + err = CreateNumpyFile(filePath, bin) + assert.Nil(t, err) + files = append(files, filePath) + + filePath = TempFilesPath + "field_float_vector.npy" + flo := [][4]float32{{1, 2, 3, 4}, {3, 4, 5, 6}, {5, 6, 7, 8}, {7, 8, 9, 10}, {9, 10, 11, 12}} + err = CreateNumpyFile(filePath, flo) + assert.Nil(t, err) + files = append(files, filePath) + + rowCount := 0 + flushFunc := func(fields map[string]storage.FieldData) error { + count := 0 + for _, data := range fields { + assert.Less(t, 0, data.RowNum()) + if count == 0 { + count = data.RowNum() + } else { + assert.Equal(t, count, data.RowNum()) + } + } + rowCount += count + return nil + } + + // success case + wrapper := NewImportWrapper(ctx, sampleSchema(), 2, 1, idAllocator, flushFunc) + + err = wrapper.Import(files, false, false) + assert.Nil(t, err) + assert.Equal(t, 5, rowCount) + + // parse error + content = []byte(`{ + "field_bool": [true, false, true, true, true] + }`) + + filePath = TempFilesPath + "rows_2.json" + fp2 := saveFile(t, filePath, content) + defer fp2.Close() + + wrapper = NewImportWrapper(ctx, sampleSchema(), 2, 1, idAllocator, flushFunc) + files = make([]string, 0) + files = append(files, filePath) + err = wrapper.Import(files, false, false) + assert.NotNil(t, err) + + // file doesn't exist + files = make([]string, 0) + files = append(files, "/dummy/dummy.json") + err = wrapper.Import(files, false, false) + assert.NotNil(t, err) +} + func perfSchema(dim int) *schemapb.CollectionSchema { schema := &schemapb.CollectionSchema{ Name: "schema", diff --git a/internal/util/importutil/json_parser.go b/internal/util/importutil/json_parser.go index 9836027bbe..cf0d67738f 100644 --- a/internal/util/importutil/json_parser.go +++ b/internal/util/importutil/json_parser.go @@ -24,7 +24,7 @@ type JSONParser struct { fields map[string]int64 // fields need to be parsed } -// newImportManager helper function to create a importManager +// NewJSONParser helper function to create a JSONParser func NewJSONParser(ctx context.Context, collectionSchema *schemapb.CollectionSchema) *JSONParser { fields := make(map[string]int64) for i := 0; i < len(collectionSchema.Fields); i++ { diff --git a/internal/util/importutil/numpy_adapter.go b/internal/util/importutil/numpy_adapter.go new file mode 100644 index 0000000000..7e03f58275 --- /dev/null +++ b/internal/util/importutil/numpy_adapter.go @@ -0,0 +1,356 @@ +package importutil + +import ( + "encoding/binary" + "errors" + "io" + "os" + + "github.com/sbinet/npyio" + "github.com/sbinet/npyio/npy" +) + +func CreateNumpyFile(path string, data interface{}) error { + f, err := os.Create(path) + if err != nil { + return err + } + defer f.Close() + + err = npyio.Write(f, data) + if err != nil { + return err + } + + return nil +} + +// a class to expand other numpy lib ability +// we evaluate two go-numpy lins: github.com/kshedden/gonpy and github.com/sbinet/npyio +// the npyio lib read data one by one, the performance is poor, we expand the read methods +// to read data in one batch, the performance is 100X faster +// the gonpy lib also read data in one batch, but it has no method to read bool data, and the ability +// to handle different data type is not strong as the npylib, so we choose the npyio lib to expand. +type NumpyAdapter struct { + reader io.Reader // data source, typically is os.File + npyReader *npy.Reader // reader of npyio lib + order binary.ByteOrder // LittleEndian or BigEndian + readPosition int // how many elements have been read +} + +func NewNumpyAdapter(reader io.Reader) (*NumpyAdapter, error) { + r, err := npyio.NewReader(reader) + if err != nil { + return nil, err + } + adapter := &NumpyAdapter{ + reader: reader, + npyReader: r, + readPosition: 0, + } + adapter.setByteOrder() + + return adapter, err +} + +// the logic of this method is copied from npyio lib +func (n *NumpyAdapter) setByteOrder() { + var nativeEndian binary.ByteOrder + v := uint16(1) + switch byte(v >> 8) { + case 0: + nativeEndian = binary.LittleEndian + case 1: + nativeEndian = binary.BigEndian + } + + switch n.npyReader.Header.Descr.Type[0] { + case '<': + n.order = binary.LittleEndian + case '>': + n.order = binary.BigEndian + default: + n.order = nativeEndian + } +} + +func (n *NumpyAdapter) Reader() io.Reader { + return n.reader +} + +func (n *NumpyAdapter) NpyReader() *npy.Reader { + return n.npyReader +} + +func (n *NumpyAdapter) GetType() string { + return n.npyReader.Header.Descr.Type +} + +func (n *NumpyAdapter) GetShape() []int { + return n.npyReader.Header.Descr.Shape +} + +func (n *NumpyAdapter) checkSize(size int) int { + shape := n.GetShape() + + // empty file? + if len(shape) == 0 { + return 0 + } + + total := 1 + for i := 0; i < len(shape); i++ { + total *= shape[i] + } + + if total == 0 { + return 0 + } + + // overflow? + if size > (total - n.readPosition) { + return total - n.readPosition + } + + return size +} + +func (n *NumpyAdapter) ReadBool(size int) ([]bool, error) { + if n.npyReader == nil { + return nil, errors.New("reader is not initialized") + } + + // incorrect type + switch n.npyReader.Header.Descr.Type { + case "b1", "i1", "int8": + default: + return nil, errors.New("numpy data is not int8 type") + } + + // avoid read overflow + readSize := n.checkSize(size) + if readSize <= 0 { + return nil, errors.New("nothing to read") + } + + data := make([]int8, readSize) + err := binary.Read(n.reader, n.order, &data) + if err != nil { + return nil, err + } + + // update read position after successfully read + n.readPosition += readSize + + return data, nil +} + +func (n *NumpyAdapter) ReadInt16(size int) ([]int16, error) { + if n.npyReader == nil { + return nil, errors.New("reader is not initialized") + } + + // incorrect type + switch n.npyReader.Header.Descr.Type { + case "i2", "i2", "int16": + default: + return nil, errors.New("numpy data is not int16 type") + } + + // avoid read overflow + readSize := n.checkSize(size) + if readSize <= 0 { + return nil, errors.New("nothing to read") + } + + data := make([]int16, readSize) + err := binary.Read(n.reader, n.order, &data) + if err != nil { + return nil, err + } + + // update read position after successfully read + n.readPosition += readSize + + return data, nil +} + +func (n *NumpyAdapter) ReadInt32(size int) ([]int32, error) { + if n.npyReader == nil { + return nil, errors.New("reader is not initialized") + } + + // incorrect type + switch n.npyReader.Header.Descr.Type { + case "i4", "i4", "int32": + default: + return nil, errors.New("numpy data is not int32 type") + } + + // avoid read overflow + readSize := n.checkSize(size) + if readSize <= 0 { + return nil, errors.New("nothing to read") + } + + data := make([]int32, readSize) + err := binary.Read(n.reader, n.order, &data) + if err != nil { + return nil, err + } + + // update read position after successfully read + n.readPosition += readSize + + return data, nil +} + +func (n *NumpyAdapter) ReadInt64(size int) ([]int64, error) { + if n.npyReader == nil { + return nil, errors.New("reader is not initialized") + } + + // incorrect type + switch n.npyReader.Header.Descr.Type { + case "i8", "i8", "int64": + default: + return nil, errors.New("numpy data is not int64 type") + } + + // avoid read overflow + readSize := n.checkSize(size) + if readSize <= 0 { + return nil, errors.New("nothing to read") + } + + data := make([]int64, readSize) + err := binary.Read(n.reader, n.order, &data) + if err != nil { + return nil, err + } + + // update read position after successfully read + n.readPosition += readSize + + return data, nil +} + +func (n *NumpyAdapter) ReadFloat32(size int) ([]float32, error) { + if n.npyReader == nil { + return nil, errors.New("reader is not initialized") + } + + // incorrect type + switch n.npyReader.Header.Descr.Type { + case "f4", "f4", "float32": + default: + return nil, errors.New("numpy data is not float32 type") + } + + // avoid read overflow + readSize := n.checkSize(size) + if readSize <= 0 { + return nil, errors.New("nothing to read") + } + + data := make([]float32, readSize) + err := binary.Read(n.reader, n.order, &data) + if err != nil { + return nil, err + } + + // update read position after successfully read + n.readPosition += readSize + + return data, nil +} + +func (n *NumpyAdapter) ReadFloat64(size int) ([]float64, error) { + if n.npyReader == nil { + return nil, errors.New("reader is not initialized") + } + + // incorrect type + switch n.npyReader.Header.Descr.Type { + case "f8", "f8", "float64": + default: + return nil, errors.New("numpy data is not float32 type") + } + + // avoid read overflow + readSize := n.checkSize(size) + if readSize <= 0 { + return nil, errors.New("nothing to read") + } + + data := make([]float64, readSize) + err := binary.Read(n.reader, n.order, &data) + if err != nil { + return nil, err + } + + // update read position after successfully read + n.readPosition += readSize + + return data, nil +} diff --git a/internal/util/importutil/numpy_adapter_test.go b/internal/util/importutil/numpy_adapter_test.go new file mode 100644 index 0000000000..f91eb04385 --- /dev/null +++ b/internal/util/importutil/numpy_adapter_test.go @@ -0,0 +1,455 @@ +package importutil + +import ( + "encoding/binary" + "io" + "os" + "testing" + + "github.com/sbinet/npyio/npy" + "github.com/stretchr/testify/assert" +) + +type MockReader struct { +} + +func (r *MockReader) Read(p []byte) (n int, err error) { + return 0, io.EOF +} + +func Test_CreateNumpyFile(t *testing.T) { + // directory doesn't exist + data1 := []float32{1, 2, 3, 4, 5} + err := CreateNumpyFile("/dummy_not_exist/dummy.npy", data1) + assert.NotNil(t, err) + + // invalid data type + data2 := make(map[string]int) + err = CreateNumpyFile("/tmp/dummy.npy", data2) + assert.NotNil(t, err) +} + +func Test_SetByteOrder(t *testing.T) { + adapter := &NumpyAdapter{ + reader: nil, + npyReader: &npy.Reader{}, + } + assert.Nil(t, adapter.Reader()) + assert.NotNil(t, adapter.NpyReader()) + + adapter.npyReader.Header.Descr.Type = "i1", "int8": + return schemapb.DataType_Int8, nil + case "i2", "i2", "int16": + return schemapb.DataType_Int16, nil + case "i4", "i4", "int32": + return schemapb.DataType_Int32, nil + case "i8", "i8", "int64": + return schemapb.DataType_Int64, nil + case "f4", "f4", "float32": + return schemapb.DataType_Float, nil + case "f8", "f8", "float64": + return schemapb.DataType_Double, nil + default: + return schemapb.DataType_None, errors.New("unsupported data type " + str) + } +} + +func (p *NumpyParser) validate(adapter *NumpyAdapter, fieldName string) error { + if adapter == nil { + return errors.New("numpy adapter is nil") + } + + // check existence of the target field + var schema *schemapb.FieldSchema + for i := 0; i < len(p.collectionSchema.Fields); i++ { + schema = p.collectionSchema.Fields[i] + if schema.GetName() == fieldName { + p.columnDesc.name = fieldName + break + } + } + + if p.columnDesc.name == "" { + return errors.New("the field " + fieldName + " doesn't exist") + } + + p.columnDesc.dt = schema.DataType + elementType, err := convertNumpyType(adapter.GetType()) + if err != nil { + return err + } + + shape := adapter.GetShape() + + // 1. field data type should be consist to numpy data type + // 2. vector field dimension should be consist to numpy shape + if schemapb.DataType_FloatVector == schema.DataType { + if elementType != schemapb.DataType_Float { + return errors.New("illegal data type " + adapter.GetType() + " for field " + schema.GetName()) + } + + // vector field, the shape should be 2 + if len(shape) != 2 { + return errors.New("illegal numpy shape " + strconv.Itoa(len(shape)) + " for field " + schema.GetName()) + } + + // shape[0] is row count, shape[1] is element count per row + p.columnDesc.elementCount = shape[0] * shape[1] + + p.columnDesc.dimension, err = getFieldDimension(schema) + if err != nil { + return err + } + + if shape[1] != p.columnDesc.dimension { + return errors.New("illegal row width " + strconv.Itoa(shape[1]) + " for field " + schema.GetName() + " dimension " + strconv.Itoa(p.columnDesc.dimension)) + } + } else if schemapb.DataType_BinaryVector == schema.DataType { + if elementType != schemapb.DataType_BinaryVector { + return errors.New("illegal data type " + adapter.GetType() + " for field " + schema.GetName()) + } + + // vector field, the shape should be 2 + if len(shape) != 2 { + return errors.New("illegal numpy shape " + strconv.Itoa(len(shape)) + " for field " + schema.GetName()) + } + + // shape[0] is row count, shape[1] is element count per row + p.columnDesc.elementCount = shape[0] * shape[1] + + p.columnDesc.dimension, err = getFieldDimension(schema) + if err != nil { + return err + } + + if shape[1] != p.columnDesc.dimension/8 { + return errors.New("illegal row width " + strconv.Itoa(shape[1]) + " for field " + schema.GetName() + " dimension " + strconv.Itoa(p.columnDesc.dimension)) + } + } else { + if elementType != schema.DataType { + return errors.New("illegal data type " + adapter.GetType() + " for field " + schema.GetName()) + } + + // scalar field, the shape should be 1 + if len(shape) != 1 { + return errors.New("illegal numpy shape " + strconv.Itoa(len(shape)) + " for field " + schema.GetName()) + } + + p.columnDesc.elementCount = shape[0] + } + + return nil +} + +// this method read numpy data section into a storage.FieldData +// please note it will require a large memory block(the memory size is almost equal to numpy file size) +func (p *NumpyParser) consume(adapter *NumpyAdapter) error { + switch p.columnDesc.dt { + case schemapb.DataType_Bool: + data, err := adapter.ReadBool(p.columnDesc.elementCount) + if err != nil { + return err + } + + p.columnData = &storage.BoolFieldData{ + NumRows: []int64{int64(p.columnDesc.elementCount)}, + Data: data, + } + + case schemapb.DataType_Int8: + data, err := adapter.ReadInt8(p.columnDesc.elementCount) + if err != nil { + return err + } + + p.columnData = &storage.Int8FieldData{ + NumRows: []int64{int64(p.columnDesc.elementCount)}, + Data: data, + } + case schemapb.DataType_Int16: + data, err := adapter.ReadInt16(p.columnDesc.elementCount) + if err != nil { + return err + } + + p.columnData = &storage.Int16FieldData{ + NumRows: []int64{int64(p.columnDesc.elementCount)}, + Data: data, + } + case schemapb.DataType_Int32: + data, err := adapter.ReadInt32(p.columnDesc.elementCount) + if err != nil { + return err + } + + p.columnData = &storage.Int32FieldData{ + NumRows: []int64{int64(p.columnDesc.elementCount)}, + Data: data, + } + case schemapb.DataType_Int64: + data, err := adapter.ReadInt64(p.columnDesc.elementCount) + if err != nil { + return err + } + + p.columnData = &storage.Int64FieldData{ + NumRows: []int64{int64(p.columnDesc.elementCount)}, + Data: data, + } + case schemapb.DataType_Float: + data, err := adapter.ReadFloat32(p.columnDesc.elementCount) + if err != nil { + return err + } + + p.columnData = &storage.FloatFieldData{ + NumRows: []int64{int64(p.columnDesc.elementCount)}, + Data: data, + } + case schemapb.DataType_Double: + data, err := adapter.ReadFloat64(p.columnDesc.elementCount) + if err != nil { + return err + } + + p.columnData = &storage.DoubleFieldData{ + NumRows: []int64{int64(p.columnDesc.elementCount)}, + Data: data, + } + case schemapb.DataType_BinaryVector: + data, err := adapter.ReadUint8(p.columnDesc.elementCount) + if err != nil { + return err + } + + p.columnData = &storage.BinaryVectorFieldData{ + NumRows: []int64{int64(p.columnDesc.elementCount)}, + Data: data, + Dim: p.columnDesc.dimension, + } + case schemapb.DataType_FloatVector: + data, err := adapter.ReadFloat32(p.columnDesc.elementCount) + if err != nil { + return err + } + + p.columnData = &storage.FloatVectorFieldData{ + NumRows: []int64{int64(p.columnDesc.elementCount)}, + Data: data, + Dim: p.columnDesc.dimension, + } + default: + return errors.New("unsupported data type: " + strconv.Itoa(int(p.columnDesc.dt))) + } + + return nil +} + +func (p *NumpyParser) Parse(reader io.Reader, fieldName string, onlyValidate bool) error { + adapter, err := NewNumpyAdapter(reader) + if err != nil { + return p.logError("Numpy parse: " + err.Error()) + } + + // the validation method only check the file header information + err = p.validate(adapter, fieldName) + if err != nil { + return p.logError("Numpy parse: " + err.Error()) + } + + if onlyValidate { + return nil + } + + // read all data from the numpy file + err = p.consume(adapter) + if err != nil { + return p.logError("Numpy parse: " + err.Error()) + } + + return p.callFlushFunc(p.columnData) +} diff --git a/internal/util/importutil/numpy_parser_test.go b/internal/util/importutil/numpy_parser_test.go new file mode 100644 index 0000000000..4bb1207486 --- /dev/null +++ b/internal/util/importutil/numpy_parser_test.go @@ -0,0 +1,509 @@ +package importutil + +import ( + "context" + "os" + "testing" + + "github.com/milvus-io/milvus/internal/proto/schemapb" + "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/internal/util/timerecord" + "github.com/sbinet/npyio/npy" + "github.com/stretchr/testify/assert" +) + +func Test_NewNumpyParser(t *testing.T) { + ctx := context.Background() + + parser := NewNumpyParser(ctx, nil, nil) + assert.Nil(t, parser) +} + +func Test_ConvertNumpyType(t *testing.T) { + checkFunc := func(inputs []string, output schemapb.DataType) { + for i := 0; i < len(inputs); i++ { + dt, err := convertNumpyType(inputs[i]) + assert.Nil(t, err) + assert.Equal(t, output, dt) + } + } + + checkFunc([]string{"b1", "i1", "int8"}, schemapb.DataType_Int8) + checkFunc([]string{"i2", "i2", "int16"}, schemapb.DataType_Int16) + checkFunc([]string{"i4", "i4", "int32"}, schemapb.DataType_Int32) + checkFunc([]string{"i8", "i8", "int64"}, schemapb.DataType_Int64) + checkFunc([]string{"f4", "f4", "float32"}, schemapb.DataType_Float) + checkFunc([]string{"f8", "f8", "float64"}, schemapb.DataType_Double) + + dt, err := convertNumpyType("dummy") + assert.NotNil(t, err) + assert.Equal(t, schemapb.DataType_None, dt) +} + +func Test_Validate(t *testing.T) { + ctx := context.Background() + err := os.MkdirAll(TempFilesPath, os.ModePerm) + assert.Nil(t, err) + defer os.RemoveAll(TempFilesPath) + + schema := sampleSchema() + flushFunc := func(field storage.FieldData) error { + return nil + } + + adapter := &NumpyAdapter{npyReader: &npy.Reader{}} + + { + // string type is not supported + p := NewNumpyParser(ctx, &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + FieldID: 109, + Name: "field_string", + IsPrimaryKey: false, + Description: "string", + DataType: schemapb.DataType_String, + }, + }, + }, flushFunc) + err = p.validate(adapter, "dummy") + assert.NotNil(t, err) + err = p.validate(adapter, "field_string") + assert.NotNil(t, err) + } + + // reader is nil + parser := NewNumpyParser(ctx, schema, flushFunc) + err = parser.validate(nil, "") + assert.NotNil(t, err) + + // validate scalar data + func() { + filePath := TempFilesPath + "scalar_1.npy" + data1 := []float64{0, 1, 2, 3, 4, 5} + CreateNumpyFile(filePath, data1) + + file1, err := os.Open(filePath) + assert.Nil(t, err) + defer file1.Close() + + adapter, err := NewNumpyAdapter(file1) + assert.Nil(t, err) + + err = parser.validate(adapter, "field_double") + assert.Nil(t, err) + assert.Equal(t, len(data1), parser.columnDesc.elementCount) + + err = parser.validate(adapter, "") + assert.NotNil(t, err) + + // data type mismatch + filePath = TempFilesPath + "scalar_2.npy" + data2 := []int64{0, 1, 2, 3, 4, 5} + CreateNumpyFile(filePath, data2) + + file2, err := os.Open(filePath) + assert.Nil(t, err) + defer file2.Close() + + adapter, err = NewNumpyAdapter(file2) + assert.Nil(t, err) + + err = parser.validate(adapter, "field_double") + assert.NotNil(t, err) + + // shape mismatch + filePath = TempFilesPath + "scalar_2.npy" + data3 := [][2]float64{{1, 1}} + CreateNumpyFile(filePath, data3) + + file3, err := os.Open(filePath) + assert.Nil(t, err) + defer file2.Close() + + adapter, err = NewNumpyAdapter(file3) + assert.Nil(t, err) + + err = parser.validate(adapter, "field_double") + assert.NotNil(t, err) + }() + + // validate binary vector data + func() { + filePath := TempFilesPath + "binary_vector_1.npy" + data1 := [][2]uint8{{0, 1}, {2, 3}, {4, 5}} + CreateNumpyFile(filePath, data1) + + file1, err := os.Open(filePath) + assert.Nil(t, err) + defer file1.Close() + + adapter, err := NewNumpyAdapter(file1) + assert.Nil(t, err) + + err = parser.validate(adapter, "field_binary_vector") + assert.Nil(t, err) + assert.Equal(t, len(data1)*len(data1[0]), parser.columnDesc.elementCount) + + // data type mismatch + filePath = TempFilesPath + "binary_vector_2.npy" + data2 := [][2]uint16{{0, 1}, {2, 3}, {4, 5}} + CreateNumpyFile(filePath, data2) + + file2, err := os.Open(filePath) + assert.Nil(t, err) + defer file2.Close() + + adapter, err = NewNumpyAdapter(file2) + assert.Nil(t, err) + + err = parser.validate(adapter, "field_binary_vector") + assert.NotNil(t, err) + + // shape mismatch + filePath = TempFilesPath + "binary_vector_3.npy" + data3 := []uint8{1, 2, 3} + CreateNumpyFile(filePath, data3) + + file3, err := os.Open(filePath) + assert.Nil(t, err) + defer file3.Close() + + adapter, err = NewNumpyAdapter(file3) + assert.Nil(t, err) + + err = parser.validate(adapter, "field_binary_vector") + assert.NotNil(t, err) + + // shape[1] mismatch + filePath = TempFilesPath + "binary_vector_4.npy" + data4 := [][3]uint8{{0, 1, 2}, {2, 3, 4}, {4, 5, 6}} + CreateNumpyFile(filePath, data4) + + file4, err := os.Open(filePath) + assert.Nil(t, err) + defer file4.Close() + + adapter, err = NewNumpyAdapter(file4) + assert.Nil(t, err) + + err = parser.validate(adapter, "field_binary_vector") + assert.NotNil(t, err) + + // dimension mismatch + p := NewNumpyParser(ctx, &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + FieldID: 109, + Name: "field_binary_vector", + DataType: schemapb.DataType_BinaryVector, + }, + }, + }, flushFunc) + + err = p.validate(adapter, "field_binary_vector") + assert.NotNil(t, err) + }() + + // validate float vector data + func() { + filePath := TempFilesPath + "float_vector.npy" + data1 := [][4]float32{{0, 0, 0, 0}, {1, 1, 1, 1}, {2, 2, 2, 2}, {3, 3, 3, 3}} + CreateNumpyFile(filePath, data1) + + file1, err := os.Open(filePath) + assert.Nil(t, err) + defer file1.Close() + + adapter, err := NewNumpyAdapter(file1) + assert.Nil(t, err) + + err = parser.validate(adapter, "field_float_vector") + assert.Nil(t, err) + assert.Equal(t, len(data1)*len(data1[0]), parser.columnDesc.elementCount) + + // data type mismatch + filePath = TempFilesPath + "float_vector_2.npy" + data2 := [][4]int32{{0, 1, 2, 3}} + CreateNumpyFile(filePath, data2) + + file2, err := os.Open(filePath) + assert.Nil(t, err) + defer file2.Close() + + adapter, err = NewNumpyAdapter(file2) + assert.Nil(t, err) + + err = parser.validate(adapter, "field_float_vector") + assert.NotNil(t, err) + + // shape mismatch + filePath = TempFilesPath + "float_vector_3.npy" + data3 := []float32{1, 2, 3} + CreateNumpyFile(filePath, data3) + + file3, err := os.Open(filePath) + assert.Nil(t, err) + defer file3.Close() + + adapter, err = NewNumpyAdapter(file3) + assert.Nil(t, err) + + err = parser.validate(adapter, "field_float_vector") + assert.NotNil(t, err) + + // shape[1] mismatch + filePath = TempFilesPath + "float_vector_4.npy" + data4 := [][3]float32{{0, 0, 0}, {1, 1, 1}} + CreateNumpyFile(filePath, data4) + + file4, err := os.Open(filePath) + assert.Nil(t, err) + defer file4.Close() + + adapter, err = NewNumpyAdapter(file4) + assert.Nil(t, err) + + err = parser.validate(adapter, "field_float_vector") + assert.NotNil(t, err) + + // dimension mismatch + p := NewNumpyParser(ctx, &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + FieldID: 109, + Name: "field_float_vector", + DataType: schemapb.DataType_FloatVector, + }, + }, + }, flushFunc) + + err = p.validate(adapter, "field_float_vector") + assert.NotNil(t, err) + }() +} + +func Test_Parse(t *testing.T) { + ctx := context.Background() + err := os.MkdirAll(TempFilesPath, os.ModePerm) + assert.Nil(t, err) + defer os.RemoveAll(TempFilesPath) + + schema := sampleSchema() + + checkFunc := func(data interface{}, fieldName string, callback func(field storage.FieldData) error) { + + filePath := TempFilesPath + fieldName + ".npy" + CreateNumpyFile(filePath, data) + + func() { + file, err := os.Open(filePath) + assert.Nil(t, err) + defer file.Close() + + parser := NewNumpyParser(ctx, schema, callback) + err = parser.Parse(file, fieldName, false) + assert.Nil(t, err) + }() + + // validation failed + func() { + file, err := os.Open(filePath) + assert.Nil(t, err) + defer file.Close() + + parser := NewNumpyParser(ctx, schema, callback) + err = parser.Parse(file, "dummy", false) + assert.NotNil(t, err) + }() + + // read data error + func() { + parser := NewNumpyParser(ctx, schema, callback) + err = parser.Parse(&MockReader{}, fieldName, false) + assert.NotNil(t, err) + }() + } + + // scalar bool + data1 := []bool{true, false, true, false, true} + flushFunc := func(field storage.FieldData) error { + assert.NotNil(t, field) + assert.Equal(t, len(data1), field.RowNum()) + + for i := 0; i < len(data1); i++ { + assert.Equal(t, data1[i], field.GetRow(i)) + } + + return nil + } + checkFunc(data1, "field_bool", flushFunc) + + // scalar int8 + data2 := []int8{1, 2, 3, 4, 5} + flushFunc = func(field storage.FieldData) error { + assert.NotNil(t, field) + assert.Equal(t, len(data2), field.RowNum()) + + for i := 0; i < len(data2); i++ { + assert.Equal(t, data2[i], field.GetRow(i)) + } + + return nil + } + checkFunc(data2, "field_int8", flushFunc) + + // scalar int16 + data3 := []int16{1, 2, 3, 4, 5} + flushFunc = func(field storage.FieldData) error { + assert.NotNil(t, field) + assert.Equal(t, len(data3), field.RowNum()) + + for i := 0; i < len(data3); i++ { + assert.Equal(t, data3[i], field.GetRow(i)) + } + + return nil + } + checkFunc(data3, "field_int16", flushFunc) + + // scalar int32 + data4 := []int32{1, 2, 3, 4, 5} + flushFunc = func(field storage.FieldData) error { + assert.NotNil(t, field) + assert.Equal(t, len(data4), field.RowNum()) + + for i := 0; i < len(data4); i++ { + assert.Equal(t, data4[i], field.GetRow(i)) + } + + return nil + } + checkFunc(data4, "field_int32", flushFunc) + + // scalar int64 + data5 := []int64{1, 2, 3, 4, 5} + flushFunc = func(field storage.FieldData) error { + assert.NotNil(t, field) + assert.Equal(t, len(data5), field.RowNum()) + + for i := 0; i < len(data5); i++ { + assert.Equal(t, data5[i], field.GetRow(i)) + } + + return nil + } + checkFunc(data5, "field_int64", flushFunc) + + // scalar float + data6 := []float32{1, 2, 3, 4, 5} + flushFunc = func(field storage.FieldData) error { + assert.NotNil(t, field) + assert.Equal(t, len(data6), field.RowNum()) + + for i := 0; i < len(data6); i++ { + assert.Equal(t, data6[i], field.GetRow(i)) + } + + return nil + } + checkFunc(data6, "field_float", flushFunc) + + // scalar double + data7 := []float64{1, 2, 3, 4, 5} + flushFunc = func(field storage.FieldData) error { + assert.NotNil(t, field) + assert.Equal(t, len(data7), field.RowNum()) + + for i := 0; i < len(data7); i++ { + assert.Equal(t, data7[i], field.GetRow(i)) + } + + return nil + } + checkFunc(data7, "field_double", flushFunc) + + // binary vector + data8 := [][2]uint8{{1, 2}, {3, 4}, {5, 6}} + flushFunc = func(field storage.FieldData) error { + assert.NotNil(t, field) + assert.Equal(t, len(data8), field.RowNum()) + + for i := 0; i < len(data8); i++ { + row := field.GetRow(i).([]uint8) + for k := 0; k < len(row); k++ { + assert.Equal(t, data8[i][k], row[k]) + } + } + + return nil + } + checkFunc(data8, "field_binary_vector", flushFunc) + + // double vector + data9 := [][4]float32{{1.1, 2.1, 3.1, 4.1}, {5.2, 6.2, 7.2, 8.2}} + flushFunc = func(field storage.FieldData) error { + assert.NotNil(t, field) + assert.Equal(t, len(data9), field.RowNum()) + + for i := 0; i < len(data9); i++ { + row := field.GetRow(i).([]float32) + for k := 0; k < len(row); k++ { + assert.Equal(t, data9[i][k], row[k]) + } + } + + return nil + } + checkFunc(data9, "field_float_vector", flushFunc) +} + +func Test_Parse_perf(t *testing.T) { + ctx := context.Background() + err := os.MkdirAll(TempFilesPath, os.ModePerm) + assert.Nil(t, err) + defer os.RemoveAll(TempFilesPath) + + tr := timerecord.NewTimeRecorder("numpy parse performance") + + // change the parameter to test performance + rowCount := 10000 + dotValue := float32(3.1415926) + const ( + dim = 128 + ) + + schema := perfSchema(dim) + + data := make([][dim]float32, 0) + for i := 0; i < rowCount; i++ { + var row [dim]float32 + for k := 0; k < dim; k++ { + row[k] = float32(i) + dotValue + } + data = append(data, row) + } + + tr.Record("generate large data") + + flushFunc := func(field storage.FieldData) error { + assert.Equal(t, len(data), field.RowNum()) + return nil + } + + filePath := TempFilesPath + "perf.npy" + CreateNumpyFile(filePath, data) + + tr.Record("generate large numpy file " + filePath) + + file, err := os.Open(filePath) + assert.Nil(t, err) + defer file.Close() + + parser := NewNumpyParser(ctx, schema, flushFunc) + err = parser.Parse(file, "Vector", false) + assert.Nil(t, err) + + tr.Record("parse large numpy files: " + filePath) +} diff --git a/scripts/run_go_unittest.sh b/scripts/run_go_unittest.sh index d13d1073eb..7bf411daae 100755 --- a/scripts/run_go_unittest.sh +++ b/scripts/run_go_unittest.sh @@ -47,6 +47,7 @@ go test -race -cover "${MILVUS_DIR}/util/retry/..." -failfast go test -race -cover "${MILVUS_DIR}/util/sessionutil/..." -failfast go test -race -cover "${MILVUS_DIR}/util/trace/..." -failfast go test -race -cover "${MILVUS_DIR}/util/typeutil/..." -failfast +go test -race -cover "${MILVUS_DIR}/util/importutil/..." -failfast # TODO: remove to distributed #go test -race -cover "${MILVUS_DIR}/proxy/..." -failfast