Add support for getting vectors by ids (#23450)

Signed-off-by: bigsheeper <yihao.dai@zilliz.com>
This commit is contained in:
yihao.dai 2023-04-23 09:00:32 +08:00 committed by GitHub
parent 897ed620e4
commit 092d743917
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
36 changed files with 5997 additions and 705 deletions

View File

@ -318,6 +318,9 @@ rpm: install
@cp -r build/rpm/services ~/rpmbuild/BUILD/
@QA_RPATHS="$$[ 0x001|0x0002|0x0020 ]" rpmbuild -ba ./build/rpm/milvus.spec
mock-proxy:
mockery --name=ProxyComponent --dir=$(PWD)/internal/types --output=$(PWD)/internal/mocks --filename=mock_proxy.go --structname=Proxy --with-expecter
mock-datanode:
mockery --name=DataNode --dir=$(PWD)/internal/types --output=$(PWD)/internal/mocks --filename=mock_datanode.go --with-expecter

View File

@ -85,6 +85,16 @@ PrefixMatch(const std::string_view str, const std::string_view prefix) {
return true;
}
inline DatasetPtr
GenIdsDataset(const int64_t count, const int64_t* ids) {
auto ret_ds = std::make_shared<Dataset>();
ret_ds->SetRows(count);
ret_ds->SetDim(1);
ret_ds->SetIds(ids);
ret_ds->SetIsOwner(false);
return ret_ds;
}
inline DatasetPtr
GenResultDataset(const int64_t nq,
const int64_t topk,

View File

@ -230,6 +230,38 @@ VectorDiskAnnIndex<T>::Query(const DatasetPtr dataset,
return result;
}
template <typename T>
const bool
VectorDiskAnnIndex<T>::HasRawData() const {
return index_.HasRawData(GetMetricType());
}
template <typename T>
const std::vector<uint8_t>
VectorDiskAnnIndex<T>::GetVector(const DatasetPtr dataset,
const Config& config) const {
auto res = index_.GetVectorByIds(*dataset, config);
if (!res.has_value()) {
PanicCodeInfo(
ErrorCodeEnum::UnexpectedError,
"failed to get vector, " + MatchKnowhereError(res.error()));
}
auto index_type = GetIndexType();
auto tensor = res.value()->GetTensor();
auto row_num = res.value()->GetRows();
auto dim = res.value()->GetDim();
int64_t data_size;
if (is_in_bin_list(index_type)) {
data_size = dim / 8 * row_num;
} else {
data_size = dim * row_num * sizeof(float);
}
std::vector<uint8_t> raw_data;
raw_data.resize(data_size);
memcpy(raw_data.data(), tensor, data_size);
return raw_data;
}
template <typename T>
void
VectorDiskAnnIndex<T>::CleanLocalData() {

View File

@ -17,6 +17,7 @@
#pragma once
#include <memory>
#include <vector>
#include "index/VectorIndex.h"
#include "storage/DiskFileManagerImpl.h"
@ -60,6 +61,13 @@ class VectorDiskAnnIndex : public VectorIndex {
const SearchInfo& search_info,
const BitsetView& bitset) override;
const bool
HasRawData() const override;
const std::vector<uint8_t>
GetVector(const DatasetPtr dataset,
const Config& config = {}) const override;
void
CleanLocalData() override;

View File

@ -19,6 +19,7 @@
#include <map>
#include <memory>
#include <string>
#include <vector>
#include <boost/dynamic_bitset.hpp>
#include "knowhere/factory.h"
@ -50,6 +51,12 @@ class VectorIndex : public IndexBase {
const SearchInfo& search_info,
const BitsetView& bitset) = 0;
virtual const bool
HasRawData() const = 0;
virtual const std::vector<uint8_t>
GetVector(const DatasetPtr dataset, const Config& config = {}) const = 0;
IndexType
GetIndexType() const {
return index_type_;

View File

@ -145,4 +145,34 @@ VectorMemIndex::Query(const DatasetPtr dataset,
return result;
}
const bool
VectorMemIndex::HasRawData() const {
return index_.HasRawData(GetMetricType());
}
const std::vector<uint8_t>
VectorMemIndex::GetVector(const DatasetPtr dataset,
const Config& config) const {
auto res = index_.GetVectorByIds(*dataset, config);
if (!res.has_value()) {
PanicCodeInfo(
ErrorCodeEnum::UnexpectedError,
"failed to get vector, " + MatchKnowhereError(res.error()));
}
auto index_type = GetIndexType();
auto tensor = res.value()->GetTensor();
auto row_num = res.value()->GetRows();
auto dim = res.value()->GetDim();
int64_t data_size;
if (is_in_bin_list(index_type)) {
data_size = dim / 8 * row_num;
} else {
data_size = dim * row_num * sizeof(float);
}
std::vector<uint8_t> raw_data;
raw_data.resize(data_size);
memcpy(raw_data.data(), tensor, data_size);
return raw_data;
}
} // namespace milvus::index

View File

@ -51,6 +51,13 @@ class VectorMemIndex : public VectorIndex {
const SearchInfo& search_info,
const BitsetView& bitset) override;
const bool
HasRawData() const override;
const std::vector<uint8_t>
GetVector(const DatasetPtr dataset,
const Config& config = {}) const override;
protected:
Config config_;
knowhere::Index<knowhere::IndexNode> index_;

View File

@ -223,6 +223,11 @@ class SegmentGrowingImpl : public SegmentGrowing {
return true;
}
bool
HasRawData(int64_t field_id) const override {
return true;
}
protected:
int64_t
num_chunk() const override;

View File

@ -88,6 +88,9 @@ class SegmentInterface {
virtual SegmentType
type() const = 0;
virtual bool
HasRawData(int64_t field_id) const = 0;
};
// internal API for DSL calculation

View File

@ -29,6 +29,7 @@
#include "query/ScalarIndex.h"
#include "query/SearchBruteForce.h"
#include "query/SearchOnSealed.h"
#include "index/Utils.h"
namespace milvus::segcore {
@ -475,6 +476,35 @@ SegmentSealedImpl::vector_search(SearchInfo& search_info,
}
}
std::unique_ptr<DataArray>
SegmentSealedImpl::get_vector(FieldId field_id,
const int64_t* ids,
int64_t count) const {
auto& filed_meta = schema_->operator[](field_id);
AssertInfo(filed_meta.is_vector(), "vector field is not vector type");
if (get_bit(index_ready_bitset_, field_id)) {
AssertInfo(vector_indexings_.is_ready(field_id),
"vector index is not ready");
auto field_indexing = vector_indexings_.get_field_indexing(field_id);
auto vec_index =
dynamic_cast<index::VectorIndex*>(field_indexing->indexing_.get());
auto index_type = vec_index->GetIndexType();
auto metric_type = vec_index->GetMetricType();
auto has_raw_data = vec_index->HasRawData();
if (has_raw_data) {
auto ids_ds = GenIdsDataset(count, ids);
auto& vector = vec_index->GetVector(ids_ds);
return segcore::CreateVectorDataArrayFrom(
vector.data(), count, filed_meta);
}
}
return fill_with_empty(field_id, count);
}
void
SegmentSealedImpl::DropFieldData(const FieldId field_id) {
if (SystemProperty::Instance().IsSystem(field_id)) {
@ -666,9 +696,7 @@ SegmentSealedImpl::bulk_subscript(FieldId field_id,
return ReverseDataFromIndex(index, seg_offsets, count, field_meta);
}
// TODO: knowhere support reverse data from vector index
// Now, real data will be filled in data array using chunk manager
return fill_with_empty(field_id, count);
return get_vector(field_id, seg_offsets, count);
}
Assert(get_bit(field_data_ready_bitset_, field_id));
@ -783,6 +811,24 @@ SegmentSealedImpl::HasFieldData(FieldId field_id) const {
}
}
bool
SegmentSealedImpl::HasRawData(int64_t field_id) const {
std::shared_lock lck(mutex_);
auto fieldID = FieldId(field_id);
const auto& field_meta = schema_->operator[](fieldID);
if (datatype_is_vector(field_meta.get_data_type())) {
if (get_bit(index_ready_bitset_, fieldID)) {
AssertInfo(vector_indexings_.is_ready(fieldID),
"vector index is not ready");
auto field_indexing = vector_indexings_.get_field_indexing(fieldID);
auto vec_index = dynamic_cast<index::VectorIndex*>(
field_indexing->indexing_.get());
return vec_index->HasRawData();
}
}
return true;
}
std::pair<std::unique_ptr<IdArray>, std::vector<SegOffset>>
SegmentSealedImpl::search_ids(const IdArray& id_array,
Timestamp timestamp) const {

View File

@ -61,6 +61,9 @@ class SegmentSealedImpl : public SegmentSealed {
return id_;
}
bool
HasRawData(int64_t field_id) const override;
public:
int64_t
GetMemoryUsageInBytes() const override;
@ -74,6 +77,9 @@ class SegmentSealedImpl : public SegmentSealed {
const Schema&
get_schema() const override;
std::unique_ptr<DataArray>
get_vector(FieldId field_id, const int64_t* ids, int64_t count) const;
public:
int64_t
num_chunk_index(FieldId field_id) const override;

View File

@ -164,6 +164,13 @@ GetRealCount(CSegmentInterface c_segment) {
return segment->get_real_count();
}
bool
HasRawData(CSegmentInterface c_segment, int64_t field_id) {
auto segment =
reinterpret_cast<milvus::segcore::SegmentInterface*>(c_segment);
return segment->HasRawData(field_id);
}
////////////////////////////// interfaces for growing segment //////////////////////////////
CStatus
Insert(CSegmentInterface c_segment,

View File

@ -67,6 +67,9 @@ GetDeletedCount(CSegmentInterface c_segment);
int64_t
GetRealCount(CSegmentInterface c_segment);
bool
HasRawData(CSegmentInterface c_segment, int64_t field_id);
////////////////////////////// interfaces for growing segment //////////////////////////////
CStatus
Insert(CSegmentInterface c_segment,

View File

@ -449,6 +449,91 @@ TEST_P(IndexTest, BuildAndQuery) {
vec_index->Query(xq_dataset, search_info, nullptr);
}
TEST_P(IndexTest, GetVector) {
milvus::index::CreateIndexInfo create_index_info;
create_index_info.index_type = index_type;
create_index_info.metric_type = metric_type;
create_index_info.field_type = vec_field_data_type;
index::IndexBasePtr index;
if (index_type == knowhere::IndexEnum::INDEX_DISKANN) {
#ifdef BUILD_DISK_ANN
milvus::storage::FieldDataMeta field_data_meta{1, 2, 3, 100};
milvus::storage::IndexMeta index_meta{3, 100, 1000, 1};
auto file_manager =
std::make_shared<milvus::storage::DiskFileManagerImpl>(
field_data_meta, index_meta, storage_config_);
index = milvus::index::IndexFactory::GetInstance().CreateIndex(
create_index_info, file_manager);
#endif
} else {
index = milvus::index::IndexFactory::GetInstance().CreateIndex(
create_index_info, nullptr);
}
ASSERT_NO_THROW(index->BuildWithDataset(xb_dataset, build_conf));
milvus::index::IndexBasePtr new_index;
milvus::index::VectorIndex* vec_index = nullptr;
if (index_type == knowhere::IndexEnum::INDEX_DISKANN) {
#ifdef BUILD_DISK_ANN
// TODO ::diskann.query need load first, ugly
auto binary_set = index->Serialize(milvus::Config{});
index.reset();
milvus::storage::FieldDataMeta field_data_meta{1, 2, 3, 100};
milvus::storage::IndexMeta index_meta{3, 100, 1000, 1};
auto file_manager =
std::make_shared<milvus::storage::DiskFileManagerImpl>(
field_data_meta, index_meta, storage_config_);
new_index = milvus::index::IndexFactory::GetInstance().CreateIndex(
create_index_info, file_manager);
vec_index = dynamic_cast<milvus::index::VectorIndex*>(new_index.get());
std::vector<std::string> index_files;
for (auto& binary : binary_set.binary_map_) {
index_files.emplace_back(binary.first);
}
load_conf["index_files"] = index_files;
vec_index->Load(binary_set, load_conf);
EXPECT_EQ(vec_index->Count(), NB);
#endif
} else {
vec_index = dynamic_cast<milvus::index::VectorIndex*>(index.get());
}
EXPECT_EQ(vec_index->GetDim(), DIM);
EXPECT_EQ(vec_index->Count(), NB);
if (!vec_index->HasRawData()) {
return;
}
auto ids_ds = GenRandomIds(NB);
auto results = vec_index->GetVector(ids_ds);
EXPECT_TRUE(results.size() > 0);
if (!is_binary) {
std::vector<float> result_vectors(results.size() / (sizeof(float)));
memcpy(result_vectors.data(), results.data(), results.size());
EXPECT_TRUE(result_vectors.size() == xb_data.size());
for (size_t i = 0; i < NB; ++i) {
auto id = ids_ds->GetIds()[i];
for (size_t j = 0; j < DIM; ++j) {
EXPECT_TRUE(result_vectors[i * DIM + j] ==
xb_data[id * DIM + j]);
}
}
} else {
EXPECT_TRUE(results.size() == xb_bin_data.size());
const auto data_bytes = DIM / 8;
for (size_t i = 0; i < NB; ++i) {
auto id = ids_ds->GetIds()[i];
for (size_t j = 0; j < data_bytes; ++j) {
EXPECT_TRUE(results[i * data_bytes + j] ==
xb_bin_data[id * data_bytes + j]);
}
}
}
}
// #ifdef BUILD_DISK_ANN
// TEST(Indexing, SearchDiskAnnWithInvalidParam) {
// int64_t NB = 10000;

View File

@ -1067,3 +1067,52 @@ TEST(Sealed, RealCount) {
ASSERT_TRUE(status.ok());
ASSERT_EQ(0, segment->get_real_count());
}
TEST(Sealed, GetVector) {
auto dim = 16;
auto topK = 5;
auto N = ROW_COUNT;
auto metric_type = knowhere::metric::L2;
auto schema = std::make_shared<Schema>();
auto fakevec_id = schema->AddDebugField(
"fakevec", DataType::VECTOR_FLOAT, dim, metric_type);
auto counter_id = schema->AddDebugField("counter", DataType::INT64);
auto double_id = schema->AddDebugField("double", DataType::DOUBLE);
auto nothing_id = schema->AddDebugField("nothing", DataType::INT32);
auto str_id = schema->AddDebugField("str", DataType::VARCHAR);
schema->AddDebugField("int8", DataType::INT8);
schema->AddDebugField("int16", DataType::INT16);
schema->AddDebugField("float", DataType::FLOAT);
schema->set_primary_field_id(counter_id);
auto dataset = DataGen(schema, N);
auto fakevec = dataset.get_col<float>(fakevec_id);
auto indexing = GenVecIndexing(N, dim, fakevec.data());
auto segment_sealed = CreateSealedSegment(schema);
LoadIndexInfo vec_info;
vec_info.field_id = fakevec_id.get();
vec_info.index = std::move(indexing);
vec_info.index_params["metric_type"] = knowhere::metric::L2;
segment_sealed->LoadIndex(vec_info);
auto segment = dynamic_cast<SegmentSealedImpl*>(segment_sealed.get());
auto has = segment->HasRawData(vec_info.field_id);
EXPECT_TRUE(has);
auto ids_ds = GenRandomIds(N);
auto result = segment->get_vector(fakevec_id, ids_ds->GetIds(), N);
auto vector = result.get()->mutable_vectors()->float_vector().data();
EXPECT_TRUE(vector.size() == fakevec.size());
for (size_t i = 0; i < N; ++i) {
auto id = ids_ds->GetIds()[i];
for (size_t j = 0; j < dim; ++j) {
EXPECT_TRUE(vector[i * dim + j] == fakevec[id * dim + j]);
}
}
}

View File

@ -599,4 +599,15 @@ GenPKs(const std::vector<int64_t>& pks) {
return GenPKs(pks.begin(), pks.end());
}
inline std::shared_ptr<knowhere::DataSet>
GenRandomIds(int rows, int64_t seed = 42) {
std::mt19937 g(seed);
auto* ids = new int64_t[rows];
for (int i = 0; i < rows; ++i) ids[i] = i;
std::shuffle(ids, ids + rows, g);
auto ids_ds = GenIdsDataset(rows, ids);
ids_ds->SetIsOwner(true);
return ids_ds;
}
} // namespace milvus::segcore

4086
internal/mocks/mock_proxy.go Normal file

File diff suppressed because it is too large Load Diff

View File

@ -2411,9 +2411,10 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest)
ReqID: paramtable.GetNodeID(),
},
request: request,
qc: node.queryCoord,
tr: timerecord.NewTimeRecorder("search"),
shardMgr: node.shardMgr,
qc: node.queryCoord,
node: node,
}
travelTs := request.TravelTimestamp

View File

@ -9,6 +9,7 @@ import (
"github.com/golang/protobuf/proto"
"github.com/milvus-io/milvus/internal/proto/planpb"
"github.com/milvus-io/milvus/pkg/util/tsoutil"
"github.com/samber/lo"
"github.com/cockroachdb/errors"
"github.com/milvus-io/milvus/internal/parser/planparserv2"
@ -492,7 +493,10 @@ func IDs2Expr(fieldName string, ids *schemapb.IDs) string {
case *schemapb.IDs_IntId:
idsStr = strings.Trim(strings.Join(strings.Fields(fmt.Sprint(ids.GetIntId().GetData())), ", "), "[]")
case *schemapb.IDs_StrId:
idsStr = strings.Trim(strings.Join(ids.GetStrId().GetData(), ", "), "[]")
strs := lo.Map(ids.GetStrId().GetData(), func(str string, _ int) string {
return fmt.Sprintf("\"%s\"", str)
})
idsStr = strings.Trim(strings.Join(strs, ", "), "[]")
}
return fieldName + " in [ " + idsStr + " ]"

View File

@ -869,3 +869,28 @@ func Test_queryTask_createPlan(t *testing.T) {
assert.Error(t, err)
})
}
func TestQueryTask_IDs2Expr(t *testing.T) {
fieldName := "pk"
intIDs := &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: []int64{1, 2, 3, 4, 5},
},
},
}
stringIDs := &schemapb.IDs{
IdField: &schemapb.IDs_StrId{
StrId: &schemapb.StringArray{
Data: []string{"a", "b", "c"},
},
},
}
idExpr := IDs2Expr(fieldName, intIDs)
expectIDExpr := "pk in [ 1, 2, 3, 4, 5 ]"
assert.Equal(t, expectIDExpr, idExpr)
strExpr := IDs2Expr(fieldName, stringIDs)
expectStrExpr := "pk in [ \"a\", \"b\", \"c\" ]"
assert.Equal(t, expectStrExpr, strExpr)
}

