milvus/tests/integration/sparse/sparse_test.go
Buqian Zheng f4a91e135b
enhance: Allow empty sparse row (#34700)
issue: #29419

* If a sparse vector with 0 non-zero value is inserted, no ANN search on
this sparse vector field will return it as a result. User may retrieve
this row via scalar query or ANN search on another vector field though.
* If the user uses an empty sparse vector as the query vector for a ANN
search, no neighbor will be returned.

Signed-off-by: Buqian Zheng <zhengbuqian@gmail.com>
2024-08-16 14:14:54 +08:00

534 lines
16 KiB
Go

// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sparse_test
import (
"context"
"encoding/binary"
"fmt"
"testing"
"github.com/stretchr/testify/suite"
"go.uber.org/zap"
"google.golang.org/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/metric"
"github.com/milvus-io/milvus/pkg/util/typeutil"
"github.com/milvus-io/milvus/tests/integration"
)
type SparseTestSuite struct {
integration.MiniClusterSuite
}
func (s *SparseTestSuite) createCollection(ctx context.Context, c *integration.MiniClusterV2, dbName string) string {
collectionName := "TestSparse" + funcutil.GenRandomStr()
pk := &schemapb.FieldSchema{
FieldID: 100,
Name: integration.Int64Field,
IsPrimaryKey: true,
Description: "",
DataType: schemapb.DataType_Int64,
TypeParams: nil,
IndexParams: nil,
AutoID: true,
}
fVec := &schemapb.FieldSchema{
FieldID: 101,
Name: integration.SparseFloatVecField,
IsPrimaryKey: false,
Description: "",
DataType: schemapb.DataType_SparseFloatVector,
TypeParams: nil,
IndexParams: nil,
}
schema := &schemapb.CollectionSchema{
Name: collectionName,
AutoID: true,
Fields: []*schemapb.FieldSchema{pk, fVec},
}
marshaledSchema, err := proto.Marshal(schema)
s.NoError(err)
createCollectionStatus, err := c.Proxy.CreateCollection(ctx, &milvuspb.CreateCollectionRequest{
DbName: dbName,
CollectionName: collectionName,
Schema: marshaledSchema,
ShardsNum: common.DefaultShardsNum,
})
s.NoError(err)
s.Equal(createCollectionStatus.GetErrorCode(), commonpb.ErrorCode_Success)
log.Info("CreateCollection result", zap.Any("createCollectionStatus", createCollectionStatus))
showCollectionsResp, err := c.Proxy.ShowCollections(ctx, &milvuspb.ShowCollectionsRequest{})
s.NoError(err)
s.Equal(showCollectionsResp.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success)
log.Info("ShowCollections result", zap.Any("showCollectionsResp", showCollectionsResp))
return collectionName
}
func (s *SparseTestSuite) TestSparse_should_not_speficy_dim() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
c := s.Cluster
const (
dbName = ""
rowNum = 3000
)
collectionName := "TestSparse" + funcutil.GenRandomStr()
pk := &schemapb.FieldSchema{
FieldID: 100,
Name: integration.Int64Field,
IsPrimaryKey: true,
Description: "",
DataType: schemapb.DataType_Int64,
TypeParams: nil,
IndexParams: nil,
AutoID: true,
}
fVec := &schemapb.FieldSchema{
FieldID: 101,
Name: integration.SparseFloatVecField,
IsPrimaryKey: false,
Description: "",
DataType: schemapb.DataType_SparseFloatVector,
TypeParams: []*commonpb.KeyValuePair{
{
Key: common.DimKey,
Value: fmt.Sprintf("%d", 10),
},
},
IndexParams: nil,
}
schema := &schemapb.CollectionSchema{
Name: collectionName,
AutoID: true,
Fields: []*schemapb.FieldSchema{pk, fVec},
}
marshaledSchema, err := proto.Marshal(schema)
s.NoError(err)
createCollectionStatus, err := c.Proxy.CreateCollection(ctx, &milvuspb.CreateCollectionRequest{
DbName: dbName,
CollectionName: collectionName,
Schema: marshaledSchema,
ShardsNum: common.DefaultShardsNum,
})
s.NoError(err)
s.NotEqual(createCollectionStatus.GetErrorCode(), commonpb.ErrorCode_Success)
}
func (s *SparseTestSuite) TestSparse_invalid_insert() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
c := s.Cluster
const (
dbName = ""
rowNum = 3000
)
collectionName := s.createCollection(ctx, c, dbName)
// valid insert
fVecColumn := integration.NewSparseFloatVectorFieldData(integration.SparseFloatVecField, rowNum)
hashKeys := integration.GenerateHashKeys(rowNum)
insertResult, err := c.Proxy.Insert(ctx, &milvuspb.InsertRequest{
DbName: dbName,
CollectionName: collectionName,
FieldsData: []*schemapb.FieldData{fVecColumn},
HashKeys: hashKeys,
NumRows: uint32(rowNum),
})
s.NoError(err)
s.Equal(insertResult.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success)
sparseVecs := fVecColumn.Field.(*schemapb.FieldData_Vectors).Vectors.GetSparseFloatVector()
// of each row, length of indices and data must equal
sparseVecs.Contents[0] = append(sparseVecs.Contents[0], make([]byte, 4)...)
insertResult, err = c.Proxy.Insert(ctx, &milvuspb.InsertRequest{
DbName: dbName,
CollectionName: collectionName,
FieldsData: []*schemapb.FieldData{fVecColumn},
HashKeys: hashKeys,
NumRows: uint32(rowNum),
})
s.NoError(err)
s.NotEqual(insertResult.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success)
sparseVecs.Contents[0] = sparseVecs.Contents[0][:len(sparseVecs.Contents[0])-4]
// empty row is allowed
sparseVecs.Contents[0] = []byte{}
insertResult, err = c.Proxy.Insert(ctx, &milvuspb.InsertRequest{
DbName: dbName,
CollectionName: collectionName,
FieldsData: []*schemapb.FieldData{fVecColumn},
HashKeys: hashKeys,
NumRows: uint32(rowNum),
})
s.NoError(err)
s.Equal(insertResult.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success)
// unsorted column index is not allowed
sparseVecs.Contents[0] = make([]byte, 16)
typeutil.SparseFloatRowSetAt(sparseVecs.Contents[0], 0, 20, 0.1)
typeutil.SparseFloatRowSetAt(sparseVecs.Contents[0], 1, 10, 0.2)
insertResult, err = c.Proxy.Insert(ctx, &milvuspb.InsertRequest{
DbName: dbName,
CollectionName: collectionName,
FieldsData: []*schemapb.FieldData{fVecColumn},
HashKeys: hashKeys,
NumRows: uint32(rowNum),
})
s.NoError(err)
s.NotEqual(insertResult.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success)
}
func (s *SparseTestSuite) TestSparse_invalid_index_build() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
c := s.Cluster
const (
dbName = ""
rowNum = 3000
)
collectionName := s.createCollection(ctx, c, dbName)
// valid insert
fVecColumn := integration.NewSparseFloatVectorFieldData(integration.SparseFloatVecField, rowNum)
hashKeys := integration.GenerateHashKeys(rowNum)
insertResult, err := c.Proxy.Insert(ctx, &milvuspb.InsertRequest{
DbName: dbName,
CollectionName: collectionName,
FieldsData: []*schemapb.FieldData{fVecColumn},
HashKeys: hashKeys,
NumRows: uint32(rowNum),
})
s.NoError(err)
s.Equal(insertResult.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success)
// flush
flushResp, err := c.Proxy.Flush(ctx, &milvuspb.FlushRequest{
DbName: dbName,
CollectionNames: []string{collectionName},
})
s.NoError(err)
segmentIDs, has := flushResp.GetCollSegIDs()[collectionName]
ids := segmentIDs.GetData()
s.Require().NotEmpty(segmentIDs)
s.Require().True(has)
flushTs, has := flushResp.GetCollFlushTs()[collectionName]
s.True(has)
segments, err := c.MetaWatcher.ShowSegments()
s.NoError(err)
s.NotEmpty(segments)
for _, segment := range segments {
log.Info("ShowSegments result", zap.String("segment", segment.String()))
}
s.WaitForFlush(ctx, ids, flushTs, dbName, collectionName)
// unsupported index type
indexParams := []*commonpb.KeyValuePair{
{
Key: common.IndexTypeKey,
Value: integration.IndexFaissIvfPQ,
},
{
Key: common.MetricTypeKey,
Value: metric.IP,
},
}
createIndexStatus, err := c.Proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{
CollectionName: collectionName,
FieldName: integration.SparseFloatVecField,
IndexName: "_default",
ExtraParams: indexParams,
})
s.NoError(err)
s.NotEqual(commonpb.ErrorCode_Success, createIndexStatus.GetErrorCode())
// nonexist index
indexParams = []*commonpb.KeyValuePair{
{
Key: common.IndexTypeKey,
Value: "INDEX_WHAT",
},
{
Key: common.MetricTypeKey,
Value: metric.IP,
},
}
createIndexStatus, err = c.Proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{
CollectionName: collectionName,
FieldName: integration.SparseFloatVecField,
IndexName: "_default",
ExtraParams: indexParams,
})
s.NoError(err)
s.NotEqual(commonpb.ErrorCode_Success, createIndexStatus.GetErrorCode())
// incorrect metric type
indexParams = []*commonpb.KeyValuePair{
{
Key: common.IndexTypeKey,
Value: integration.IndexSparseInvertedIndex,
},
{
Key: common.MetricTypeKey,
Value: metric.L2,
},
}
createIndexStatus, err = c.Proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{
CollectionName: collectionName,
FieldName: integration.SparseFloatVecField,
IndexName: "_default",
ExtraParams: indexParams,
})
s.NoError(err)
s.NotEqual(commonpb.ErrorCode_Success, createIndexStatus.GetErrorCode())
// incorrect drop ratio build
indexParams = []*commonpb.KeyValuePair{
{
Key: common.IndexTypeKey,
Value: integration.IndexSparseInvertedIndex,
},
{
Key: common.MetricTypeKey,
Value: metric.L2,
},
{
Key: common.DropRatioBuildKey,
Value: "-0.1",
},
}
createIndexStatus, err = c.Proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{
CollectionName: collectionName,
FieldName: integration.SparseFloatVecField,
IndexName: "_default",
ExtraParams: indexParams,
})
s.NoError(err)
s.NotEqual(commonpb.ErrorCode_Success, createIndexStatus.GetErrorCode())
// incorrect drop ratio build
indexParams = []*commonpb.KeyValuePair{
{
Key: common.IndexTypeKey,
Value: integration.IndexSparseInvertedIndex,
},
{
Key: common.MetricTypeKey,
Value: metric.L2,
},
{
Key: common.DropRatioBuildKey,
Value: "1.1",
},
}
createIndexStatus, err = c.Proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{
CollectionName: collectionName,
FieldName: integration.SparseFloatVecField,
IndexName: "_default",
ExtraParams: indexParams,
})
s.NoError(err)
s.NotEqual(commonpb.ErrorCode_Success, createIndexStatus.GetErrorCode())
}
func (s *SparseTestSuite) TestSparse_invalid_search_request() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
c := s.Cluster
const (
dbName = ""
rowNum = 3000
)
collectionName := s.createCollection(ctx, c, dbName)
// valid insert
fVecColumn := integration.NewSparseFloatVectorFieldData(integration.SparseFloatVecField, rowNum)
hashKeys := integration.GenerateHashKeys(rowNum)
insertResult, err := c.Proxy.Insert(ctx, &milvuspb.InsertRequest{
DbName: dbName,
CollectionName: collectionName,
FieldsData: []*schemapb.FieldData{fVecColumn},
HashKeys: hashKeys,
NumRows: uint32(rowNum),
})
s.NoError(err)
s.Equal(insertResult.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success)
// flush
flushResp, err := c.Proxy.Flush(ctx, &milvuspb.FlushRequest{
DbName: dbName,
CollectionNames: []string{collectionName},
})
s.NoError(err)
segmentIDs, has := flushResp.GetCollSegIDs()[collectionName]
ids := segmentIDs.GetData()
s.Require().NotEmpty(segmentIDs)
s.Require().True(has)
flushTs, has := flushResp.GetCollFlushTs()[collectionName]
s.True(has)
segments, err := c.MetaWatcher.ShowSegments()
s.NoError(err)
s.NotEmpty(segments)
for _, segment := range segments {
log.Info("ShowSegments result", zap.String("segment", segment.String()))
}
s.WaitForFlush(ctx, ids, flushTs, dbName, collectionName)
indexType := integration.IndexSparseInvertedIndex
metricType := metric.IP
indexParams := []*commonpb.KeyValuePair{
{
Key: common.MetricTypeKey,
Value: metricType,
},
{
Key: common.IndexTypeKey,
Value: indexType,
},
{
Key: common.DropRatioBuildKey,
Value: "0.1",
},
}
createIndexStatus, err := c.Proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{
CollectionName: collectionName,
FieldName: integration.SparseFloatVecField,
IndexName: "_default",
ExtraParams: indexParams,
})
s.NoError(err)
s.Equal(commonpb.ErrorCode_Success, createIndexStatus.GetErrorCode())
s.WaitForIndexBuilt(ctx, collectionName, integration.SparseFloatVecField)
// load
loadStatus, err := c.Proxy.LoadCollection(ctx, &milvuspb.LoadCollectionRequest{
DbName: dbName,
CollectionName: collectionName,
})
s.NoError(err)
if loadStatus.GetErrorCode() != commonpb.ErrorCode_Success {
log.Warn("loadStatus fail reason", zap.String("reason", loadStatus.GetReason()))
}
s.Equal(commonpb.ErrorCode_Success, loadStatus.GetErrorCode())
s.WaitForLoad(ctx, collectionName)
// search
expr := fmt.Sprintf("%s > 0", integration.Int64Field)
nq := 10
topk := 10
roundDecimal := -1
params := integration.GetSearchParams(indexType, metricType)
searchReq := integration.ConstructSearchRequest("", collectionName, expr,
integration.SparseFloatVecField, schemapb.DataType_SparseFloatVector, nil, metricType, params, nq, 0, topk, roundDecimal)
replaceQuery := func(vecs *schemapb.SparseFloatArray) {
values := make([][]byte, 0, 1)
bs, err := proto.Marshal(vecs)
if err != nil {
panic(err)
}
values = append(values, bs)
plg := &commonpb.PlaceholderGroup{
Placeholders: []*commonpb.PlaceholderValue{
{
Tag: "$0",
Type: commonpb.PlaceholderType_SparseFloatVector,
Values: values,
},
},
}
plgBs, err := proto.Marshal(plg)
if err != nil {
panic(err)
}
searchReq.PlaceholderGroup = plgBs
}
sparseVecs := integration.GenerateSparseFloatArray(nq)
// negative column index
oldIdx := typeutil.SparseFloatRowIndexAt(sparseVecs.Contents[0], 0)
var newIdx int32 = -10
binary.LittleEndian.PutUint32(sparseVecs.Contents[0][0:], uint32(newIdx))
replaceQuery(sparseVecs)
searchResult, err := c.Proxy.Search(ctx, searchReq)
s.NoError(err)
s.NotEqual(commonpb.ErrorCode_Success, searchResult.GetStatus().GetErrorCode())
binary.LittleEndian.PutUint32(sparseVecs.Contents[0][0:], oldIdx)
// of each row, length of indices and data must equal
sparseVecs.Contents[0] = append(sparseVecs.Contents[0], make([]byte, 4)...)
replaceQuery(sparseVecs)
searchResult, err = c.Proxy.Search(ctx, searchReq)
s.NoError(err)
s.NotEqual(commonpb.ErrorCode_Success, searchResult.GetStatus().GetErrorCode())
sparseVecs.Contents[0] = sparseVecs.Contents[0][:len(sparseVecs.Contents[0])-4]
// empty row is not allowed
sparseVecs.Contents[0] = []byte{}
replaceQuery(sparseVecs)
searchResult, err = c.Proxy.Search(ctx, searchReq)
s.NoError(err)
s.NotEqual(commonpb.ErrorCode_Success, searchResult.GetStatus().GetErrorCode())
// column index in the same row must be ordered
sparseVecs.Contents[0] = make([]byte, 16)
typeutil.SparseFloatRowSetAt(sparseVecs.Contents[0], 0, 20, 0.1)
typeutil.SparseFloatRowSetAt(sparseVecs.Contents[0], 1, 10, 0.2)
replaceQuery(sparseVecs)
searchResult, err = c.Proxy.Search(ctx, searchReq)
s.NoError(err)
s.NotEqual(commonpb.ErrorCode_Success, searchResult.GetStatus().GetErrorCode())
}
func TestSparse(t *testing.T) {
suite.Run(t, new(SparseTestSuite))
}