View File

@ -3,11 +3,13 @@ package proxy
import (
"context"
"fmt"
"math"
"regexp"
"strconv"
"github.com/cockroachdb/errors"
"github.com/golang/protobuf/proto"
"github.com/samber/lo"
"go.opentelemetry.io/otel"
"go.uber.org/zap"
@ -25,6 +27,7 @@ import (
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/util/distance"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/timerecord"
"github.com/milvus-io/milvus/pkg/util/tsoutil"
@ -34,6 +37,12 @@ import (
const (
SearchTaskName = "SearchTask"
SearchLevelKey = "level"
// requeryThreshold is the estimated threshold for the size of the search results.
// If the number of estimated search results exceeds this threshold,
// a second query request will be initiated to retrieve output fields data.
// In this case, the first search will not return any output field from QueryNodes.
requeryThreshold = 0.5 * 1024 * 1024
)
type searchTask struct {
@ -41,13 +50,14 @@ type searchTask struct {
*internalpb.SearchRequest
ctx context.Context
result *milvuspb.SearchResults
request *milvuspb.SearchRequest
qc types.QueryCoord
result *milvuspb.SearchResults
request *milvuspb.SearchRequest
tr *timerecord.TimeRecorder
collectionName string
channelNum int32
schema *schemapb.CollectionSchema
requery bool
offset int64
resultBuf chan *internalpb.SearchResults
@ -55,6 +65,9 @@ type searchTask struct {
searchShardPolicy pickShardPolicy
shardMgr *shardClientMgr
qc types.QueryCoord
node types.ProxyComponent
}
func getPartitionIDs(ctx context.Context, collectionName string, partitionNames []string) (partitionIDs []UniqueID, err error) {
@ -164,11 +177,7 @@ func getOutputFieldIDs(schema *schemapb.CollectionSchema, outputFields []string)
hitField := false
for _, field := range schema.GetFields() {
if field.Name == name {
if field.DataType == schemapb.DataType_BinaryVector || field.DataType == schemapb.DataType_FloatVector {
return nil, errors.New("search doesn't support vector field as output_fields")
}
outputFieldIDs = append(outputFieldIDs, field.GetFieldID())
hitField = true
break
}
@ -255,6 +264,24 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
}
t.SearchRequest.IgnoreGrowing = ignoreGrowing
// Manually update nq if not set.
nq, err := getNq(t.request)
if err != nil {
return err
}
// Check if nq is valid:
// https://milvus.io/docs/limitations.md
if err := validateLimit(nq); err != nil {
return fmt.Errorf("%s [%d] is invalid, %w", NQKey, nq, err)
}
t.SearchRequest.Nq = nq
outputFieldIDs, err := getOutputFieldIDs(t.schema, t.request.GetOutputFields())
if err != nil {
return err
}
t.SearchRequest.OutputFieldsId = outputFieldIDs
if t.request.GetDslType() == commonpb.DslType_BoolExprV1 {
annsField, err := funcutil.GetAttrByKeyFromRepeatedKV(AnnsFieldKey, t.request.GetSearchParams())
if err != nil {
@ -278,17 +305,21 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
zap.String("dsl", t.request.Dsl), // may be very large if large term passed.
zap.String("anns field", annsField), zap.Any("query info", queryInfo))
outputFieldIDs, err := getOutputFieldIDs(t.schema, t.request.GetOutputFields())
if err != nil {
return err
}
t.SearchRequest.OutputFieldsId = outputFieldIDs
plan.OutputFieldIds = outputFieldIDs
t.SearchRequest.Topk = queryInfo.GetTopk()
t.SearchRequest.MetricType = queryInfo.GetMetricType()
t.SearchRequest.DslType = commonpb.DslType_BoolExprV1
estimateSize, err := t.estimateResultSize(nq, t.SearchRequest.Topk)
if err != nil {
return err
}
if estimateSize >= requeryThreshold {
t.requery = true
plan.OutputFieldIds = nil
}
t.SearchRequest.SerializedExprPlan, err = proto.Marshal(plan)
if err != nil {
return err
@ -319,17 +350,6 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
t.SearchRequest.Dsl = t.request.Dsl
t.SearchRequest.PlaceholderGroup = t.request.PlaceholderGroup
// Manually update nq if not set.
nq, err := getNq(t.request)
if err != nil {
return err
}
// Check if nq is valid:
// https://milvus.io/docs/limitations.md
if err := validateLimit(nq); err != nil {
return fmt.Errorf("%s [%d] is invalid, %w", NQKey, nq, err)
}
t.SearchRequest.Nq = nq
log.Ctx(ctx).Debug("search PreExecute done.",
zap.Uint64("travel_ts", travelTimestamp), zap.Uint64("guarantee_ts", guaranteeTs),
@ -435,6 +455,13 @@ func (t *searchTask) PostExecute(ctx context.Context) error {
t.result.CollectionName = t.collectionName
t.fillInFieldInfo()
if t.requery {
err = t.Requery()
if err != nil {
return err
}
}
log.Ctx(ctx).Debug("Search post execute done",
zap.Int64("collection", t.GetCollectionID()),
zap.Int64s("partitionIDs", t.GetPartitionIDs()))
@ -480,6 +507,93 @@ func (t *searchTask) searchShard(ctx context.Context, nodeID int64, qn types.Que
return nil
}
func (t *searchTask) estimateResultSize(nq int64, topK int64) (int64, error) {
vectorOutputFields := lo.Filter(t.schema.GetFields(), func(field *schemapb.FieldSchema, _ int) bool {
return lo.Contains(t.request.GetOutputFields(), field.GetName()) && typeutil.IsVectorType(field.GetDataType())
})
// Currently, we get vectors by requery. Once we support getting vectors from search,
// searches with small result size could no longer need requery.
if len(vectorOutputFields) > 0 {
return math.MaxInt64, nil
}
// If no vector field as output, no need to requery.
return 0, nil
//outputFields := lo.Filter(t.schema.GetFields(), func(field *schemapb.FieldSchema, _ int) bool {
// return lo.Contains(t.request.GetOutputFields(), field.GetName())
//})
//sizePerRecord, err := typeutil.EstimateSizePerRecord(&schemapb.CollectionSchema{Fields: outputFields})
//if err != nil {
// return 0, err
//}
//return int64(sizePerRecord) * nq * topK, nil
}
func (t *searchTask) Requery() error {
pkField, err := typeutil.GetPrimaryFieldSchema(t.schema)
if err != nil {
return err
}
ids := t.result.GetResults().GetIds()
expr := IDs2Expr(pkField.GetName(), ids)
queryReq := &milvuspb.QueryRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Retrieve,
},
CollectionName: t.request.GetCollectionName(),
Expr: expr,
OutputFields: t.request.GetOutputFields(),
PartitionNames: t.request.GetPartitionNames(),
TravelTimestamp: t.request.GetTravelTimestamp(),
GuaranteeTimestamp: t.request.GetGuaranteeTimestamp(),
QueryParams: t.request.GetSearchParams(),
}
queryResult, err := t.node.Query(t.ctx, queryReq)
if err != nil {
return err
}
if queryResult.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
return merr.Error(queryResult.GetStatus())
}
// Reorganize Results. The order of query result ids will be altered and differ from queried ids.
// We should reorganize query results to keep the order of original queried ids. For example:
// ===========================================
// 3 2 5 4 1 (query ids)
// ||
// || (query)
// \/
// 4 3 5 1 2 (result ids)
// v4 v3 v5 v1 v2 (result vectors)
// ||
// || (reorganize)
// \/
// 3 2 5 4 1 (result ids)
// v3 v2 v5 v4 v1 (result vectors)
// ===========================================
pkFieldData, err := typeutil.GetPrimaryFieldData(queryResult.GetFieldsData(), pkField)
if err != nil {
return err
}
offsets := make(map[any]int)
for i := 0; i < typeutil.GetDataSize(pkFieldData); i++ {
pk := typeutil.GetData(pkFieldData, i)
offsets[pk] = i
}
t.result.Results.FieldsData = make([]*schemapb.FieldData, len(queryResult.GetFieldsData()))
for i := 0; i < typeutil.GetSizeOfIDs(ids); i++ {
id := typeutil.GetPK(ids, int64(i))
if _, ok := offsets[id]; !ok {
return fmt.Errorf("incomplete query result, missing id %s, len(searchIDs) = %d, len(queryIDs) = %d, collection=%d",
id, typeutil.GetSizeOfIDs(ids), len(offsets), t.GetCollectionID())
}
typeutil.AppendFieldData(t.result.Results.FieldsData, queryResult.GetFieldsData(), int64(offsets[id]))
}
return nil
}
func (t *searchTask) fillInEmptyResult(numQueries int64) {
t.result = &milvuspb.SearchResults{
Status: &commonpb.Status{

View File

@ -17,6 +17,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/types"
@ -268,7 +269,7 @@ func TestSearchTask_PreExecute(t *testing.T) {
// contain vector field
task.request.OutputFields = []string{testFloatVecField}
assert.Error(t, task.PreExecute(ctx))
assert.NoError(t, task.PreExecute(ctx))
})
}
@ -1959,3 +1960,287 @@ func getSearchResultData(nq, topk int64) *schemapb.SearchResultData {
}
return &result
}
func TestSearchTask_Requery(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
const (
dim = 128
rows = 5
collection = "test-requery"
pkField = "pk"
vecField = "vec"
)
ids := make([]int64, rows)
for i := range ids {
ids[i] = int64(i)
}
t.Run("Test normal", func(t *testing.T) {
schema := constructCollectionSchema(pkField, vecField, dim, collection)
node := mocks.NewProxy(t)
node.EXPECT().Query(mock.Anything, mock.Anything).
Return(&milvuspb.QueryResults{
FieldsData: []*schemapb.FieldData{{
Type: schemapb.DataType_Int64,
FieldName: pkField,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_LongData{
LongData: &schemapb.LongArray{
Data: ids,
},
},
},
},
},
newFloatVectorFieldData(vecField, rows, dim),
},
}, nil)
resultIDs := &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: ids,
},
},
}
qt := &searchTask{
ctx: ctx,
SearchRequest: &internalpb.SearchRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Search,
SourceID: paramtable.GetNodeID(),
},
},
request: &milvuspb.SearchRequest{},
result: &milvuspb.SearchResults{
Results: &schemapb.SearchResultData{
Ids: resultIDs,
},
},
schema: schema,
tr: timerecord.NewTimeRecorder("search"),
node: node,
}
err := qt.Requery()
assert.NoError(t, err)
})
t.Run("Test no primary key", func(t *testing.T) {
schema := &schemapb.CollectionSchema{}
node := mocks.NewProxy(t)
qt := &searchTask{
ctx: ctx,
SearchRequest: &internalpb.SearchRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Search,
SourceID: paramtable.GetNodeID(),
},
},
request: &milvuspb.SearchRequest{},
schema: schema,
tr: timerecord.NewTimeRecorder("search"),
node: node,
}
err := qt.Requery()
t.Logf("err = %s", err)
assert.Error(t, err)
})
t.Run("Test requery failed 1", func(t *testing.T) {
schema := constructCollectionSchema(pkField, vecField, dim, collection)
node := mocks.NewProxy(t)
node.EXPECT().Query(mock.Anything, mock.Anything).
Return(nil, fmt.Errorf("mock err 1"))
qt := &searchTask{
ctx: ctx,
SearchRequest: &internalpb.SearchRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Search,
SourceID: paramtable.GetNodeID(),
},
},
request: &milvuspb.SearchRequest{},
schema: schema,
tr: timerecord.NewTimeRecorder("search"),
node: node,
}
err := qt.Requery()
t.Logf("err = %s", err)
assert.Error(t, err)
})
t.Run("Test requery failed 2", func(t *testing.T) {
schema := constructCollectionSchema(pkField, vecField, dim, collection)
node := mocks.NewProxy(t)
node.EXPECT().Query(mock.Anything, mock.Anything).
Return(&milvuspb.QueryResults{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "mock err 2",
},
}, nil)
qt := &searchTask{
ctx: ctx,
SearchRequest: &internalpb.SearchRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Search,
SourceID: paramtable.GetNodeID(),
},
},
request: &milvuspb.SearchRequest{},
schema: schema,
tr: timerecord.NewTimeRecorder("search"),
node: node,
}
err := qt.Requery()
t.Logf("err = %s", err)
assert.Error(t, err)
})
t.Run("Test get pk filed data failed", func(t *testing.T) {
schema := constructCollectionSchema(pkField, vecField, dim, collection)
node := mocks.NewProxy(t)
node.EXPECT().Query(mock.Anything, mock.Anything).
Return(&milvuspb.QueryResults{
FieldsData: []*schemapb.FieldData{},
}, nil)
qt := &searchTask{
ctx: ctx,
SearchRequest: &internalpb.SearchRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Search,
SourceID: paramtable.GetNodeID(),
},
},
request: &milvuspb.SearchRequest{},
schema: schema,
tr: timerecord.NewTimeRecorder("search"),
node: node,
}
err := qt.Requery()
t.Logf("err = %s", err)
assert.Error(t, err)
})
t.Run("Test incomplete query result", func(t *testing.T) {
schema := constructCollectionSchema(pkField, vecField, dim, collection)
node := mocks.NewProxy(t)
node.EXPECT().Query(mock.Anything, mock.Anything).
Return(&milvuspb.QueryResults{
FieldsData: []*schemapb.FieldData{{
Type: schemapb.DataType_Int64,
FieldName: pkField,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_LongData{
LongData: &schemapb.LongArray{
Data: ids[:len(ids)-1],
},
},
},
},
},
newFloatVectorFieldData(vecField, rows, dim),
},
}, nil)
resultIDs := &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: ids,
},
},
}
qt := &searchTask{
ctx: ctx,
SearchRequest: &internalpb.SearchRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Search,
SourceID: paramtable.GetNodeID(),
},
},
request: &milvuspb.SearchRequest{},
result: &milvuspb.SearchResults{
Results: &schemapb.SearchResultData{
Ids: resultIDs,
},
},
schema: schema,
tr: timerecord.NewTimeRecorder("search"),
node: node,
}
err := qt.Requery()
t.Logf("err = %s", err)
assert.Error(t, err)
})
t.Run("Test postExecute with requery failed", func(t *testing.T) {
schema := constructCollectionSchema(pkField, vecField, dim, collection)
node := mocks.NewProxy(t)
node.EXPECT().Query(mock.Anything, mock.Anything).
Return(nil, fmt.Errorf("mock err 1"))
resultIDs := &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: ids,
},
},
}
qt := &searchTask{
ctx: ctx,
SearchRequest: &internalpb.SearchRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Search,
SourceID: paramtable.GetNodeID(),
},
},
request: &milvuspb.SearchRequest{},
result: &milvuspb.SearchResults{
Results: &schemapb.SearchResultData{
Ids: resultIDs,
},
},
requery: true,
schema: schema,
resultBuf: make(chan *internalpb.SearchResults, 10),
tr: timerecord.NewTimeRecorder("search"),
node: node,
}
scores := make([]float32, rows)
for i := range scores {
scores[i] = float32(i)
}
partialResultData := &schemapb.SearchResultData{
Ids: resultIDs,
Scores: scores,
}
bytes, err := proto.Marshal(partialResultData)
assert.NoError(t, err)
qt.resultBuf <- &internalpb.SearchResults{
SlicedBlob: bytes,
}
err = qt.PostExecute(ctx)
t.Logf("err = %s", err)
assert.Error(t, err)
})
}

View File

@ -282,6 +282,13 @@ func (s *LocalSegment) ExistIndex(fieldID int64) bool {
return fieldInfo.IndexInfo != nil && fieldInfo.IndexInfo.EnableIndex
}
func (s *LocalSegment) HasRawData(fieldID int64) bool {
s.mut.RLock()
defer s.mut.RUnlock()
ret := C.HasRawData(s.ptr, C.int64_t(fieldID))
return bool(ret)
}
func (s *LocalSegment) Indexes() []*IndexedFieldInfo {
var result []*IndexedFieldInfo
s.fieldIndexes.Range(func(key int64, value *IndexedFieldInfo) bool {
@ -463,10 +470,18 @@ func (s *LocalSegment) FillIndexedFieldsData(ctx context.Context,
)
for _, fieldData := range result.FieldsData {
// If the vector field doesn't have indexed. Vector data is in memory for
// brute force search. No need to download data from remote.
if fieldData.GetType() != schemapb.DataType_FloatVector && fieldData.GetType() != schemapb.DataType_BinaryVector ||
!s.ExistIndex(fieldData.FieldId) {
// If the field is not vector field, no need to download data from remote.
if !typeutil.IsVectorType(fieldData.GetType()) {
continue
}
// If the vector field doesn't have indexed, vector data is in memory
// for brute force search, no need to download data from remote.
if !s.ExistIndex(fieldData.FieldId) {
continue
}
// If the index has raw data, vector data could be obtained from index,
// no need to download data from remote.
if s.HasRawData(fieldData.FieldId) {
continue
}

View File

@ -117,6 +117,13 @@ func (suite *SegmentSuite) TestDelete() {
suite.Equal(rowNum, suite.growing.InsertCount())
}
func (suite *SegmentSuite) TestHasRawData() {
has := suite.growing.HasRawData(simpleFloatVecField.id)
suite.True(has)
has = suite.sealed.HasRawData(simpleFloatVecField.id)
suite.True(has)
}
func TestSegment(t *testing.T) {
suite.Run(t, new(SegmentSuite))
}

View File

@ -735,6 +735,30 @@ func GetSizeOfIDs(data *schemapb.IDs) int {
return result
}
func GetDataSize(fieldData *schemapb.FieldData) int {
switch fieldData.GetType() {
case schemapb.DataType_Bool:
return len(fieldData.GetScalars().GetBoolData().GetData())
case schemapb.DataType_Int8, schemapb.DataType_Int16, schemapb.DataType_Int32:
return len(fieldData.GetScalars().GetIntData().GetData())
case schemapb.DataType_Int64:
return len(fieldData.GetScalars().GetLongData().GetData())
case schemapb.DataType_Float:
return len(fieldData.GetScalars().GetFloatData().GetData())
case schemapb.DataType_Double:
return len(fieldData.GetScalars().GetDoubleData().GetData())
case schemapb.DataType_String:
return len(fieldData.GetScalars().GetStringData().GetData())
case schemapb.DataType_VarChar:
return len(fieldData.GetScalars().GetStringData().GetData())
case schemapb.DataType_FloatVector:
return len(fieldData.GetVectors().GetFloatVector().GetData())
case schemapb.DataType_BinaryVector:
return len(fieldData.GetVectors().GetBinaryVector())
}
return 0
}
func IsPrimaryFieldType(dataType schemapb.DataType) bool {
if dataType == schemapb.DataType_Int64 || dataType == schemapb.DataType_VarChar {
return true
@ -756,6 +780,33 @@ func GetPK(data *schemapb.IDs, idx int64) interface{} {
return nil
}
func GetData(field *schemapb.FieldData, idx int) interface{} {
switch field.GetType() {
case schemapb.DataType_Bool:
return field.GetScalars().GetBoolData().GetData()[idx]
case schemapb.DataType_Int8, schemapb.DataType_Int16, schemapb.DataType_Int32:
return field.GetScalars().GetIntData().GetData()[idx]
case schemapb.DataType_Int64:
return field.GetScalars().GetLongData().GetData()[idx]
case schemapb.DataType_Float:
return field.GetScalars().GetFloatData().GetData()[idx]
case schemapb.DataType_Double:
return field.GetScalars().GetDoubleData().GetData()[idx]
case schemapb.DataType_String:
return field.GetScalars().GetStringData().GetData()[idx]
case schemapb.DataType_VarChar:
return field.GetScalars().GetStringData().GetData()[idx]
case schemapb.DataType_FloatVector:
dim := int(field.GetVectors().GetDim())
return field.GetVectors().GetFloatVector().GetData()[idx*dim : (idx+1)*dim]
case schemapb.DataType_BinaryVector:
dim := int(field.GetVectors().GetDim())
dataBytes := dim / 8
return field.GetVectors().GetBinaryVector()[idx*dataBytes : (idx+1)*dataBytes]
}
return nil
}
func AppendPKs(pks *schemapb.IDs, pk interface{}) {
switch realPK := pk.(type) {
case int64:

View File

@ -470,6 +470,21 @@ func genFieldData(fieldName string, fieldID int64, fieldType schemapb.DataType,
},
FieldId: fieldID,
}
case schemapb.DataType_String:
fieldData = &schemapb.FieldData{
Type: schemapb.DataType_String,
FieldName: fieldName,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_StringData{
StringData: &schemapb.StringArray{
Data: fieldValue.([]string),
},
},
},
},
FieldId: fieldID,
}
case schemapb.DataType_VarChar:
fieldData = &schemapb.FieldData{
Type: schemapb.DataType_VarChar,
@ -990,3 +1005,94 @@ func TestCalcColumnSize(t *testing.T) {
assert.Equal(t, expected, size, field.GetName())
}
}
func TestGetDataAndGetDataSize(t *testing.T) {
const (
Dim = 8
fieldName = "filed-0"
fieldID = 0
)
BoolArray := []bool{true, false}
Int8Array := []int8{1, 2}
Int16Array := []int16{3, 4}
Int32Array := []int32{5, 6}
Int64Array := []int64{11, 22}
FloatArray := []float32{1.0, 2.0}
DoubleArray := []float64{11.0, 22.0}
VarCharArray := []string{"a", "b"}
StringArray := []string{"c", "d"}
BinaryVector := []byte{0x12, 0x34}
FloatVector := []float32{1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 11.0, 22.0, 33.0, 44.0, 55.0, 66.0, 77.0, 88.0}
boolData := genFieldData(fieldName, fieldID, schemapb.DataType_Bool, BoolArray, 1)
int8Data := genFieldData(fieldName, fieldID, schemapb.DataType_Int8, Int8Array, 1)
int16Data := genFieldData(fieldName, fieldID, schemapb.DataType_Int16, Int16Array, 1)
int32Data := genFieldData(fieldName, fieldID, schemapb.DataType_Int32, Int32Array, 1)
int64Data := genFieldData(fieldName, fieldID, schemapb.DataType_Int64, Int64Array, 1)
floatData := genFieldData(fieldName, fieldID, schemapb.DataType_Float, FloatArray, 1)
doubleData := genFieldData(fieldName, fieldID, schemapb.DataType_Double, DoubleArray, 1)
varCharData := genFieldData(fieldName, fieldID, schemapb.DataType_VarChar, VarCharArray, 1)
stringData := genFieldData(fieldName, fieldID, schemapb.DataType_String, StringArray, 1)
binVecData := genFieldData(fieldName, fieldID, schemapb.DataType_BinaryVector, BinaryVector, Dim)
floatVecData := genFieldData(fieldName, fieldID, schemapb.DataType_FloatVector, FloatVector, Dim)
invalidData := &schemapb.FieldData{
Type: schemapb.DataType_None,
}
t.Run("test GetDataSize", func(t *testing.T) {
boolDataRes := GetDataSize(boolData)
int8DataRes := GetDataSize(int8Data)
int16DataRes := GetDataSize(int16Data)
int32DataRes := GetDataSize(int32Data)
int64DataRes := GetDataSize(int64Data)
floatDataRes := GetDataSize(floatData)
doubleDataRes := GetDataSize(doubleData)
varCharDataRes := GetDataSize(varCharData)
stringDataRes := GetDataSize(stringData)
binVecDataRes := GetDataSize(binVecData)
floatVecDataRes := GetDataSize(floatVecData)
invalidDataRes := GetDataSize(invalidData)
assert.Equal(t, 2, boolDataRes)
assert.Equal(t, 2, int8DataRes)
assert.Equal(t, 2, int16DataRes)
assert.Equal(t, 2, int32DataRes)
assert.Equal(t, 2, int64DataRes)
assert.Equal(t, 2, floatDataRes)
assert.Equal(t, 2, doubleDataRes)
assert.Equal(t, 2, varCharDataRes)
assert.Equal(t, 2, stringDataRes)
assert.Equal(t, 2*Dim/8, binVecDataRes)
assert.Equal(t, 2*Dim, floatVecDataRes)
assert.Equal(t, 0, invalidDataRes)
})
t.Run("test GetData", func(t *testing.T) {
boolDataRes := GetData(boolData, 0)
int8DataRes := GetData(int8Data, 0)
int16DataRes := GetData(int16Data, 0)
int32DataRes := GetData(int32Data, 0)
int64DataRes := GetData(int64Data, 0)
floatDataRes := GetData(floatData, 0)
doubleDataRes := GetData(doubleData, 0)
varCharDataRes := GetData(varCharData, 0)
stringDataRes := GetData(stringData, 0)
binVecDataRes := GetData(binVecData, 0)
floatVecDataRes := GetData(floatVecData, 0)
invalidDataRes := GetData(invalidData, 0)
assert.Equal(t, BoolArray[0], boolDataRes)
assert.Equal(t, int32(Int8Array[0]), int8DataRes)
assert.Equal(t, int32(Int16Array[0]), int16DataRes)
assert.Equal(t, Int32Array[0], int32DataRes)
assert.Equal(t, Int64Array[0], int64DataRes)
assert.Equal(t, FloatArray[0], floatDataRes)
assert.Equal(t, DoubleArray[0], doubleDataRes)
assert.Equal(t, VarCharArray[0], varCharDataRes)
assert.Equal(t, StringArray[0], stringDataRes)
assert.ElementsMatch(t, BinaryVector[:Dim/8], binVecDataRes)
assert.ElementsMatch(t, FloatVector[:Dim], floatVecDataRes)
assert.Nil(t, invalidDataRes)
})
}

View File

@ -51,72 +51,24 @@ const (
// 5, load
// 6, search
func TestBulkInsert(t *testing.T) {
ctx := context.Background()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*180)
c, err := StartMiniCluster(ctx)
assert.NoError(t, err)
err = c.Start()
assert.NoError(t, err)
defer c.Stop()
assert.NoError(t, err)
defer func() {
err = c.Stop()
assert.NoError(t, err)
cancel()
}()
prefix := "TestBulkInsert"
dbName := ""
collectionName := prefix + funcutil.GenRandomStr()
int64Field := "int64"
floatVecField := "embeddings"
scalarField := "image_path"
floatVecField := floatVecField
dim := 128
constructCollectionSchema := func() *schemapb.CollectionSchema {
pk := &schemapb.FieldSchema{
Name: int64Field,
IsPrimaryKey: true,
Description: "pk",
DataType: schemapb.DataType_Int64,
TypeParams: nil,
IndexParams: nil,
AutoID: true,
}
fVec := &schemapb.FieldSchema{
Name: floatVecField,
IsPrimaryKey: false,
Description: "",
DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: strconv.Itoa(dim),
},
},
IndexParams: nil,
AutoID: false,
}
scalar := &schemapb.FieldSchema{
Name: scalarField,
IsPrimaryKey: false,
Description: "",
DataType: schemapb.DataType_VarChar,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "max_length",
Value: "65535",
},
},
IndexParams: nil,
AutoID: false,
}
return &schemapb.CollectionSchema{
Name: collectionName,
Description: "",
AutoID: false,
Fields: []*schemapb.FieldSchema{
pk,
fVec,
scalar,
},
}
}
schema := constructCollectionSchema()
schema := constructSchema(collectionName, dim, true)
marshaledSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
@ -207,28 +159,7 @@ func TestBulkInsert(t *testing.T) {
CollectionName: collectionName,
FieldName: floatVecField,
IndexName: "_default",
ExtraParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: strconv.Itoa(dim),
},
{
Key: common.MetricTypeKey,
Value: distance.L2,
},
{
Key: "index_type",
Value: "HNSW",
},
{
Key: "M",
Value: "64",
},
{
Key: "efConstruction",
Value: "512",
},
},
ExtraParams: constructIndexParam(dim, IndexHNSW, distance.L2),
})
if createIndexStatus.GetErrorCode() != commonpb.ErrorCode_Success {
log.Warn("createIndexStatus fail reason", zap.String("reason", createIndexStatus.GetReason()))
@ -246,30 +177,17 @@ func TestBulkInsert(t *testing.T) {
log.Warn("loadStatus fail reason", zap.String("reason", loadStatus.GetReason()))
}
assert.Equal(t, commonpb.ErrorCode_Success, loadStatus.GetErrorCode())
for {
loadProgress, err := c.proxy.GetLoadingProgress(ctx, &milvuspb.GetLoadingProgressRequest{
CollectionName: collectionName,
})
if err != nil {
panic("GetLoadingProgress fail")
}
if loadProgress.GetProgress() == 100 {
break
}
time.Sleep(500 * time.Millisecond)
}
waitingForLoad(ctx, c, collectionName)
// search
expr := fmt.Sprintf("%s > 0", "int64")
expr := fmt.Sprintf("%s > 0", int64Field)
nq := 10
topk := 10
roundDecimal := -1
nprobe := 10
params := make(map[string]int)
params["nprobe"] = nprobe
params := getSearchParams(IndexHNSW, distance.L2)
searchReq := constructSearchRequest("", collectionName, expr,
floatVecField, distance.L2, params, nq, dim, topk, roundDecimal)
floatVecField, schemapb.DataType_FloatVector, nil, distance.L2, params, nq, dim, topk, roundDecimal)
searchResult, err := c.proxy.Search(ctx, searchReq)

View File

@ -0,0 +1,366 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package integration
import (
"context"
"fmt"
"strconv"
"testing"
"time"
"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/util/distance"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
type TestGetVectorSuite struct {
suite.Suite
ctx context.Context
cancel context.CancelFunc
cluster *MiniCluster
// test params
nq int
topK int
indexType string
metricType string
pkType schemapb.DataType
vecType schemapb.DataType
}
func (suite *TestGetVectorSuite) SetupTest() {
suite.ctx, suite.cancel = context.WithTimeout(context.Background(), time.Second*600)
var err error
suite.cluster, err = StartMiniCluster(suite.ctx)
suite.Require().NoError(err)
err = suite.cluster.Start()
suite.Require().NoError(err)
}
func (suite *TestGetVectorSuite) run() {
collection := fmt.Sprintf("TestGetVector_%d_%d_%s_%s_%s",
suite.nq, suite.topK, suite.indexType, suite.metricType, funcutil.GenRandomStr())
const (
NB = 10000
dim = 128
)
pkFieldName := "pkField"
vecFieldName := "vecField"
pk := &schemapb.FieldSchema{
FieldID: 100,
Name: pkFieldName,
IsPrimaryKey: true,
Description: "",
DataType: suite.pkType,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "max_length",
Value: "100",
},
},
IndexParams: nil,
AutoID: false,
}
fVec := &schemapb.FieldSchema{
FieldID: 101,
Name: vecFieldName,
IsPrimaryKey: false,
Description: "",
DataType: suite.vecType,
TypeParams: []*commonpb.KeyValuePair{
{
Key: common.DimKey,
Value: fmt.Sprintf("%d", dim),
},
},
IndexParams: nil,
}
schema := constructSchema(collection, dim, false, pk, fVec)
marshaledSchema, err := proto.Marshal(schema)
suite.Require().NoError(err)
createCollectionStatus, err := suite.cluster.proxy.CreateCollection(suite.ctx, &milvuspb.CreateCollectionRequest{
CollectionName: collection,
Schema: marshaledSchema,
ShardsNum: 2,
})
suite.Require().NoError(err)
suite.Require().Equal(createCollectionStatus.GetErrorCode(), commonpb.ErrorCode_Success)
fieldsData := make([]*schemapb.FieldData, 0)
if suite.pkType == schemapb.DataType_Int64 {
fieldsData = append(fieldsData, newInt64FieldData(pkFieldName, NB))
} else {
fieldsData = append(fieldsData, newStringFieldData(pkFieldName, NB))
}
var vecFieldData *schemapb.FieldData
if suite.vecType == schemapb.DataType_FloatVector {
vecFieldData = newFloatVectorFieldData(vecFieldName, NB, dim)
} else {
vecFieldData = newBinaryVectorFieldData(vecFieldName, NB, dim)
}
fieldsData = append(fieldsData, vecFieldData)
hashKeys := generateHashKeys(NB)
_, err = suite.cluster.proxy.Insert(suite.ctx, &milvuspb.InsertRequest{
CollectionName: collection,
FieldsData: fieldsData,
HashKeys: hashKeys,
NumRows: uint32(NB),
})
suite.Require().NoError(err)
suite.Require().Equal(createCollectionStatus.GetErrorCode(), commonpb.ErrorCode_Success)
// flush
flushResp, err := suite.cluster.proxy.Flush(suite.ctx, &milvuspb.FlushRequest{
CollectionNames: []string{collection},
})
suite.Require().NoError(err)
segmentIDs, has := flushResp.GetCollSegIDs()[collection]
ids := segmentIDs.GetData()
suite.Require().NotEmpty(segmentIDs)
suite.Require().True(has)
segments, err := suite.cluster.metaWatcher.ShowSegments()
suite.Require().NoError(err)
suite.Require().NotEmpty(segments)
waitingForFlush(suite.ctx, suite.cluster, ids)
// create index
_, err = suite.cluster.proxy.CreateIndex(suite.ctx, &milvuspb.CreateIndexRequest{
CollectionName: collection,
FieldName: vecFieldName,
IndexName: "_default",
ExtraParams: constructIndexParam(dim, suite.indexType, suite.metricType),
})
suite.Require().NoError(err)
suite.Require().Equal(createCollectionStatus.GetErrorCode(), commonpb.ErrorCode_Success)
// load
_, err = suite.cluster.proxy.LoadCollection(suite.ctx, &milvuspb.LoadCollectionRequest{
CollectionName: collection,
})
suite.Require().NoError(err)
suite.Require().Equal(createCollectionStatus.GetErrorCode(), commonpb.ErrorCode_Success)
waitingForLoad(suite.ctx, suite.cluster, collection)
// search
nq := suite.nq
topk := suite.topK
outputFields := []string{vecFieldName}
params := getSearchParams(suite.indexType, suite.metricType)
searchReq := constructSearchRequest("", collection, "",
vecFieldName, suite.vecType, outputFields, suite.metricType, params, nq, dim, topk, -1)
searchResp, err := suite.cluster.proxy.Search(suite.ctx, searchReq)
suite.Require().NoError(err)
suite.Require().Equal(searchResp.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success)
result := searchResp.GetResults()
if suite.pkType == schemapb.DataType_Int64 {
suite.Require().Len(result.GetIds().GetIntId().GetData(), nq*topk)
} else {
suite.Require().Len(result.GetIds().GetStrId().GetData(), nq*topk)
}
suite.Require().Len(result.GetScores(), nq*topk)
suite.Require().GreaterOrEqual(len(result.GetFieldsData()), 1)
var vecFieldIndex = -1
for i, fieldData := range result.GetFieldsData() {
if typeutil.IsVectorType(fieldData.GetType()) {
vecFieldIndex = i
break
}
}
suite.Require().EqualValues(nq, result.GetNumQueries())
suite.Require().EqualValues(topk, result.GetTopK())
// check output vectors
if suite.vecType == schemapb.DataType_FloatVector {
suite.Require().Len(result.GetFieldsData()[vecFieldIndex].GetVectors().GetFloatVector().GetData(), nq*topk*dim)
rawData := vecFieldData.GetVectors().GetFloatVector().GetData()
resData := result.GetFieldsData()[vecFieldIndex].GetVectors().GetFloatVector().GetData()
if suite.pkType == schemapb.DataType_Int64 {
for i, id := range result.GetIds().GetIntId().GetData() {
expect := rawData[int(id)*dim : (int(id)+1)*dim]
actual := resData[i*dim : (i+1)*dim]
suite.Require().ElementsMatch(expect, actual)
}
} else {
for i, idStr := range result.GetIds().GetStrId().GetData() {
id, err := strconv.Atoi(idStr)
suite.Require().NoError(err)
expect := rawData[id*dim : (id+1)*dim]
actual := resData[i*dim : (i+1)*dim]
suite.Require().ElementsMatch(expect, actual)
}
}
} else {
suite.Require().Len(result.GetFieldsData()[vecFieldIndex].GetVectors().GetBinaryVector(), nq*topk*dim/8)
rawData := vecFieldData.GetVectors().GetBinaryVector()
resData := result.GetFieldsData()[vecFieldIndex].GetVectors().GetBinaryVector()
if suite.pkType == schemapb.DataType_Int64 {
for i, id := range result.GetIds().GetIntId().GetData() {
dataBytes := dim / 8
for j := 0; j < dataBytes; j++ {
expect := rawData[int(id)*dataBytes+j]
actual := resData[i*dataBytes+j]
suite.Require().Equal(expect, actual)
}
}
} else {
for i, idStr := range result.GetIds().GetStrId().GetData() {
dataBytes := dim / 8
id, err := strconv.Atoi(idStr)
suite.Require().NoError(err)
for j := 0; j < dataBytes; j++ {
expect := rawData[id*dataBytes+j]
actual := resData[i*dataBytes+j]
suite.Require().Equal(expect, actual)
}
}
}
}
status, err := suite.cluster.proxy.DropCollection(suite.ctx, &milvuspb.DropCollectionRequest{
CollectionName: collection,
})
suite.Require().NoError(err)
suite.Require().Equal(status.GetErrorCode(), commonpb.ErrorCode_Success)
}
func (suite *TestGetVectorSuite) TestGetVector_FLAT() {
suite.nq = 10
suite.topK = 10
suite.indexType = IndexFaissIDMap
suite.metricType = distance.L2
suite.pkType = schemapb.DataType_Int64
suite.vecType = schemapb.DataType_FloatVector
suite.run()
}
func (suite *TestGetVectorSuite) TestGetVector_IVF_FLAT() {
suite.nq = 10
suite.topK = 10
suite.indexType = IndexFaissIvfFlat
suite.metricType = distance.L2
suite.pkType = schemapb.DataType_Int64
suite.vecType = schemapb.DataType_FloatVector
suite.run()
}
func (suite *TestGetVectorSuite) TestGetVector_IVF_PQ() {
suite.nq = 10
suite.topK = 10
suite.indexType = IndexFaissIvfPQ
suite.metricType = distance.L2
suite.pkType = schemapb.DataType_Int64
suite.vecType = schemapb.DataType_FloatVector
suite.run()
}
func (suite *TestGetVectorSuite) TestGetVector_IVF_SQ8() {
suite.nq = 10
suite.topK = 10
suite.indexType = IndexFaissIvfSQ8
suite.metricType = distance.L2
suite.pkType = schemapb.DataType_Int64
suite.vecType = schemapb.DataType_FloatVector
suite.run()
}
func (suite *TestGetVectorSuite) TestGetVector_HNSW() {
suite.nq = 10
suite.topK = 10
suite.indexType = IndexHNSW
suite.metricType = distance.L2
suite.pkType = schemapb.DataType_Int64
suite.vecType = schemapb.DataType_FloatVector
suite.run()
}
func (suite *TestGetVectorSuite) TestGetVector_IP() {
suite.nq = 10
suite.topK = 10
suite.indexType = IndexHNSW
suite.metricType = distance.IP
suite.pkType = schemapb.DataType_Int64
suite.vecType = schemapb.DataType_FloatVector
suite.run()
}
func (suite *TestGetVectorSuite) TestGetVector_StringPK() {
suite.nq = 10
suite.topK = 10
suite.indexType = IndexHNSW
suite.metricType = distance.L2
suite.pkType = schemapb.DataType_VarChar
suite.vecType = schemapb.DataType_FloatVector
suite.run()
}
func (suite *TestGetVectorSuite) TestGetVector_BinaryVector() {
suite.nq = 10
suite.topK = 10
suite.indexType = IndexFaissBinIvfFlat
suite.metricType = distance.JACCARD
suite.pkType = schemapb.DataType_Int64
suite.vecType = schemapb.DataType_BinaryVector
suite.run()
}
func (suite *TestGetVectorSuite) TestGetVector_Big_NQ_TOPK() {
suite.nq = 10000
suite.topK = 200
suite.indexType = IndexHNSW
suite.metricType = distance.L2
suite.pkType = schemapb.DataType_Int64
suite.vecType = schemapb.DataType_FloatVector
suite.run()
}
//func (suite *TestGetVectorSuite) TestGetVector_DISKANN() {
// suite.nq = 10
// suite.topK = 10
// suite.indexType = IndexDISKANN
// suite.metricType = distance.L2
// suite.pkType = schemapb.DataType_Int64
// suite.vecType = schemapb.DataType_FloatVector
// suite.run()
//}
func (suite *TestGetVectorSuite) TearDownTest() {
err := suite.cluster.Stop()
suite.Require().NoError(err)
suite.cancel()
}
func TestGetVector(t *testing.T) {
suite.Run(t, new(TestGetVectorSuite))
}

View File

@ -17,17 +17,11 @@
package integration
import (
"bytes"
"context"
"encoding/binary"
"encoding/json"
"fmt"
"math/rand"
"strconv"
"testing"
"time"
"github.com/cockroachdb/errors"
"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/assert"
"go.uber.org/zap"
@ -42,59 +36,27 @@ import (
)
func TestHelloMilvus(t *testing.T) {
ctx := context.Background()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*180)
defer cancel()
c, err := StartMiniCluster(ctx)
assert.NoError(t, err)
err = c.Start()
assert.NoError(t, err)
defer c.Stop()
assert.NoError(t, err)
defer func() {
err = c.Stop()
assert.NoError(t, err)
cancel()
}()
prefix := "TestHelloMilvus"
dbName := ""
collectionName := prefix + funcutil.GenRandomStr()
int64Field := "int64"
floatVecField := "fvec"
dim := 128
rowNum := 3000
const (
dim = 128
dbName = ""
rowNum = 3000
)
constructCollectionSchema := func() *schemapb.CollectionSchema {
pk := &schemapb.FieldSchema{
FieldID: 0,
Name: int64Field,
IsPrimaryKey: true,
Description: "",
DataType: schemapb.DataType_Int64,
TypeParams: nil,
IndexParams: nil,
AutoID: true,
}
fVec := &schemapb.FieldSchema{
FieldID: 0,
Name: floatVecField,
IsPrimaryKey: false,
Description: "",
DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: strconv.Itoa(dim),
},
},
IndexParams: nil,
AutoID: false,
}
return &schemapb.CollectionSchema{
Name: collectionName,
Description: "",
AutoID: false,
Fields: []*schemapb.FieldSchema{
pk,
fVec,
},
}
}
schema := constructCollectionSchema()
collectionName := "TestHelloMilvus" + funcutil.GenRandomStr()
schema := constructSchema(collectionName, dim, true)
marshaledSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
@ -137,6 +99,7 @@ func TestHelloMilvus(t *testing.T) {
segmentIDs, has := flushResp.GetCollSegIDs()[collectionName]
ids := segmentIDs.GetData()
assert.NotEmpty(t, segmentIDs)
assert.True(t, has)
segments, err := c.metaWatcher.ShowSegments()
assert.NoError(t, err)
@ -144,52 +107,14 @@ func TestHelloMilvus(t *testing.T) {
for _, segment := range segments {
log.Info("ShowSegments result", zap.String("segment", segment.String()))
}
if has && len(ids) > 0 {
flushed := func() bool {
resp, err := c.proxy.GetFlushState(ctx, &milvuspb.GetFlushStateRequest{
SegmentIDs: ids,
})
if err != nil {
//panic(errors.New("GetFlushState failed"))
return false
}
return resp.GetFlushed()
}
for !flushed() {
// respect context deadline/cancel
select {
case <-ctx.Done():
panic(errors.New("deadline exceeded"))
default:
}
time.Sleep(500 * time.Millisecond)
}
}
waitingForFlush(ctx, c, ids)
// create index
createIndexStatus, err := c.proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{
CollectionName: collectionName,
FieldName: floatVecField,
IndexName: "_default",
ExtraParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: strconv.Itoa(dim),
},
{
Key: common.MetricTypeKey,
Value: distance.L2,
},
{
Key: "index_type",
Value: "IVF_FLAT",
},
{
Key: "nlist",
Value: strconv.Itoa(10),
},
},
ExtraParams: constructIndexParam(dim, IndexFaissIvfFlat, distance.L2),
})
if createIndexStatus.GetErrorCode() != commonpb.ErrorCode_Success {
log.Warn("createIndexStatus fail reason", zap.String("reason", createIndexStatus.GetReason()))
@ -207,30 +132,17 @@ func TestHelloMilvus(t *testing.T) {
log.Warn("loadStatus fail reason", zap.String("reason", loadStatus.GetReason()))
}
assert.Equal(t, commonpb.ErrorCode_Success, loadStatus.GetErrorCode())
for {
loadProgress, err := c.proxy.GetLoadingProgress(ctx, &milvuspb.GetLoadingProgressRequest{
CollectionName: collectionName,
})
if err != nil {
panic("GetLoadingProgress fail")
}
if loadProgress.GetProgress() == 100 {
break
}
time.Sleep(500 * time.Millisecond)
}
waitingForLoad(ctx, c, collectionName)
// search
expr := fmt.Sprintf("%s > 0", "int64")
expr := fmt.Sprintf("%s > 0", int64Field)
nq := 10
topk := 10
roundDecimal := -1
nprobe := 10
params := make(map[string]int)
params["nprobe"] = nprobe
params := getSearchParams(IndexFaissIvfFlat, distance.L2)
searchReq := constructSearchRequest("", collectionName, expr,
floatVecField, distance.L2, params, nq, dim, topk, roundDecimal)
floatVecField, schemapb.DataType_FloatVector, nil, distance.L2, params, nq, dim, topk, roundDecimal)
searchResult, err := c.proxy.Search(ctx, searchReq)
@ -242,155 +154,3 @@ func TestHelloMilvus(t *testing.T) {
log.Info("TestHelloMilvus succeed")
}
const (
AnnsFieldKey = "anns_field"
TopKKey = "topk"
NQKey = "nq"
MetricTypeKey = "metric_type"
SearchParamsKey = "params"
RoundDecimalKey = "round_decimal"
OffsetKey = "offset"
LimitKey = "limit"
)
func constructSearchRequest(
dbName, collectionName string,
expr string,
floatVecField string,
metricType string,
params map[string]int,
nq, dim, topk, roundDecimal int,
) *milvuspb.SearchRequest {
b, err := json.Marshal(params)
if err != nil {
panic(err)
}
plg := constructPlaceholderGroup(nq, dim)
plgBs, err := proto.Marshal(plg)
if err != nil {
panic(err)
}
return &milvuspb.SearchRequest{
Base: nil,
DbName: dbName,
CollectionName: collectionName,
PartitionNames: nil,
Dsl: expr,
PlaceholderGroup: plgBs,
DslType: commonpb.DslType_BoolExprV1,
OutputFields: nil,
SearchParams: []*commonpb.KeyValuePair{
{
Key: common.MetricTypeKey,
Value: metricType,
},
{
Key: SearchParamsKey,
Value: string(b),
},
{
Key: AnnsFieldKey,
Value: floatVecField,
},
{
Key: TopKKey,
Value: strconv.Itoa(topk),
},
{
Key: RoundDecimalKey,
Value: strconv.Itoa(roundDecimal),
},
},
TravelTimestamp: 0,
GuaranteeTimestamp: 0,
}
}
func constructPlaceholderGroup(
nq, dim int,
) *commonpb.PlaceholderGroup {
values := make([][]byte, 0, nq)
for i := 0; i < nq; i++ {
bs := make([]byte, 0, dim*4)
for j := 0; j < dim; j++ {
var buffer bytes.Buffer
f := rand.Float32()
err := binary.Write(&buffer, common.Endian, f)
if err != nil {
panic(err)
}
bs = append(bs, buffer.Bytes()...)
}
values = append(values, bs)
}
return &commonpb.PlaceholderGroup{
Placeholders: []*commonpb.PlaceholderValue{
{
Tag: "$0",
Type: commonpb.PlaceholderType_FloatVector,
Values: values,
},
},
}
}
func newFloatVectorFieldData(fieldName string, numRows, dim int) *schemapb.FieldData {
return &schemapb.FieldData{
Type: schemapb.DataType_FloatVector,
FieldName: fieldName,
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Dim: int64(dim),
Data: &schemapb.VectorField_FloatVector{
FloatVector: &schemapb.FloatArray{
Data: generateFloatVectors(numRows, dim),
},
},
},
},
}
}
func newInt64PrimaryKey(fieldName string, numRows int) *schemapb.FieldData {
return &schemapb.FieldData{
Type: schemapb.DataType_Int64,
FieldName: fieldName,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_LongData{
LongData: &schemapb.LongArray{
Data: generateInt64Array(numRows),
},
},
},
},
}
}
func generateFloatVectors(numRows, dim int) []float32 {
total := numRows * dim
ret := make([]float32, 0, total)
for i := 0; i < total; i++ {
ret = append(ret, rand.Float32())
}
return ret
}
func generateInt64Array(numRows int) []int64 {
ret := make([]int64, 0, numRows)
for i := 0; i < numRows; i++ {
ret = append(ret, int64(rand.Int()))
}
return ret
}
func generateHashKeys(numRows int) []uint32 {
ret := make([]uint32, 0, numRows)
for i := 0; i < numRows; i++ {
ret = append(ret, rand.Uint32())
}
return ret
}

View File

@ -19,16 +19,13 @@ package integration
import (
"context"
"fmt"
"strconv"
"testing"
"time"
"github.com/cockroachdb/errors"
"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/assert"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/pkg/common"
@ -39,59 +36,24 @@ import (
)
func TestRangeSearchIP(t *testing.T) {
ctx := context.Background()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*180)
c, err := StartMiniCluster(ctx)
assert.NoError(t, err)
err = c.Start()
assert.NoError(t, err)
defer c.Stop()
assert.NoError(t, err)
defer func() {
err = c.Stop()
assert.NoError(t, err)
cancel()
}()
prefix := "TestRangeSearchIP"
dbName := ""
collectionName := prefix + funcutil.GenRandomStr()
int64Field := "int64"
floatVecField := "fvec"
dim := 128
rowNum := 3000
constructCollectionSchema := func() *schemapb.CollectionSchema {
pk := &schemapb.FieldSchema{
FieldID: 0,
Name: int64Field,
IsPrimaryKey: true,
Description: "",
DataType: schemapb.DataType_Int64,
TypeParams: nil,
IndexParams: nil,
AutoID: true,
}
fVec := &schemapb.FieldSchema{
FieldID: 0,
Name: floatVecField,
IsPrimaryKey: false,
Description: "",
DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: strconv.Itoa(dim),
},
},
IndexParams: nil,
AutoID: false,
}
return &schemapb.CollectionSchema{
Name: collectionName,
Description: "",
AutoID: false,
Fields: []*schemapb.FieldSchema{
pk,
fVec,
},
}
}
schema := constructCollectionSchema()
schema := constructSchema(collectionName, dim, true)
marshaledSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
@ -133,6 +95,7 @@ func TestRangeSearchIP(t *testing.T) {
})
assert.NoError(t, err)
segmentIDs, has := flushResp.GetCollSegIDs()[collectionName]
assert.True(t, has)
ids := segmentIDs.GetData()
assert.NotEmpty(t, segmentIDs)
@ -142,52 +105,14 @@ func TestRangeSearchIP(t *testing.T) {
for _, segment := range segments {
log.Info("ShowSegments result", zap.String("segment", segment.String()))
}
if has && len(ids) > 0 {
flushed := func() bool {
resp, err := c.proxy.GetFlushState(ctx, &milvuspb.GetFlushStateRequest{
SegmentIDs: ids,
})
if err != nil {
//panic(errors.New("GetFlushState failed"))
return false
}
return resp.GetFlushed()
}
for !flushed() {
// respect context deadline/cancel
select {
case <-ctx.Done():
panic(errors.New("deadline exceeded"))
default:
}
time.Sleep(500 * time.Millisecond)
}
}
waitingForFlush(ctx, c, ids)
// create index
createIndexStatus, err := c.proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{
CollectionName: collectionName,
FieldName: floatVecField,
IndexName: "_default",
ExtraParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: strconv.Itoa(dim),
},
{
Key: common.MetricTypeKey,
Value: distance.IP,
},
{
Key: "index_type",
Value: "IVF_FLAT",
},
{
Key: "nlist",
Value: strconv.Itoa(10),
},
},
ExtraParams: constructIndexParam(dim, IndexFaissIvfFlat, distance.IP),
})
assert.NoError(t, err)
err = merr.Error(createIndexStatus)
@ -205,34 +130,21 @@ func TestRangeSearchIP(t *testing.T) {
if err != nil {
log.Warn("LoadCollection fail reason", zap.Error(err))
}
for {
loadProgress, err := c.proxy.GetLoadingProgress(ctx, &milvuspb.GetLoadingProgressRequest{
CollectionName: collectionName,
})
if err != nil {
panic("GetLoadingProgress fail")
}
if loadProgress.GetProgress() == 100 {
break
}
time.Sleep(500 * time.Millisecond)
}
waitingForLoad(ctx, c, collectionName)
// search
expr := fmt.Sprintf("%s > 0", "int64")
expr := fmt.Sprintf("%s > 0", int64Field)
nq := 10
topk := 10
roundDecimal := -1
nprobe := 10
radius := 10
filter := 20
params := make(map[string]int)
params["nprobe"] = nprobe
params := getSearchParams(IndexFaissIvfFlat, distance.IP)
// only pass in radius when range search
params["radius"] = radius
searchReq := constructSearchRequest("", collectionName, expr,
floatVecField, distance.IP, params, nq, dim, topk, roundDecimal)
floatVecField, schemapb.DataType_FloatVector, nil, distance.IP, params, nq, dim, topk, roundDecimal)
searchResult, _ := c.proxy.Search(ctx, searchReq)
@ -245,7 +157,7 @@ func TestRangeSearchIP(t *testing.T) {
// pass in radius and range_filter when range search
params["range_filter"] = filter
searchReq = constructSearchRequest("", collectionName, expr,
floatVecField, distance.IP, params, nq, dim, topk, roundDecimal)
floatVecField, schemapb.DataType_FloatVector, nil, distance.IP, params, nq, dim, topk, roundDecimal)
searchResult, _ = c.proxy.Search(ctx, searchReq)
@ -259,7 +171,7 @@ func TestRangeSearchIP(t *testing.T) {
params["radius"] = filter
params["range_filter"] = radius
searchReq = constructSearchRequest("", collectionName, expr,
floatVecField, distance.IP, params, nq, dim, topk, roundDecimal)
floatVecField, schemapb.DataType_FloatVector, nil, distance.IP, params, nq, dim, topk, roundDecimal)
searchResult, _ = c.proxy.Search(ctx, searchReq)
@ -277,59 +189,24 @@ func TestRangeSearchIP(t *testing.T) {
}
func TestRangeSearchL2(t *testing.T) {
ctx := context.Background()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*180)
c, err := StartMiniCluster(ctx)
assert.NoError(t, err)
err = c.Start()
assert.NoError(t, err)
defer c.Stop()
assert.NoError(t, err)
defer func() {
err = c.Stop()
assert.NoError(t, err)
cancel()
}()
prefix := "TestRangeSearchL2"
dbName := ""
collectionName := prefix + funcutil.GenRandomStr()
int64Field := "int64"
floatVecField := "fvec"
dim := 128
rowNum := 3000
constructCollectionSchema := func() *schemapb.CollectionSchema {
pk := &schemapb.FieldSchema{
FieldID: 0,
Name: int64Field,
IsPrimaryKey: true,
Description: "",
DataType: schemapb.DataType_Int64,
TypeParams: nil,
IndexParams: nil,
AutoID: true,
}
fVec := &schemapb.FieldSchema{
FieldID: 0,
Name: floatVecField,
IsPrimaryKey: false,
Description: "",
DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: strconv.Itoa(dim),
},
},
IndexParams: nil,
AutoID: false,
}
return &schemapb.CollectionSchema{
Name: collectionName,
Description: "",
AutoID: false,
Fields: []*schemapb.FieldSchema{
pk,
fVec,
},
}
}
schema := constructCollectionSchema()
schema := constructSchema(collectionName, dim, true)
marshaledSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
@ -371,6 +248,7 @@ func TestRangeSearchL2(t *testing.T) {
})
assert.NoError(t, err)
segmentIDs, has := flushResp.GetCollSegIDs()[collectionName]
assert.True(t, has)
ids := segmentIDs.GetData()
assert.NotEmpty(t, segmentIDs)
@ -380,52 +258,14 @@ func TestRangeSearchL2(t *testing.T) {
for _, segment := range segments {
log.Info("ShowSegments result", zap.String("segment", segment.String()))
}
if has && len(ids) > 0 {
flushed := func() bool {
resp, err := c.proxy.GetFlushState(ctx, &milvuspb.GetFlushStateRequest{
SegmentIDs: ids,
})
if err != nil {
//panic(errors.New("GetFlushState failed"))
return false
}
return resp.GetFlushed()
}
for !flushed() {
// respect context deadline/cancel
select {
case <-ctx.Done():
panic(errors.New("deadline exceeded"))
default:
}
time.Sleep(500 * time.Millisecond)
}
}
waitingForFlush(ctx, c, ids)
// create index
createIndexStatus, err := c.proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{
CollectionName: collectionName,
FieldName: floatVecField,
IndexName: "_default",
ExtraParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: strconv.Itoa(dim),
},
{
Key: common.MetricTypeKey,
Value: distance.L2,
},
{
Key: "index_type",
Value: "IVF_FLAT",
},
{
Key: "nlist",
Value: strconv.Itoa(10),
},
},
ExtraParams: constructIndexParam(dim, IndexFaissIvfFlat, distance.L2),
})
assert.NoError(t, err)
err = merr.Error(createIndexStatus)
@ -443,34 +283,20 @@ func TestRangeSearchL2(t *testing.T) {
if err != nil {
log.Warn("LoadCollection fail reason", zap.Error(err))
}
for {
loadProgress, err := c.proxy.GetLoadingProgress(ctx, &milvuspb.GetLoadingProgressRequest{
CollectionName: collectionName,
})
if err != nil {
panic("GetLoadingProgress fail")
}
if loadProgress.GetProgress() == 100 {
break
}
time.Sleep(500 * time.Millisecond)
}
waitingForLoad(ctx, c, collectionName)
// search
expr := fmt.Sprintf("%s > 0", "int64")
expr := fmt.Sprintf("%s > 0", int64Field)
nq := 10
topk := 10
roundDecimal := -1
nprobe := 10
radius := 20
filter := 10
params := make(map[string]int)
params["nprobe"] = nprobe
params := getSearchParams(IndexFaissIvfFlat, distance.L2)
// only pass in radius when range search
params["radius"] = radius
searchReq := constructSearchRequest("", collectionName, expr,
floatVecField, distance.L2, params, nq, dim, topk, roundDecimal)
floatVecField, schemapb.DataType_FloatVector, nil, distance.L2, params, nq, dim, topk, roundDecimal)
searchResult, _ := c.proxy.Search(ctx, searchReq)
@ -483,7 +309,7 @@ func TestRangeSearchL2(t *testing.T) {
// pass in radius and range_filter when range search
params["range_filter"] = filter
searchReq = constructSearchRequest("", collectionName, expr,
floatVecField, distance.L2, params, nq, dim, topk, roundDecimal)
floatVecField, schemapb.DataType_FloatVector, nil, distance.L2, params, nq, dim, topk, roundDecimal)
searchResult, _ = c.proxy.Search(ctx, searchReq)
@ -497,7 +323,7 @@ func TestRangeSearchL2(t *testing.T) {
params["radius"] = filter
params["range_filter"] = radius
searchReq = constructSearchRequest("", collectionName, expr,
floatVecField, distance.L2, params, nq, dim, topk, roundDecimal)
floatVecField, schemapb.DataType_FloatVector, nil, distance.L2, params, nq, dim, topk, roundDecimal)
searchResult, _ = c.proxy.Search(ctx, searchReq)

View File

@ -19,13 +19,10 @@ package integration
import (
"context"
"fmt"
"strconv"
"testing"
"time"
"github.com/cockroachdb/errors"
"github.com/golang/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/pkg/common"
@ -38,59 +35,25 @@ import (
)
func TestUpsert(t *testing.T) {
ctx := context.Background()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*180)
c, err := StartMiniCluster(ctx)
assert.NoError(t, err)
err = c.Start()
assert.NoError(t, err)
defer c.Stop()
defer func() {
err = c.Stop()
assert.NoError(t, err)
cancel()
}()
assert.NoError(t, err)
prefix := "TestUpsert"
dbName := ""
collectionName := prefix + funcutil.GenRandomStr()
int64Field := "int64"
floatVecField := "fvec"
dim := 128
rowNum := 3000
constructCollectionSchema := func() *schemapb.CollectionSchema {
pk := &schemapb.FieldSchema{
FieldID: 0,
Name: int64Field,
IsPrimaryKey: true,
Description: "",
DataType: schemapb.DataType_Int64,
TypeParams: nil,
IndexParams: nil,
AutoID: false,
}
fVec := &schemapb.FieldSchema{
FieldID: 0,
Name: floatVecField,
IsPrimaryKey: false,
Description: "",
DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: strconv.Itoa(dim),
},
},
IndexParams: nil,
AutoID: false,
}
return &schemapb.CollectionSchema{
Name: collectionName,
Description: "",
AutoID: false,
Fields: []*schemapb.FieldSchema{
pk,
fVec,
},
}
}
schema := constructCollectionSchema()
schema := constructSchema(collectionName, dim, false)
marshaledSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
@ -113,7 +76,7 @@ func TestUpsert(t *testing.T) {
assert.True(t, merr.Ok(showCollectionsResp.GetStatus()))
log.Info("ShowCollections result", zap.Any("showCollectionsResp", showCollectionsResp))
pkFieldData := newInt64PrimaryKey(int64Field, rowNum)
pkFieldData := newInt64FieldData(int64Field, rowNum)
fVecColumn := newFloatVectorFieldData(floatVecField, rowNum, dim)
hashKeys := generateHashKeys(rowNum)
upsertResult, err := c.proxy.Upsert(ctx, &milvuspb.UpsertRequest{
@ -133,6 +96,7 @@ func TestUpsert(t *testing.T) {
})
assert.NoError(t, err)
segmentIDs, has := flushResp.GetCollSegIDs()[collectionName]
assert.True(t, has)
ids := segmentIDs.GetData()
assert.NotEmpty(t, segmentIDs)
@ -142,52 +106,14 @@ func TestUpsert(t *testing.T) {
for _, segment := range segments {
log.Info("ShowSegments result", zap.String("segment", segment.String()))
}
if has && len(ids) > 0 {
flushed := func() bool {
resp, err := c.proxy.GetFlushState(ctx, &milvuspb.GetFlushStateRequest{
SegmentIDs: ids,
})
if err != nil {
//panic(errors.New("GetFlushState failed"))
return false
}
return resp.GetFlushed()
}
for !flushed() {
// respect context deadline/cancel
select {
case <-ctx.Done():
panic(errors.New("deadline exceeded"))
default:
}
time.Sleep(500 * time.Millisecond)
}
}
waitingForFlush(ctx, c, ids)
// create index
createIndexStatus, err := c.proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{
CollectionName: collectionName,
FieldName: floatVecField,
IndexName: "_default",
ExtraParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: strconv.Itoa(dim),
},
{
Key: common.MetricTypeKey,
Value: distance.L2,
},
{
Key: "index_type",
Value: "IVF_FLAT",
},
{
Key: "nlist",
Value: strconv.Itoa(10),
},
},
ExtraParams: constructIndexParam(dim, IndexFaissIvfFlat, distance.IP),
})
assert.NoError(t, err)
err = merr.Error(createIndexStatus)
@ -205,30 +131,16 @@ func TestUpsert(t *testing.T) {
if err != nil {
log.Warn("LoadCollection fail reason", zap.Error(err))
}
for {
loadProgress, err := c.proxy.GetLoadingProgress(ctx, &milvuspb.GetLoadingProgressRequest{
CollectionName: collectionName,
})
if err != nil {
panic("GetLoadingProgress fail")
}
if loadProgress.GetProgress() == 100 {
break
}
time.Sleep(500 * time.Millisecond)
}
waitingForLoad(ctx, c, collectionName)
// search
expr := fmt.Sprintf("%s > 0", "int64")
expr := fmt.Sprintf("%s > 0", int64Field)
nq := 10
topk := 10
roundDecimal := -1
nprobe := 10
params := make(map[string]int)
params["nprobe"] = nprobe
params := getSearchParams(IndexFaissIvfFlat, "")
searchReq := constructSearchRequest("", collectionName, expr,
floatVecField, distance.IP, params, nq, dim, topk, roundDecimal)
floatVecField, schemapb.DataType_FloatVector, nil, distance.IP, params, nq, dim, topk, roundDecimal)
searchResult, _ := c.proxy.Search(ctx, searchReq)

View File

@ -0,0 +1,108 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package integration
import (
"fmt"
"strconv"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/util/indexparamcheck"
)
const (
IndexRaftIvfFlat = indexparamcheck.IndexRaftIvfFlat
IndexRaftIvfPQ = indexparamcheck.IndexRaftIvfPQ
IndexFaissIDMap = indexparamcheck.IndexFaissIDMap
IndexFaissIvfFlat = indexparamcheck.IndexFaissIvfFlat
IndexFaissIvfPQ = indexparamcheck.IndexFaissIvfPQ
IndexFaissIvfSQ8 = indexparamcheck.IndexFaissIvfSQ8
IndexFaissBinIDMap = indexparamcheck.IndexFaissBinIDMap
IndexFaissBinIvfFlat = indexparamcheck.IndexFaissBinIvfFlat
IndexHNSW = indexparamcheck.IndexHNSW
IndexDISKANN = indexparamcheck.IndexDISKANN
)
func constructIndexParam(dim int, indexType string, metricType string) []*commonpb.KeyValuePair {
params := []*commonpb.KeyValuePair{
{
Key: common.DimKey,
Value: strconv.Itoa(dim),
},
{
Key: common.MetricTypeKey,
Value: metricType,
},
{
Key: common.IndexTypeKey,
Value: indexType,
},
}
switch indexType {
case IndexFaissIDMap, IndexFaissBinIDMap:
// no index param is required
case IndexFaissIvfFlat, IndexFaissBinIvfFlat, IndexFaissIvfSQ8:
params = append(params, &commonpb.KeyValuePair{
Key: "nlist",
Value: "100",
})
case IndexFaissIvfPQ:
params = append(params, &commonpb.KeyValuePair{
Key: "nlist",
Value: "100",
})
params = append(params, &commonpb.KeyValuePair{
Key: "m",
Value: "16",
})
params = append(params, &commonpb.KeyValuePair{
Key: "nbits",
Value: "8",
})
case IndexHNSW:
params = append(params, &commonpb.KeyValuePair{
Key: "M",
Value: "16",
})
params = append(params, &commonpb.KeyValuePair{
Key: "efConstruction",
Value: "200",
})
case IndexDISKANN:
default:
panic(fmt.Sprintf("unimplemented index param for %s, please help to improve it", indexType))
}
return params
}
func getSearchParams(indexType string, metricType string) map[string]any {
params := make(map[string]any)
switch indexType {
case IndexFaissIDMap, IndexFaissBinIDMap:
params["metric_type"] = metricType
case IndexFaissIvfFlat, IndexFaissBinIvfFlat, IndexFaissIvfSQ8, IndexFaissIvfPQ:
params["nprobe"] = 8
case IndexHNSW:
params["ef"] = 200
case IndexDISKANN:
params["search_list"] = 5
default:
panic(fmt.Sprintf("unimplemented search param for %s, please help to improve it", indexType))
}
return params
}

View File

@ -0,0 +1,154 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package integration
import (
"context"
"fmt"
"math/rand"
"time"
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
)
func waitingForFlush(ctx context.Context, cluster *MiniCluster, segIDs []int64) {
flushed := func() bool {
resp, err := cluster.proxy.GetFlushState(ctx, &milvuspb.GetFlushStateRequest{
SegmentIDs: segIDs,
})
if err != nil {
return false
}
return resp.GetFlushed()
}
for !flushed() {
select {
case <-ctx.Done():
panic("flush timeout")
default:
time.Sleep(500 * time.Millisecond)
}
}
}
func newInt64FieldData(fieldName string, numRows int) *schemapb.FieldData {
return &schemapb.FieldData{
Type: schemapb.DataType_Int64,
FieldName: fieldName,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_LongData{
LongData: &schemapb.LongArray{
Data: generateInt64Array(numRows),
},
},
},
},
}
}
func newStringFieldData(fieldName string, numRows int) *schemapb.FieldData {
return &schemapb.FieldData{
Type: schemapb.DataType_Int64,
FieldName: fieldName,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_StringData{
StringData: &schemapb.StringArray{
Data: generateStringArray(numRows),
},
},
},
},
}
}
func newFloatVectorFieldData(fieldName string, numRows, dim int) *schemapb.FieldData {
return &schemapb.FieldData{
Type: schemapb.DataType_FloatVector,
FieldName: fieldName,
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Dim: int64(dim),
Data: &schemapb.VectorField_FloatVector{
FloatVector: &schemapb.FloatArray{
Data: generateFloatVectors(numRows, dim),
},
},
},
},
}
}
func newBinaryVectorFieldData(fieldName string, numRows, dim int) *schemapb.FieldData {
return &schemapb.FieldData{
Type: schemapb.DataType_BinaryVector,
FieldName: fieldName,
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Dim: int64(dim),
Data: &schemapb.VectorField_BinaryVector{
BinaryVector: generateBinaryVectors(numRows, dim),
},
},
},
}
}
func generateInt64Array(numRows int) []int64 {
ret := make([]int64, numRows)
for i := 0; i < numRows; i++ {
ret[i] = int64(i)
}
return ret
}
func generateStringArray(numRows int) []string {
ret := make([]string, numRows)
for i := 0; i < numRows; i++ {
ret[i] = fmt.Sprintf("%d", i)
}
return ret
}
func generateFloatVectors(numRows, dim int) []float32 {
total := numRows * dim
ret := make([]float32, 0, total)
for i := 0; i < total; i++ {
ret = append(ret, rand.Float32())
}
return ret
}
func generateBinaryVectors(numRows, dim int) []byte {
total := (numRows * dim) / 8
ret := make([]byte, total)
_, err := rand.Read(ret)
if err != nil {
panic(err)
}
return ret
}
func generateHashKeys(numRows int) []uint32 {
ret := make([]uint32, 0, numRows)
for i := 0; i < numRows; i++ {
ret = append(ret, rand.Uint32())
}
return ret
}

View File

@ -0,0 +1,166 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package integration
import (
"bytes"
"context"
"encoding/binary"
"encoding/json"
"math/rand"
"strconv"
"time"
"github.com/golang/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/pkg/common"
)
const (
AnnsFieldKey = "anns_field"
TopKKey = "topk"
NQKey = "nq"
MetricTypeKey = "metric_type"
SearchParamsKey = "params"
RoundDecimalKey = "round_decimal"
OffsetKey = "offset"
LimitKey = "limit"
)
func waitingForLoad(ctx context.Context, cluster *MiniCluster, collection string) {
getLoadingProgress := func() *milvuspb.GetLoadingProgressResponse {
loadProgress, err := cluster.proxy.GetLoadingProgress(ctx, &milvuspb.GetLoadingProgressRequest{
CollectionName: collection,
})
if err != nil {
panic("GetLoadingProgress fail")
}
return loadProgress
}
for getLoadingProgress().GetProgress() != 100 {
select {
case <-ctx.Done():
panic("load timeout")
default:
time.Sleep(500 * time.Millisecond)
}
}
}
func constructSearchRequest(
dbName, collectionName string,
expr string,
vecField string,
vectorType schemapb.DataType,
outputFields []string,
metricType string,
params map[string]any,
nq, dim int, topk, roundDecimal int,
) *milvuspb.SearchRequest {
b, err := json.Marshal(params)
if err != nil {
panic(err)
}
plg := constructPlaceholderGroup(nq, dim, vectorType)
plgBs, err := proto.Marshal(plg)
if err != nil {
panic(err)
}
return &milvuspb.SearchRequest{
Base: nil,
DbName: dbName,
CollectionName: collectionName,
PartitionNames: nil,
Dsl: expr,
PlaceholderGroup: plgBs,
DslType: commonpb.DslType_BoolExprV1,
OutputFields: outputFields,
SearchParams: []*commonpb.KeyValuePair{
{
Key: common.MetricTypeKey,
Value: metricType,
},
{
Key: SearchParamsKey,
Value: string(b),
},
{
Key: AnnsFieldKey,
Value: vecField,
},
{
Key: common.TopKKey,
Value: strconv.Itoa(topk),
},
{
Key: RoundDecimalKey,
Value: strconv.Itoa(roundDecimal),
},
},
TravelTimestamp: 0,
GuaranteeTimestamp: 0,
}
}
func constructPlaceholderGroup(nq, dim int, vectorType schemapb.DataType) *commonpb.PlaceholderGroup {
values := make([][]byte, 0, nq)
var placeholderType commonpb.PlaceholderType
switch vectorType {
case schemapb.DataType_FloatVector:
placeholderType = commonpb.PlaceholderType_FloatVector
for i := 0; i < nq; i++ {
bs := make([]byte, 0, dim*4)
for j := 0; j < dim; j++ {
var buffer bytes.Buffer
f := rand.Float32()
err := binary.Write(&buffer, common.Endian, f)
if err != nil {
panic(err)
}
bs = append(bs, buffer.Bytes()...)
}
values = append(values, bs)
}
case schemapb.DataType_BinaryVector:
placeholderType = commonpb.PlaceholderType_BinaryVector
for i := 0; i < nq; i++ {
total := dim / 8
ret := make([]byte, total)
_, err := rand.Read(ret)
if err != nil {
panic(err)
}
values = append(values, ret)
}
default:
panic("invalid vector data type")
}
return &commonpb.PlaceholderGroup{
Placeholders: []*commonpb.PlaceholderValue{
{
Tag: "$0",
Type: placeholderType,
Values: values,
},
},
}
}

View File

@ -0,0 +1,80 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package integration
import (
"fmt"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/pkg/common"
)
const (
boolField = "boolField"
int8Field = "int8Field"
int16Field = "int16Field"
int32Field = "int32Field"
int64Field = "int64Field"
floatField = "floatField"
doubleField = "doubleField"
varCharField = "varCharField"
floatVecField = "floatVecField"
binVecField = "binVecField"
)
func constructSchema(collection string, dim int, autoID bool, fields ...*schemapb.FieldSchema) *schemapb.CollectionSchema {
// if fields are specified, construct it
if len(fields) > 0 {
return &schemapb.CollectionSchema{
Name: collection,
AutoID: autoID,
Fields: fields,
}
}
// if no field is specified, use default
pk := &schemapb.FieldSchema{
FieldID: 100,
Name: int64Field,
IsPrimaryKey: true,
Description: "",
DataType: schemapb.DataType_Int64,
TypeParams: nil,
IndexParams: nil,
AutoID: autoID,
}
fVec := &schemapb.FieldSchema{
FieldID: 101,
Name: floatVecField,
IsPrimaryKey: false,
Description: "",
DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{
Key: common.DimKey,
Value: fmt.Sprintf("%d", dim),
},
},
IndexParams: nil,
}
return &schemapb.CollectionSchema{
Name: collection,
AutoID: autoID,
Fields: []*schemapb.FieldSchema{pk, fVec},
}
}

View File

@ -842,11 +842,7 @@ class TestCollectionSearchInvalid(TestcaseBase):
log.info("test_search_output_field_vector: Searching collection %s" % collection_w.name)
collection_w.search(vectors[:default_nq], default_search_field,
default_search_params, default_limit,
default_search_exp, output_fields=output_fields,
check_task=CheckTasks.err_res,
check_items={"err_code": 1,
"err_msg": "Search doesn't support "
"vector field as output_fields"})
default_search_exp, output_fields=output_fields)
@pytest.mark.tags(CaseLabel.L2)
@pytest.mark.parametrize("output_fields", [["*%"], ["**"], ["*", "@"]])