feat: add more operation detail info for better allocation (#30438)

issue: #30436

---------

Signed-off-by: SimFG <bang.fu@zilliz.com>
This commit is contained in:
SimFG 2024-03-28 06:33:11 +08:00 committed by GitHub
parent fbff46a005
commit b1a1cca10b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
61 changed files with 1891 additions and 2355 deletions

View File

@ -140,6 +140,8 @@ issues:
- which can be annoying to use
# Binds to all network interfaces
- G102
# Use of unsafe calls should be audited
- G103
# Errors unhandled
- G104
# file/folder Permission
@ -164,4 +166,5 @@ issues:
max-same-issues: 0
service:
golangci-lint-version: 1.55.2 # use the fixed version to not introduce new linters unexpectedly
# use the fixed version to not introduce new linters unexpectedly
golangci-lint-version: 1.55.2

17
go.mod
View File

@ -9,6 +9,7 @@ require (
github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.1.0
github.com/aliyun/credentials-go v1.2.7
github.com/antlr/antlr4/runtime/Go/antlr v0.0.0-20210826220005-b48c857c3a0e
github.com/apache/arrow/go/v12 v12.0.1
github.com/apache/pulsar-client-go v0.6.1-0.20210728062540-29414db801a7
github.com/bits-and-blooms/bloom/v3 v3.0.1
github.com/blang/semver/v4 v4.0.0
@ -17,6 +18,7 @@ require (
github.com/cockroachdb/errors v1.9.1
github.com/containerd/cgroups/v3 v3.0.3 // indirect
github.com/gin-gonic/gin v1.9.1
github.com/go-playground/validator/v10 v10.14.0
github.com/gofrs/flock v0.8.1
github.com/gogo/protobuf v1.3.2
github.com/golang/protobuf v1.5.3
@ -26,9 +28,11 @@ require (
github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d
github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240317125658-67a0f065c1de
github.com/minio/minio-go/v7 v7.0.61
github.com/pingcap/log v1.1.1-0.20221015072633-39906604fb81
github.com/prometheus/client_golang v1.14.0
github.com/prometheus/client_model v0.3.0
github.com/prometheus/common v0.42.0
github.com/quasilyte/go-ruleguard/dsl v0.3.22
github.com/samber/lo v1.27.0
github.com/sbinet/npyio v0.6.0
github.com/soheilhy/cmux v0.1.5
@ -36,6 +40,7 @@ require (
github.com/spf13/viper v1.8.1
github.com/stretchr/testify v1.8.4
github.com/tecbot/gorocksdb v0.0.0-20191217155057-f0fad39f321c
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common v1.0.865
github.com/tidwall/gjson v1.14.4
github.com/tikv/client-go/v2 v2.0.4
go.etcd.io/etcd/api/v3 v3.5.5
@ -49,6 +54,7 @@ require (
go.uber.org/zap v1.24.0
golang.org/x/crypto v0.16.0
golang.org/x/exp v0.0.0-20230728194245-b0cb94b80691
golang.org/x/net v0.19.0
golang.org/x/oauth2 v0.8.0
golang.org/x/sync v0.5.0
golang.org/x/text v0.14.0
@ -56,18 +62,9 @@ require (
google.golang.org/grpc/examples v0.0.0-20220617181431-3e7b97febc7f
)
require github.com/apache/arrow/go/v12 v12.0.1
require github.com/milvus-io/milvus-storage/go v0.0.0-20231227072638-ebd0b8e56d70
require (
github.com/go-playground/validator/v10 v10.14.0
github.com/milvus-io/milvus/pkg v0.0.0-00010101000000-000000000000
github.com/pingcap/log v1.1.1-0.20221015072633-39906604fb81
github.com/quasilyte/go-ruleguard/dsl v0.3.22
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common v1.0.865
golang.org/x/net v0.19.0
)
require github.com/milvus-io/milvus/pkg v0.0.0-00010101000000-000000000000
require (
cloud.google.com/go/compute v1.20.1 // indirect

View File

@ -189,6 +189,7 @@ struct SearchResult {
public:
int64_t total_nq_;
int64_t unity_topK_;
int64_t total_data_cnt_;
void* segment_;
// first fill data during search, and then update data after reducing search results
@ -223,6 +224,7 @@ struct RetrieveResult {
RetrieveResult() = default;
public:
int64_t total_data_cnt_;
void* segment_;
std::vector<int64_t> result_offsets_;
std::vector<DataArray> field_data_;

View File

@ -71,6 +71,7 @@ empty_search_result(int64_t num_queries, SearchInfo& search_info) {
SearchResult final_result;
final_result.total_nq_ = num_queries;
final_result.unity_topK_ = 0; // no result
final_result.total_data_cnt_ = 0;
return final_result;
}
@ -212,6 +213,7 @@ ExecPlanNodeVisitor::VectorVisitorImpl(VectorPlanNode& node) {
timestamp_,
final_view,
search_result);
search_result.total_data_cnt_ = final_view.size();
if (search_result.vector_iterators_.has_value()) {
std::vector<GroupByValueType> group_by_values;
GroupBy(search_result.vector_iterators_.value(),
@ -239,6 +241,7 @@ wrap_num_entities(int64_t cnt) {
auto scalar = arr.mutable_scalars();
scalar->mutable_long_data()->mutable_data()->Add(cnt);
retrieve_result->field_data_ = {arr};
retrieve_result->total_data_cnt_ = 0;
return retrieve_result;
}
@ -249,6 +252,7 @@ ExecPlanNodeVisitor::visit(RetrievePlanNode& node) {
dynamic_cast<const segcore::SegmentInternalInterface*>(&segment_);
AssertInfo(segment, "Support SegmentSmallIndex Only");
RetrieveResult retrieve_result;
retrieve_result.total_data_cnt_ = 0;
auto active_count = segment->get_active_count(timestamp_);
@ -295,10 +299,12 @@ ExecPlanNodeVisitor::visit(RetrievePlanNode& node) {
if (node.is_count_) {
auto cnt = bitset_holder.size() - bitset_holder.count();
retrieve_result = *(wrap_num_entities(cnt));
retrieve_result.total_data_cnt_ = bitset_holder.size();
retrieve_result_opt_ = std::move(retrieve_result);
return;
}
retrieve_result.total_data_cnt_ = bitset_holder.size();
bool false_filtered_out = false;
if (get_cache_offset) {
segment->timestamp_filter(bitset_holder, cache_offsets, timestamp_);

View File

@ -348,12 +348,14 @@ ReduceHelper::GetSearchResultDataSlice(int slice_index) {
auto nq_end = slice_nqs_prefix_sum_[slice_index + 1];
int64_t result_count = 0;
int64_t all_search_count = 0;
for (auto search_result : search_results_) {
AssertInfo(search_result->topk_per_nq_prefix_sum_.size() ==
search_result->total_nq_ + 1,
"incorrect topk_per_nq_prefix_sum_ size in search result");
result_count += search_result->topk_per_nq_prefix_sum_[nq_end] -
search_result->topk_per_nq_prefix_sum_[nq_begin];
all_search_count += search_result->total_data_cnt_;
}
auto search_result_data =
@ -363,6 +365,8 @@ ReduceHelper::GetSearchResultDataSlice(int slice_index) {
search_result_data->set_num_queries(nq_end - nq_begin);
search_result_data->mutable_topks()->Resize(nq_end - nq_begin, 0);
search_result_data->set_all_search_count(all_search_count);
// `result_pairs` contains the SearchResult and result_offset info, used for filling output fields
std::vector<std::pair<SearchResult*, int64_t>> result_pairs(result_count);

View File

@ -100,6 +100,7 @@ SegmentInternalInterface::Retrieve(const query::RetrievePlan* plan,
fmt::format("query results exceed the limit size ", limit_size));
}
results->set_all_retrieve_count(retrieve_results.total_data_cnt_);
if (plan->plan_node_->is_count_) {
AssertInfo(retrieve_results.field_data_.size() == 1,
"count result should only have one column");

View File

@ -205,7 +205,7 @@ func (ib *indexBuilder) run() {
}
}
func getBinLogIds(segment *SegmentInfo, fieldID int64) []int64 {
func getBinLogIDs(segment *SegmentInfo, fieldID int64) []int64 {
binlogIDs := make([]int64, 0)
for _, fieldBinLog := range segment.GetBinlogs() {
if fieldBinLog.GetFieldID() == fieldID {
@ -299,7 +299,7 @@ func (ib *indexBuilder) process(buildID UniqueID) bool {
FieldID: partitionKeyField.FieldID,
FieldName: partitionKeyField.Name,
FieldType: int32(partitionKeyField.DataType),
DataIds: getBinLogIds(segment, partitionKeyField.FieldID),
DataIds: getBinLogIDs(segment, partitionKeyField.FieldID),
})
}
}
@ -333,7 +333,7 @@ func (ib *indexBuilder) process(buildID UniqueID) bool {
}
fieldID := ib.meta.indexMeta.GetFieldIDByIndexID(meta.CollectionID, meta.IndexID)
binlogIDs := getBinLogIds(segment, fieldID)
binlogIDs := getBinLogIDs(segment, fieldID)
if isDiskANNIndex(GetIndexType(indexParams)) {
var err error
indexParams, err = indexparams.UpdateDiskIndexBuildParams(Params, indexParams)

View File

@ -139,6 +139,31 @@ func (s *Server) getSystemInfoMetrics(
return resp, nil
}
func (s *Server) getCollectionStorageMetrics(ctx context.Context) (*milvuspb.GetMetricsResponse, error) {
coordTopology := metricsinfo.DataCoordTopology{
Cluster: metricsinfo.DataClusterTopology{
Self: s.getDataCoordMetrics(ctx),
},
Connections: metricsinfo.ConnTopology{
Name: metricsinfo.ConstructComponentName(typeutil.DataCoordRole, paramtable.GetNodeID()),
ConnectedComponents: []metricsinfo.ConnectionInfo{},
},
}
resp := &milvuspb.GetMetricsResponse{
Status: merr.Success(),
ComponentName: metricsinfo.ConstructComponentName(typeutil.DataCoordRole, paramtable.GetNodeID()),
}
var err error
resp.Response, err = metricsinfo.MarshalTopology(coordTopology)
if err != nil {
resp.Status = merr.Status(err)
return resp, nil
}
return resp, nil
}
// getDataCoordMetrics composes datacoord infos
func (s *Server) getDataCoordMetrics(ctx context.Context) metricsinfo.DataCoordInfos {
ret := metricsinfo.DataCoordInfos{

View File

@ -236,8 +236,8 @@ func TestLastExpireReset(t *testing.T) {
assert.Equal(t, expire1, segment1.GetLastExpireTime())
assert.Equal(t, expire2, segment2.GetLastExpireTime())
assert.True(t, segment3.GetLastExpireTime() > expire3)
flushableSegIds, _ := newSegmentManager.GetFlushableSegments(context.Background(), channelName, expire3)
assert.ElementsMatch(t, []UniqueID{segmentID1, segmentID2}, flushableSegIds) // segment1 and segment2 can be flushed
flushableSegIDs, _ := newSegmentManager.GetFlushableSegments(context.Background(), channelName, expire3)
assert.ElementsMatch(t, []UniqueID{segmentID1, segmentID2}, flushableSegIDs) // segment1 and segment2 can be flushed
newAlloc, err := newSegmentManager.AllocSegment(context.Background(), collID, 0, channelName, 2000)
assert.Nil(t, err)
assert.Equal(t, segmentID3, newAlloc[0].SegmentID) // segment3 still can be used to allocate

View File

@ -94,7 +94,7 @@ type rootCoordCreatorFunc func(ctx context.Context) (types.RootCoordClient, erro
// makes sure Server implements `DataCoord`
var _ types.DataCoord = (*Server)(nil)
var Params *paramtable.ComponentParam = paramtable.Get()
var Params = paramtable.Get()
// Server implements `types.DataCoord`
// handles Data Coordinator related jobs
@ -161,6 +161,11 @@ type Server struct {
broker broker.Broker
}
type CollectionNameInfo struct {
CollectionName string
DBName string
}
// ServerHelper datacoord server injection helper
type ServerHelper struct {
eventAfterHandleDataNodeTt func()

View File

@ -34,6 +34,7 @@ import (
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
clientv3 "go.etcd.io/etcd/client/v3"
"go.uber.org/atomic"
"go.uber.org/zap"
"google.golang.org/grpc"
@ -3836,3 +3837,67 @@ func TestUpdateAutoBalanceConfigLoop(t *testing.T) {
wg.Wait()
})
}
func TestGetCollectionStorage(t *testing.T) {
paramtable.Init()
mockSession := sessionutil.NewMockSession(t)
mockSession.EXPECT().GetAddress().Return("localhost:8888")
size := atomic.NewInt64(100)
s := &Server{
session: mockSession,
meta: &meta{
segments: &SegmentsInfo{
segments: map[UniqueID]*SegmentInfo{
1: {
SegmentInfo: &datapb.SegmentInfo{
ID: 1,
State: commonpb.SegmentState_Growing,
CollectionID: 10001,
PartitionID: 10000,
NumOfRows: 10,
},
size: *size,
},
2: {
SegmentInfo: &datapb.SegmentInfo{
ID: 2,
State: commonpb.SegmentState_Dropped,
CollectionID: 10001,
PartitionID: 10000,
NumOfRows: 10,
},
size: *size,
},
3: {
SegmentInfo: &datapb.SegmentInfo{
ID: 3,
State: commonpb.SegmentState_Flushed,
CollectionID: 10002,
PartitionID: 9999,
NumOfRows: 10,
},
size: *size,
},
},
},
},
}
s.stateCode.Store(commonpb.StateCode_Healthy)
req, err := metricsinfo.ConstructRequestByMetricType(metricsinfo.CollectionStorageMetrics)
assert.NoError(t, err)
resp, err := s.GetMetrics(context.TODO(), req)
assert.NoError(t, err)
var coordTopology metricsinfo.DataCoordTopology
err = metricsinfo.UnmarshalTopology(resp.Response, &coordTopology)
assert.NoError(t, err)
m := coordTopology.Cluster.Self.QuotaMetrics
assert.NotNil(t, m)
assert.Equal(t, int64(200), m.TotalBinlogSize)
assert.Len(t, m.CollectionBinlogSize, 2)
assert.Equal(t, int64(100), m.CollectionBinlogSize[10001])
assert.Equal(t, int64(100), m.CollectionBinlogSize[10002])
}

View File

@ -1029,6 +1029,23 @@ func (s *Server) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest
zap.Any("metrics", metrics), // TODO(dragondriver): necessary? may be very large
zap.Error(err))
return metrics, nil
} else if metricType == metricsinfo.CollectionStorageMetrics {
metrics, err := s.getCollectionStorageMetrics(ctx)
if err != nil {
log.Warn("DataCoord GetMetrics CollectionStorage failed", zap.Int64("nodeID", paramtable.GetNodeID()), zap.Error(err))
return &milvuspb.GetMetricsResponse{
Status: merr.Status(err),
}, nil
}
log.RatedDebug(60, "DataCoord.GetMetrics CollectionStorage",
zap.Int64("nodeID", paramtable.GetNodeID()),
zap.String("req", req.Request),
zap.String("metricType", metricType),
zap.Any("metrics", metrics),
zap.Error(err))
return metrics, nil
}

View File

@ -232,7 +232,6 @@ func (c *SessionManagerImpl) SyncSegments(nodeID int64, req *datapb.SyncSegments
}
return nil
})
if err != nil {
log.Warn("failed to sync segments after retry", zap.Error(err))
return err

View File

@ -108,20 +108,20 @@ func getVchanInfo(info *testInfo) *datapb.VchannelInfo {
ufs = []*datapb.SegmentInfo{}
}
var ufsIds []int64
var fsIds []int64
var ufsIDs []int64
var fsIDs []int64
for _, segmentInfo := range ufs {
ufsIds = append(ufsIds, segmentInfo.ID)
ufsIDs = append(ufsIDs, segmentInfo.ID)
}
for _, segmentInfo := range fs {
fsIds = append(fsIds, segmentInfo.ID)
fsIDs = append(fsIDs, segmentInfo.ID)
}
vi := &datapb.VchannelInfo{
CollectionID: info.collID,
ChannelName: info.chanName,
SeekPosition: &msgpb.MsgPosition{},
UnflushedSegmentIds: ufsIds,
FlushedSegmentIds: fsIds,
UnflushedSegmentIds: ufsIDs,
FlushedSegmentIds: fsIDs,
}
return vi
}
@ -465,13 +465,13 @@ func (s *DataSyncServiceSuite) TestStartStop() {
NumOfRows: 0,
DmlPosition: &msgpb.MsgPosition{},
}}
var ufsIds []int64
var fsIds []int64
var ufsIDs []int64
var fsIDs []int64
for _, segmentInfo := range ufs {
ufsIds = append(ufsIds, segmentInfo.ID)
ufsIDs = append(ufsIDs, segmentInfo.ID)
}
for _, segmentInfo := range fs {
fsIds = append(fsIds, segmentInfo.ID)
fsIDs = append(fsIDs, segmentInfo.ID)
}
watchInfo := &datapb.ChannelWatchInfo{
@ -479,8 +479,8 @@ func (s *DataSyncServiceSuite) TestStartStop() {
Vchan: &datapb.VchannelInfo{
CollectionID: collMeta.ID,
ChannelName: insertChannelName,
UnflushedSegmentIds: ufsIds,
FlushedSegmentIds: fsIds,
UnflushedSegmentIds: ufsIDs,
FlushedSegmentIds: fsIDs,
},
}

View File

@ -691,7 +691,7 @@ func TestInsert(t *testing.T) {
mp5, _ = wrapWithDescribeColl(t, mp5, ReturnSuccess, 1, nil)
mp5.EXPECT().Insert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{
Status: &StatusSuccess,
IDs: genIds(schemapb.DataType_Int64),
IDs: genIDs(schemapb.DataType_Int64),
InsertCnt: 3,
}, nil).Once()
testCases = append(testCases, testCase{
@ -705,7 +705,7 @@ func TestInsert(t *testing.T) {
mp6, _ = wrapWithDescribeColl(t, mp6, ReturnSuccess, 1, nil)
mp6.EXPECT().Insert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{
Status: &StatusSuccess,
IDs: genIds(schemapb.DataType_VarChar),
IDs: genIDs(schemapb.DataType_VarChar),
InsertCnt: 3,
}, nil).Once()
testCases = append(testCases, testCase{
@ -776,7 +776,7 @@ func TestInsertForDataType(t *testing.T) {
}, nil).Once()
mp.EXPECT().Insert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{
Status: &StatusSuccess,
IDs: genIds(schemapb.DataType_Int64),
IDs: genIDs(schemapb.DataType_Int64),
InsertCnt: 3,
}, nil).Once()
testEngine := initHTTPServer(mp, true)
@ -844,7 +844,7 @@ func TestReturnInt64(t *testing.T) {
}, nil).Once()
mp.EXPECT().Insert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{
Status: &StatusSuccess,
IDs: genIds(dataType),
IDs: genIDs(dataType),
InsertCnt: 3,
}, nil).Once()
testEngine := initHTTPServer(mp, true)
@ -875,7 +875,7 @@ func TestReturnInt64(t *testing.T) {
}, nil).Once()
mp.EXPECT().Upsert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{
Status: &StatusSuccess,
IDs: genIds(dataType),
IDs: genIDs(dataType),
UpsertCnt: 3,
}, nil).Once()
testEngine := initHTTPServer(mp, true)
@ -906,7 +906,7 @@ func TestReturnInt64(t *testing.T) {
}, nil).Once()
mp.EXPECT().Insert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{
Status: &StatusSuccess,
IDs: genIds(dataType),
IDs: genIDs(dataType),
InsertCnt: 3,
}, nil).Once()
testEngine := initHTTPServer(mp, true)
@ -938,7 +938,7 @@ func TestReturnInt64(t *testing.T) {
}, nil).Once()
mp.EXPECT().Upsert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{
Status: &StatusSuccess,
IDs: genIds(dataType),
IDs: genIDs(dataType),
UpsertCnt: 3,
}, nil).Once()
testEngine := initHTTPServer(mp, true)
@ -971,7 +971,7 @@ func TestReturnInt64(t *testing.T) {
}, nil).Once()
mp.EXPECT().Insert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{
Status: &StatusSuccess,
IDs: genIds(dataType),
IDs: genIDs(dataType),
InsertCnt: 3,
}, nil).Once()
testEngine := initHTTPServer(mp, true)
@ -1002,7 +1002,7 @@ func TestReturnInt64(t *testing.T) {
}, nil).Once()
mp.EXPECT().Upsert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{
Status: &StatusSuccess,
IDs: genIds(dataType),
IDs: genIDs(dataType),
UpsertCnt: 3,
}, nil).Once()
testEngine := initHTTPServer(mp, true)
@ -1033,7 +1033,7 @@ func TestReturnInt64(t *testing.T) {
}, nil).Once()
mp.EXPECT().Insert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{
Status: &StatusSuccess,
IDs: genIds(dataType),
IDs: genIDs(dataType),
InsertCnt: 3,
}, nil).Once()
testEngine := initHTTPServer(mp, true)
@ -1065,7 +1065,7 @@ func TestReturnInt64(t *testing.T) {
}, nil).Once()
mp.EXPECT().Upsert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{
Status: &StatusSuccess,
IDs: genIds(dataType),
IDs: genIDs(dataType),
UpsertCnt: 3,
}, nil).Once()
testEngine := initHTTPServer(mp, true)
@ -1132,7 +1132,7 @@ func TestUpsert(t *testing.T) {
mp5, _ = wrapWithDescribeColl(t, mp5, ReturnSuccess, 1, nil)
mp5.EXPECT().Upsert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{
Status: &StatusSuccess,
IDs: genIds(schemapb.DataType_Int64),
IDs: genIDs(schemapb.DataType_Int64),
UpsertCnt: 3,
}, nil).Once()
testCases = append(testCases, testCase{
@ -1146,7 +1146,7 @@ func TestUpsert(t *testing.T) {
mp6, _ = wrapWithDescribeColl(t, mp6, ReturnSuccess, 1, nil)
mp6.EXPECT().Upsert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{
Status: &StatusSuccess,
IDs: genIds(schemapb.DataType_VarChar),
IDs: genIDs(schemapb.DataType_VarChar),
UpsertCnt: 3,
}, nil).Once()
testCases = append(testCases, testCase{
@ -1198,8 +1198,8 @@ func TestUpsert(t *testing.T) {
})
}
func genIds(dataType schemapb.DataType) *schemapb.IDs {
return generateIds(dataType, 3)
func genIDs(dataType schemapb.DataType) *schemapb.IDs {
return generateIDs(dataType, 3)
}
func TestSearch(t *testing.T) {

View File

@ -37,7 +37,7 @@ func generatePrimaryField(datatype schemapb.DataType) schemapb.FieldSchema {
}
}
func generateIds(dataType schemapb.DataType, num int) *schemapb.IDs {
func generateIDs(dataType schemapb.DataType, num int) *schemapb.IDs {
var intArray []int64
if num == 0 {
intArray = []int64{}
@ -684,7 +684,7 @@ func compareRows(row1 []map[string]interface{}, row2 []map[string]interface{}, c
func TestBuildQueryResp(t *testing.T) {
outputFields := []string{FieldBookID, FieldWordCount, "author", "date"}
rows, err := buildQueryResp(int64(0), outputFields, generateFieldData(), generateIds(schemapb.DataType_Int64, 3), DefaultScores, true) // []*schemapb.FieldData{&fieldData1, &fieldData2, &fieldData3}
rows, err := buildQueryResp(int64(0), outputFields, generateFieldData(), generateIDs(schemapb.DataType_Int64, 3), DefaultScores, true) // []*schemapb.FieldData{&fieldData1, &fieldData2, &fieldData3}
assert.Equal(t, nil, err)
exceptRows := generateSearchResult(schemapb.DataType_Int64)
assert.Equal(t, true, compareRows(rows, exceptRows, compareRow))
@ -1298,7 +1298,7 @@ func TestBuildQueryResps(t *testing.T) {
outputFields := []string{"XXX", "YYY"}
outputFieldsList := [][]string{outputFields, {"$meta"}, {"$meta", FieldBookID, FieldBookIntro, "YYY"}}
for _, theOutputFields := range outputFieldsList {
rows, err := buildQueryResp(int64(0), theOutputFields, newFieldData(generateFieldData(), schemapb.DataType_None), generateIds(schemapb.DataType_Int64, 3), DefaultScores, true)
rows, err := buildQueryResp(int64(0), theOutputFields, newFieldData(generateFieldData(), schemapb.DataType_None), generateIDs(schemapb.DataType_Int64, 3), DefaultScores, true)
assert.Equal(t, nil, err)
exceptRows := newSearchResult(generateSearchResult(schemapb.DataType_Int64))
assert.Equal(t, true, compareRows(rows, exceptRows, compareRow))
@ -1312,29 +1312,29 @@ func TestBuildQueryResps(t *testing.T) {
schemapb.DataType_JSON, schemapb.DataType_Array,
}
for _, dateType := range dataTypes {
_, err := buildQueryResp(int64(0), outputFields, newFieldData([]*schemapb.FieldData{}, dateType), generateIds(schemapb.DataType_Int64, 3), DefaultScores, true)
_, err := buildQueryResp(int64(0), outputFields, newFieldData([]*schemapb.FieldData{}, dateType), generateIDs(schemapb.DataType_Int64, 3), DefaultScores, true)
assert.Equal(t, nil, err)
}
_, err := buildQueryResp(int64(0), outputFields, newFieldData([]*schemapb.FieldData{}, 1000), generateIds(schemapb.DataType_Int64, 3), DefaultScores, true)
_, err := buildQueryResp(int64(0), outputFields, newFieldData([]*schemapb.FieldData{}, 1000), generateIDs(schemapb.DataType_Int64, 3), DefaultScores, true)
assert.Equal(t, "the type(1000) of field(wrong-field-type) is not supported, use other sdk please", err.Error())
res, err := buildQueryResp(int64(0), outputFields, []*schemapb.FieldData{}, generateIds(schemapb.DataType_Int64, 3), DefaultScores, true)
res, err := buildQueryResp(int64(0), outputFields, []*schemapb.FieldData{}, generateIDs(schemapb.DataType_Int64, 3), DefaultScores, true)
assert.Equal(t, 3, len(res))
assert.Equal(t, nil, err)
res, err = buildQueryResp(int64(0), outputFields, []*schemapb.FieldData{}, generateIds(schemapb.DataType_Int64, 3), DefaultScores, false)
res, err = buildQueryResp(int64(0), outputFields, []*schemapb.FieldData{}, generateIDs(schemapb.DataType_Int64, 3), DefaultScores, false)
assert.Equal(t, 3, len(res))
assert.Equal(t, nil, err)
res, err = buildQueryResp(int64(0), outputFields, []*schemapb.FieldData{}, generateIds(schemapb.DataType_VarChar, 3), DefaultScores, true)
res, err = buildQueryResp(int64(0), outputFields, []*schemapb.FieldData{}, generateIDs(schemapb.DataType_VarChar, 3), DefaultScores, true)
assert.Equal(t, 3, len(res))
assert.Equal(t, nil, err)
_, err = buildQueryResp(int64(0), outputFields, generateFieldData(), generateIds(schemapb.DataType_Int64, 3), DefaultScores, false)
_, err = buildQueryResp(int64(0), outputFields, generateFieldData(), generateIDs(schemapb.DataType_Int64, 3), DefaultScores, false)
assert.Equal(t, nil, err)
// len(rows) != len(scores), didn't show distance
_, err = buildQueryResp(int64(0), outputFields, newFieldData(generateFieldData(), schemapb.DataType_None), generateIds(schemapb.DataType_Int64, 3), []float32{0.01, 0.04}, true)
_, err = buildQueryResp(int64(0), outputFields, newFieldData(generateFieldData(), schemapb.DataType_None), generateIDs(schemapb.DataType_Int64, 3), []float32{0.01, 0.04}, true)
assert.Equal(t, nil, err)
}

View File

@ -125,16 +125,16 @@ func (suite *CatalogTestSuite) TestPartition() {
}
func (suite *CatalogTestSuite) TestReleaseManyPartitions() {
partitionIds := make([]int64, 0)
partitionIDs := make([]int64, 0)
for i := 1; i <= 150; i++ {
suite.catalog.SavePartition(&querypb.PartitionLoadInfo{
CollectionID: 1,
PartitionID: int64(i),
})
partitionIds = append(partitionIds, int64(i))
partitionIDs = append(partitionIDs, int64(i))
}
err := suite.catalog.ReleasePartition(1, partitionIds...)
err := suite.catalog.ReleasePartition(1, partitionIDs...)
suite.NoError(err)
partitions, err := suite.catalog.GetPartitions()
suite.NoError(err)

View File

@ -16,6 +16,6 @@ type Segment struct {
CreatedByCompaction bool
SegmentState commonpb.SegmentState
// IndexInfos []*SegmentIndex
ReplicaIds []int64
NodeIds []int64
ReplicaIDs []int64
NodeIDs []int64
}

View File

@ -134,6 +134,7 @@ message SearchResults {
// search request cost
CostAggregation costAggregation = 13;
map<string, uint64> channels_mvcc = 14;
int64 all_search_count = 15;
}
message CostAggregation {
@ -175,6 +176,7 @@ message RetrieveResults {
// query request cost
CostAggregation costAggregation = 13;
int64 all_retrieve_count = 14;
}
message LoadIndex {

View File

@ -9,6 +9,7 @@ message RetrieveResults {
schema.IDs ids = 1;
repeated int64 offset = 2;
repeated schema.FieldData fields_data = 3;
int64 all_retrieve_count = 4;
}
message LoadFieldMeta {

View File

@ -10,6 +10,7 @@ import (
"google.golang.org/grpc/metadata"
"github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/util/hookutil"
"github.com/milvus-io/milvus/pkg/util"
"github.com/milvus-io/milvus/pkg/util/crypto"
"github.com/milvus-io/milvus/pkg/util/paramtable"
@ -140,5 +141,5 @@ func TestAuthenticationInterceptor(t *testing.T) {
user, _ := parseMD(rawToken)
assert.Equal(t, "mockUser", user)
}
hoo = defaultHook{}
hoo = hookutil.DefaultHook{}
}

View File

@ -177,7 +177,6 @@ func createStream(factory msgstream.Factory, pchans []pChan, repack repackFuncTy
var err error
stream, err = factory.NewMsgStream(context.Background())
if err != nil {
return nil, err
}

View File

@ -4,6 +4,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/merr"
)
type cntReducer struct {
@ -20,6 +21,7 @@ func (r *cntReducer) Reduce(results []*internalpb.RetrieveResults) (*milvuspb.Qu
cnt += c
}
res := funcutil.WrapCntToQueryResults(cnt)
res.Status = merr.Success()
res.CollectionName = r.collectionName
return res, nil
}

View File

@ -2,93 +2,24 @@ package proxy
import (
"context"
"fmt"
"plugin"
"strconv"
"strings"
"github.com/cockroachdb/errors"
"go.uber.org/zap"
"google.golang.org/grpc"
"github.com/milvus-io/milvus-proto/go-api/v2/hook"
"github.com/milvus-io/milvus/pkg/config"
"github.com/milvus-io/milvus/internal/util/hookutil"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/metrics"
"github.com/milvus-io/milvus/pkg/util/paramtable"
)
type defaultHook struct{}
func (d defaultHook) VerifyAPIKey(key string) (string, error) {
return "", errors.New("default hook, can't verify api key")
}
func (d defaultHook) Init(params map[string]string) error {
return nil
}
func (d defaultHook) Mock(ctx context.Context, req interface{}, fullMethod string) (bool, interface{}, error) {
return false, nil, nil
}
func (d defaultHook) Before(ctx context.Context, req interface{}, fullMethod string) (context.Context, error) {
return ctx, nil
}
func (d defaultHook) After(ctx context.Context, result interface{}, err error, fullMethod string) error {
return nil
}
func (d defaultHook) Release() {}
var hoo hook.Hook
func initHook() error {
path := Params.ProxyCfg.SoPath.GetValue()
if path == "" {
hoo = defaultHook{}
return nil
}
logger.Debug("start to load plugin", zap.String("path", path))
p, err := plugin.Open(path)
if err != nil {
return fmt.Errorf("fail to open the plugin, error: %s", err.Error())
}
logger.Debug("plugin open")
h, err := p.Lookup("MilvusHook")
if err != nil {
return fmt.Errorf("fail to the 'MilvusHook' object in the plugin, error: %s", err.Error())
}
var ok bool
hoo, ok = h.(hook.Hook)
if !ok {
return fmt.Errorf("fail to convert the `Hook` interface")
}
if err = hoo.Init(paramtable.GetHookParams().SoConfig.GetValue()); err != nil {
return fmt.Errorf("fail to init configs for the hook, error: %s", err.Error())
}
paramtable.GetHookParams().WatchHookWithPrefix("watch_hook", "", func(event *config.Event) {
log.Info("receive the hook refresh event", zap.Any("event", event))
go func() {
soConfig := paramtable.GetHookParams().SoConfig.GetValue()
log.Info("refresh hook configs", zap.Any("config", soConfig))
if err = hoo.Init(soConfig); err != nil {
log.Panic("fail to init configs for the hook when refreshing", zap.Error(err))
}
}()
})
return nil
}
func UnaryServerHookInterceptor() grpc.UnaryServerInterceptor {
if hookError := initHook(); hookError != nil {
logger.Error("hook error", zap.String("path", Params.ProxyCfg.SoPath.GetValue()), zap.Error(hookError))
hoo = defaultHook{}
}
hookutil.InitOnceHook()
hoo = hookutil.Hoo
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
var (
fullMethod = info.FullMethod
@ -145,24 +76,13 @@ func getCurrentUser(ctx context.Context) string {
return username
}
// MockAPIHook is a mock hook for api key verification, ONLY FOR TEST
type MockAPIHook struct {
defaultHook
mockErr error
apiUser string
}
func (m MockAPIHook) VerifyAPIKey(apiKey string) (string, error) {
return m.apiUser, m.mockErr
}
func SetMockAPIHook(apiUser string, mockErr error) {
if apiUser == "" && mockErr == nil {
hoo = defaultHook{}
hoo = &hookutil.DefaultHook{}
return
}
hoo = MockAPIHook{
mockErr: mockErr,
apiUser: apiUser,
hoo = &hookutil.MockAPIHook{
MockErr: mockErr,
User: apiUser,
}
}

View File

@ -8,22 +8,11 @@ import (
"github.com/stretchr/testify/assert"
"google.golang.org/grpc"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/internal/util/hookutil"
)
func TestInitHook(t *testing.T) {
paramtable.Get().Save(Params.ProxyCfg.SoPath.Key, "")
initHook()
assert.IsType(t, defaultHook{}, hoo)
paramtable.Get().Save(Params.ProxyCfg.SoPath.Key, "/a/b/hook.so")
err := initHook()
assert.Error(t, err)
paramtable.Get().Save(Params.ProxyCfg.SoPath.Key, "")
}
type mockHook struct {
defaultHook
hookutil.DefaultHook
mockRes interface{}
mockErr error
}
@ -39,7 +28,7 @@ type req struct {
type BeforeMockCtxKey int
type beforeMock struct {
defaultHook
hookutil.DefaultHook
method string
ctxKey BeforeMockCtxKey
ctxValue string
@ -60,7 +49,7 @@ type resp struct {
}
type afterMock struct {
defaultHook
hookutil.DefaultHook
method string
err error
}
@ -129,7 +118,7 @@ func TestHookInterceptor(t *testing.T) {
assert.Equal(t, re.method, afterHoo.method)
assert.Equal(t, err, afterHoo.err)
hoo = defaultHook{}
hoo = &hookutil.DefaultHook{}
res, err = interceptor(ctx, r, info, func(ctx context.Context, r interface{}) (interface{}, error) {
return &resp{
method: r.(*req).method,
@ -139,18 +128,6 @@ func TestHookInterceptor(t *testing.T) {
assert.NoError(t, err)
}
func TestDefaultHook(t *testing.T) {
d := defaultHook{}
assert.NoError(t, d.Init(nil))
{
_, err := d.VerifyAPIKey("key")
assert.Error(t, err)
}
assert.NotPanics(t, func() {
d.Release()
})
}
func TestUpdateProxyFunctionCallMetric(t *testing.T) {
assert.NotPanics(t, func() {
updateProxyFunctionCallMetric("/milvus.proto.milvus.MilvusService/Flush")

View File

@ -42,6 +42,7 @@ import (
"github.com/milvus-io/milvus/internal/proto/proxypb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/proxy/connection"
"github.com/milvus-io/milvus/internal/util/hookutil"
"github.com/milvus-io/milvus/internal/util/importutilv2"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
@ -2394,6 +2395,15 @@ func (node *Proxy) Insert(ctx context.Context, request *milvuspb.InsertRequest)
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method,
metrics.SuccessLabel).Inc()
successCnt := it.result.InsertCnt - int64(len(it.result.ErrIndex))
v := Extension.Report(map[string]any{
hookutil.OpTypeKey: hookutil.OpTypeInsert,
hookutil.DatabaseKey: request.DbName,
hookutil.UsernameKey: GetCurUserFromContextOrDefault(ctx),
hookutil.DataSizeKey: proto.Size(request),
hookutil.SuccessCntKey: successCnt,
hookutil.FailCntKey: len(it.result.ErrIndex),
})
SetReportValue(it.result.GetStatus(), v)
metrics.ProxyInsertVectors.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Add(float64(successCnt))
metrics.ProxyMutationLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.InsertLabel).Observe(float64(tr.ElapseSpan().Milliseconds()))
metrics.ProxyCollectionMutationLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.InsertLabel, request.CollectionName).Observe(float64(tr.ElapseSpan().Milliseconds()))
@ -2469,6 +2479,15 @@ func (node *Proxy) Delete(ctx context.Context, request *milvuspb.DeleteRequest)
successCnt := dr.result.GetDeleteCnt()
metrics.ProxyDeleteVectors.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Add(float64(successCnt))
v := Extension.Report(map[string]any{
hookutil.OpTypeKey: hookutil.OpTypeDelete,
hookutil.DatabaseKey: request.DbName,
hookutil.UsernameKey: GetCurUserFromContextOrDefault(ctx),
hookutil.SuccessCntKey: successCnt,
hookutil.RelatedCntKey: dr.allQueryCnt.Load(),
})
SetReportValue(dr.result.GetStatus(), v)
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method,
metrics.SuccessLabel).Inc()
metrics.ProxyMutationLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.DeleteLabel).Observe(float64(tr.ElapseSpan().Milliseconds()))
@ -2584,6 +2603,16 @@ func (node *Proxy) Upsert(ctx context.Context, request *milvuspb.UpsertRequest)
// UpsertCnt always equals to the number of entities in the request
it.result.UpsertCnt = int64(request.NumRows)
v := Extension.Report(map[string]any{
hookutil.OpTypeKey: hookutil.OpTypeUpsert,
hookutil.DatabaseKey: request.DbName,
hookutil.UsernameKey: GetCurUserFromContextOrDefault(ctx),
hookutil.DataSizeKey: proto.Size(it.req),
hookutil.SuccessCntKey: it.result.UpsertCnt,
hookutil.FailCntKey: len(it.result.ErrIndex),
})
SetReportValue(it.result.GetStatus(), v)
rateCol.Add(internalpb.RateType_DMLUpsert.String(), float64(it.upsertMsg.DeleteMsg.Size()+it.upsertMsg.DeleteMsg.Size()))
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method,
@ -2759,6 +2788,15 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest)
if qt.result != nil {
sentSize := proto.Size(qt.result)
v := Extension.Report(map[string]any{
hookutil.OpTypeKey: hookutil.OpTypeSearch,
hookutil.DatabaseKey: request.DbName,
hookutil.UsernameKey: GetCurUserFromContextOrDefault(ctx),
hookutil.DataSizeKey: sentSize,
hookutil.RelatedCntKey: qt.result.GetResults().GetAllSearchCount(),
hookutil.DimensionKey: qt.dimension,
})
SetReportValue(qt.result.GetStatus(), v)
metrics.ProxyReadReqSendBytes.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Add(float64(sentSize))
rateCol.Add(metricsinfo.ReadResultThroughput, float64(sentSize))
}
@ -2902,6 +2940,15 @@ func (node *Proxy) HybridSearch(ctx context.Context, request *milvuspb.HybridSea
if qt.result != nil {
sentSize := proto.Size(qt.result)
v := Extension.Report(map[string]any{
hookutil.OpTypeKey: hookutil.OpTypeHybridSearch,
hookutil.DatabaseKey: request.DbName,
hookutil.UsernameKey: GetCurUserFromContextOrDefault(ctx),
hookutil.DataSizeKey: sentSize,
hookutil.RelatedCntKey: qt.result.GetResults().GetAllSearchCount(),
hookutil.DimensionKey: qt.dimension,
})
SetReportValue(qt.result.GetStatus(), v)
metrics.ProxyReadReqSendBytes.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Add(float64(sentSize))
rateCol.Add(metricsinfo.ReadResultThroughput, float64(sentSize))
}
@ -3182,7 +3229,19 @@ func (node *Proxy) Query(ctx context.Context, request *milvuspb.QueryRequest) (*
qc: node.queryCoord,
lb: node.lbPolicy,
}
return node.query(ctx, qt)
res, err := node.query(ctx, qt)
if merr.Ok(res.Status) && err == nil {
v := Extension.Report(map[string]any{
hookutil.OpTypeKey: hookutil.OpTypeQuery,
hookutil.DatabaseKey: request.DbName,
hookutil.UsernameKey: GetCurUserFromContextOrDefault(ctx),
hookutil.DataSizeKey: proto.Size(res),
hookutil.RelatedCntKey: qt.allQueryCnt,
hookutil.DimensionKey: qt.dimension,
})
SetReportValue(res.Status, v)
}
return res, err
}
// CreateAlias create alias for collection, then you can search the collection with alias.

View File

@ -59,6 +59,8 @@ type Cache interface {
GetCollectionName(ctx context.Context, database string, collectionID int64) (string, error)
// GetCollectionInfo get collection's information by name or collection id, such as schema, and etc.
GetCollectionInfo(ctx context.Context, database, collectionName string, collectionID int64) (*collectionBasicInfo, error)
// GetCollectionNamesByID get collection name and database name by collection id
GetCollectionNamesByID(ctx context.Context, collectionID []UniqueID) ([]string, []string, error)
// GetPartitionID get partition's identifier of specific collection.
GetPartitionID(ctx context.Context, database, collectionName string, partitionName string) (typeutil.UniqueID, error)
// GetPartitions get all partitions' id of specific collection.
@ -242,11 +244,12 @@ type MetaCache struct {
rootCoord types.RootCoordClient
queryCoord types.QueryCoordClient
collInfo map[string]map[string]*collectionInfo // database -> collectionName -> collection_info
collLeader map[string]map[string]*shardLeaders // database -> collectionName -> collection_leaders
credMap map[string]*internalpb.CredentialInfo // cache for credential, lazy load
privilegeInfos map[string]struct{} // privileges cache
userToRoles map[string]map[string]struct{} // user to role cache
collInfo map[string]map[string]*collectionInfo // database -> collectionName -> collection_info
collLeader map[string]map[string]*shardLeaders // database -> collectionName -> collection_leaders
dbInfo map[string]map[typeutil.UniqueID]string // database -> collectionID -> collectionName
credMap map[string]*internalpb.CredentialInfo // cache for credential, lazy load
privilegeInfos map[string]struct{} // privileges cache
userToRoles map[string]map[string]struct{} // user to role cache
mu sync.RWMutex
credMut sync.RWMutex
leaderMut sync.RWMutex
@ -288,6 +291,7 @@ func NewMetaCache(rootCoord types.RootCoordClient, queryCoord types.QueryCoordCl
queryCoord: queryCoord,
collInfo: map[string]map[string]*collectionInfo{},
collLeader: map[string]map[string]*shardLeaders{},
dbInfo: map[string]map[typeutil.UniqueID]string{},
credMap: map[string]*internalpb.CredentialInfo{},
shardMgr: shardMgr,
privilegeInfos: map[string]struct{}{},
@ -471,6 +475,90 @@ func (m *MetaCache) GetCollectionInfo(ctx context.Context, database string, coll
return collInfo.getBasicInfo(), nil
}
func (m *MetaCache) GetCollectionNamesByID(ctx context.Context, collectionIDs []UniqueID) ([]string, []string, error) {
hasUpdate := false
dbNames := make([]string, 0)
collectionNames := make([]string, 0)
for _, collectionID := range collectionIDs {
dbName, collectionName := m.innerGetCollectionByID(collectionID)
if dbName != "" {
dbNames = append(dbNames, dbName)
collectionNames = append(collectionNames, collectionName)
continue
}
if hasUpdate {
return nil, nil, errors.New("collection not found after meta cache has been updated")
}
hasUpdate = true
err := m.updateDBInfo(ctx)
if err != nil {
return nil, nil, err
}
dbName, collectionName = m.innerGetCollectionByID(collectionID)
if dbName == "" {
return nil, nil, errors.New("collection not found")
}
dbNames = append(dbNames, dbName)
collectionNames = append(collectionNames, collectionName)
}
return dbNames, collectionNames, nil
}
func (m *MetaCache) innerGetCollectionByID(collectionID int64) (string, string) {
m.mu.RLock()
defer m.mu.RUnlock()
for database, db := range m.dbInfo {
name, ok := db[collectionID]
if ok {
return database, name
}
}
return "", ""
}
func (m *MetaCache) updateDBInfo(ctx context.Context) error {
databaseResp, err := m.rootCoord.ListDatabases(ctx, &milvuspb.ListDatabasesRequest{
Base: commonpbutil.NewMsgBase(commonpbutil.WithMsgType(commonpb.MsgType_ListDatabases)),
})
if err := merr.CheckRPCCall(databaseResp, err); err != nil {
log.Warn("failed to ListDatabases", zap.Error(err))
return err
}
dbInfo := make(map[string]map[int64]string)
for _, dbName := range databaseResp.DbNames {
resp, err := m.rootCoord.ShowCollections(ctx, &milvuspb.ShowCollectionsRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_ShowCollections),
),
DbName: dbName,
})
if err := merr.CheckRPCCall(resp, err); err != nil {
log.Warn("failed to ShowCollections",
zap.String("dbName", dbName),
zap.Error(err))
return err
}
collections := make(map[int64]string)
for i, collection := range resp.CollectionNames {
collections[resp.CollectionIds[i]] = collection
}
dbInfo[dbName] = collections
}
m.mu.Lock()
defer m.mu.Unlock()
m.dbInfo = dbInfo
return nil
}
// GetCollectionInfo returns the collection information related to provided collection name
// If the information is not found, proxy will try to fetch information for other source (RootCoord for now)
// TODO: may cause data race of this implementation, should be refactored in future.

View File

@ -881,3 +881,149 @@ func TestMetaCache_AllocID(t *testing.T) {
assert.Equal(t, id, int64(0))
})
}
func TestGlobalMetaCache_UpdateDBInfo(t *testing.T) {
rootCoord := mocks.NewMockRootCoordClient(t)
queryCoord := mocks.NewMockQueryCoordClient(t)
shardMgr := newShardClientMgr()
ctx := context.Background()
cache, err := NewMetaCache(rootCoord, queryCoord, shardMgr)
assert.NoError(t, err)
t.Run("fail to list db", func(t *testing.T) {
rootCoord.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return(&milvuspb.ListDatabasesResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Code: 500,
},
}, nil).Once()
err := cache.updateDBInfo(ctx)
assert.Error(t, err)
})
t.Run("fail to list collection", func(t *testing.T) {
rootCoord.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return(&milvuspb.ListDatabasesResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
DbNames: []string{"db1"},
}, nil).Once()
rootCoord.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&milvuspb.ShowCollectionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Code: 500,
},
}, nil).Once()
err := cache.updateDBInfo(ctx)
assert.Error(t, err)
})
t.Run("success", func(t *testing.T) {
rootCoord.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return(&milvuspb.ListDatabasesResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
DbNames: []string{"db1"},
}, nil).Once()
rootCoord.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&milvuspb.ShowCollectionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
CollectionNames: []string{"collection1"},
CollectionIds: []int64{1},
}, nil).Once()
err := cache.updateDBInfo(ctx)
assert.NoError(t, err)
assert.Len(t, cache.dbInfo, 1)
assert.Len(t, cache.dbInfo["db1"], 1)
assert.Equal(t, "collection1", cache.dbInfo["db1"][1])
})
}
func TestGlobalMetaCache_GetCollectionNamesByID(t *testing.T) {
rootCoord := mocks.NewMockRootCoordClient(t)
queryCoord := mocks.NewMockQueryCoordClient(t)
shardMgr := newShardClientMgr()
ctx := context.Background()
t.Run("fail to update db info", func(t *testing.T) {
rootCoord.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return(&milvuspb.ListDatabasesResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Code: 500,
},
}, nil).Once()
cache, err := NewMetaCache(rootCoord, queryCoord, shardMgr)
assert.NoError(t, err)
_, _, err = cache.GetCollectionNamesByID(ctx, []int64{1})
assert.Error(t, err)
})
t.Run("not found collection", func(t *testing.T) {
rootCoord.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return(&milvuspb.ListDatabasesResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
DbNames: []string{"db1"},
}, nil).Once()
rootCoord.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&milvuspb.ShowCollectionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
CollectionNames: []string{"collection1"},
CollectionIds: []int64{1},
}, nil).Once()
cache, err := NewMetaCache(rootCoord, queryCoord, shardMgr)
assert.NoError(t, err)
_, _, err = cache.GetCollectionNamesByID(ctx, []int64{2})
assert.Error(t, err)
})
t.Run("not found collection 2", func(t *testing.T) {
rootCoord.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return(&milvuspb.ListDatabasesResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
DbNames: []string{"db1"},
}, nil).Once()
rootCoord.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&milvuspb.ShowCollectionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
CollectionNames: []string{"collection1"},
CollectionIds: []int64{1},
}, nil).Once()
cache, err := NewMetaCache(rootCoord, queryCoord, shardMgr)
assert.NoError(t, err)
_, _, err = cache.GetCollectionNamesByID(ctx, []int64{1, 2})
assert.Error(t, err)
})
t.Run("success", func(t *testing.T) {
rootCoord.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return(&milvuspb.ListDatabasesResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
DbNames: []string{"db1"},
}, nil).Once()
rootCoord.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&milvuspb.ShowCollectionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
CollectionNames: []string{"collection1", "collection2"},
CollectionIds: []int64{1, 2},
}, nil).Once()
cache, err := NewMetaCache(rootCoord, queryCoord, shardMgr)
assert.NoError(t, err)
dbNames, collectionNames, err := cache.GetCollectionNamesByID(ctx, []int64{1, 2})
assert.NoError(t, err)
assert.Equal(t, []string{"collection1", "collection2"}, collectionNames)
assert.Equal(t, []string{"db1", "db1"}, dbNames)
})
}

View File

@ -275,6 +275,70 @@ func (_c *MockCache_GetCollectionName_Call) RunAndReturn(run func(context.Contex
return _c
}
// GetCollectionNamesByID provides a mock function with given fields: ctx, collectionID
func (_m *MockCache) GetCollectionNamesByID(ctx context.Context, collectionID []int64) ([]string, []string, error) {
ret := _m.Called(ctx, collectionID)
var r0 []string
var r1 []string
var r2 error
if rf, ok := ret.Get(0).(func(context.Context, []int64) ([]string, []string, error)); ok {
return rf(ctx, collectionID)
}
if rf, ok := ret.Get(0).(func(context.Context, []int64) []string); ok {
r0 = rf(ctx, collectionID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]string)
}
}
if rf, ok := ret.Get(1).(func(context.Context, []int64) []string); ok {
r1 = rf(ctx, collectionID)
} else {
if ret.Get(1) != nil {
r1 = ret.Get(1).([]string)
}
}
if rf, ok := ret.Get(2).(func(context.Context, []int64) error); ok {
r2 = rf(ctx, collectionID)
} else {
r2 = ret.Error(2)
}
return r0, r1, r2
}
// MockCache_GetCollectionNamesByID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetCollectionNamesByID'
type MockCache_GetCollectionNamesByID_Call struct {
*mock.Call
}
// GetCollectionNamesByID is a helper method to define mock.On call
// - ctx context.Context
// - collectionID []int64
func (_e *MockCache_Expecter) GetCollectionNamesByID(ctx interface{}, collectionID interface{}) *MockCache_GetCollectionNamesByID_Call {
return &MockCache_GetCollectionNamesByID_Call{Call: _e.mock.On("GetCollectionNamesByID", ctx, collectionID)}
}
func (_c *MockCache_GetCollectionNamesByID_Call) Run(run func(ctx context.Context, collectionID []int64)) *MockCache_GetCollectionNamesByID_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].([]int64))
})
return _c
}
func (_c *MockCache_GetCollectionNamesByID_Call) Return(_a0 []string, _a1 []string, _a2 error) *MockCache_GetCollectionNamesByID_Call {
_c.Call.Return(_a0, _a1, _a2)
return _c
}
func (_c *MockCache_GetCollectionNamesByID_Call) RunAndReturn(run func(context.Context, []int64) ([]string, []string, error)) *MockCache_GetCollectionNamesByID_Call {
_c.Call.Return(run)
return _c
}
// GetCollectionSchema provides a mock function with given fields: ctx, database, collectionName
func (_m *MockCache) GetCollectionSchema(ctx context.Context, database string, collectionName string) (*schemaInfo, error) {
ret := _m.Called(ctx, database, collectionName)

View File

@ -26,11 +26,13 @@ import (
"time"
"github.com/cockroachdb/errors"
"github.com/samber/lo"
clientv3 "go.etcd.io/etcd/client/v3"
"go.uber.org/atomic"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/hook"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/allocator"
"github.com/milvus-io/milvus/internal/proto/internalpb"
@ -38,6 +40,7 @@ import (
"github.com/milvus-io/milvus/internal/proxy/connection"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/hookutil"
"github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/metrics"
@ -45,6 +48,7 @@ import (
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/util/expr"
"github.com/milvus-io/milvus/pkg/util/logutil"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/metricsinfo"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/ratelimitutil"
@ -65,10 +69,11 @@ type Timestamp = typeutil.Timestamp
// make sure Proxy implements types.Proxy
var _ types.Proxy = (*Proxy)(nil)
var Params *paramtable.ComponentParam = paramtable.Get()
// rateCol is global rateCollector in Proxy.
var rateCol *ratelimitutil.RateCollector
var (
Params = paramtable.Get()
Extension hook.Extension
rateCol *ratelimitutil.RateCollector
)
// Proxy of milvus
type Proxy struct {
@ -151,6 +156,8 @@ func NewProxy(ctx context.Context, factory dependency.Factory) (*Proxy, error) {
}
node.UpdateStateCode(commonpb.StateCode_Abnormal)
expr.Register("proxy", node)
hookutil.InitOnceHook()
Extension = hookutil.Extension
logutil.Logger(ctx).Debug("create a new Proxy instance", zap.Any("state", node.stateCode.Load()))
return node, nil
}
@ -415,6 +422,12 @@ func (node *Proxy) Start() error {
cb()
}
Extension.Report(map[string]any{
hookutil.OpTypeKey: hookutil.OpTypeNodeID,
hookutil.NodeIDKey: paramtable.GetNodeID(),
})
node.startReportCollectionStorage()
log.Debug("update state code", zap.String("role", typeutil.ProxyRole), zap.String("State", commonpb.StateCode_Healthy.String()))
node.UpdateStateCode(commonpb.StateCode_Healthy)
@ -537,3 +550,87 @@ func (node *Proxy) GetRateLimiter() (types.Limiter, error) {
}
return node.multiRateLimiter, nil
}
func (node *Proxy) startReportCollectionStorage() {
go func() {
tick := time.NewTicker(30 * time.Second)
defer tick.Stop()
for {
select {
case <-node.ctx.Done():
return
case <-tick.C:
_ = node.reportCollectionStorage()
}
}
}()
}
func (node *Proxy) reportCollectionStorage() error {
if node.dataCoord == nil {
return errors.New("nil datacoord")
}
req, err := metricsinfo.ConstructRequestByMetricType(metricsinfo.CollectionStorageMetrics)
if err != nil {
return err
}
rsp, err := node.dataCoord.GetMetrics(node.ctx, req)
if err = merr.CheckRPCCall(rsp, err); err != nil {
log.Warn("failed to get metrics", zap.Error(err))
return err
}
dataCoordTopology := &metricsinfo.DataCoordTopology{}
err = metricsinfo.UnmarshalTopology(rsp.GetResponse(), dataCoordTopology)
if err != nil {
log.Warn("failed to unmarshal topology", zap.Error(err))
return err
}
quotaMetric := dataCoordTopology.Cluster.Self.QuotaMetrics
if quotaMetric == nil {
log.Warn("quota metric is nil")
return errors.New("quota metric is nil")
}
ctx, cancelFunc := context.WithTimeout(node.ctx, 5*time.Second)
defer cancelFunc()
ids := lo.Keys(quotaMetric.CollectionBinlogSize)
dbNames, collectionNames, err := globalMetaCache.GetCollectionNamesByID(ctx, ids)
if err != nil {
log.Warn("failed to get collection names", zap.Error(err))
return err
}
if len(ids) != len(dbNames) || len(ids) != len(collectionNames) {
log.Warn("failed to get collection names",
zap.Int("len(ids)", len(ids)),
zap.Int("len(dbNames)", len(dbNames)),
zap.Int("len(collectionNames)", len(collectionNames)))
return errors.New("failed to get collection names")
}
nameInfos := make(map[typeutil.UniqueID]lo.Tuple2[string, string])
for i, k := range ids {
nameInfos[k] = lo.Tuple2[string, string]{A: dbNames[i], B: collectionNames[i]}
}
storeInfo := make(map[string]int64)
for collectionID, dataSize := range quotaMetric.CollectionBinlogSize {
nameTuple, ok := nameInfos[collectionID]
if !ok {
continue
}
storeInfo[nameTuple.A] += dataSize
}
if len(storeInfo) > 0 {
Extension.Report(map[string]any{
hookutil.OpTypeKey: hookutil.OpTypeStorage,
hookutil.StorageDetailKey: lo.MapValues(storeInfo, func(v int64, _ string) any { return v }),
})
}
return nil
}

View File

@ -60,6 +60,7 @@ import (
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
"github.com/milvus-io/milvus/internal/util/componentutil"
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/hookutil"
"github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
@ -1085,7 +1086,7 @@ func TestProxy(t *testing.T) {
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode())
})
var insertedIds []int64
var insertedIDs []int64
wg.Add(1)
t.Run("insert", func(t *testing.T) {
defer wg.Done()
@ -1100,7 +1101,7 @@ func TestProxy(t *testing.T) {
switch field := resp.GetIDs().GetIdField().(type) {
case *schemapb.IDs_IntId:
insertedIds = field.IntId.GetData()
insertedIDs = field.IntId.GetData()
default:
t.Fatalf("Unexpected ID type")
}
@ -1611,7 +1612,7 @@ func TestProxy(t *testing.T) {
nq = 10
constructPrimaryKeysPlaceholderGroup := func() *commonpb.PlaceholderGroup {
expr := fmt.Sprintf("%v in [%v]", int64Field, insertedIds[0])
expr := fmt.Sprintf("%v in [%v]", int64Field, insertedIDs[0])
exprBytes := []byte(expr)
return &commonpb.PlaceholderGroup{
@ -4803,3 +4804,227 @@ func TestUnhealthProxy_GetIndexStatistics(t *testing.T) {
assert.Equal(t, commonpb.ErrorCode_NotReadyServe, resp.GetStatus().GetErrorCode())
})
}
func TestProxy_ReportCollectionStorage(t *testing.T) {
t.Run("nil datacoord", func(t *testing.T) {
proxy := &Proxy{}
err := proxy.reportCollectionStorage()
assert.Error(t, err)
})
t.Run("fail to get metric", func(t *testing.T) {
datacoord := mocks.NewMockDataCoordClient(t)
datacoord.EXPECT().GetMetrics(mock.Anything, mock.Anything).Return(&milvuspb.GetMetricsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Code: 500,
},
}, nil).Once()
ctx := context.Background()
proxy := &Proxy{
ctx: ctx,
dataCoord: datacoord,
}
err := proxy.reportCollectionStorage()
assert.Error(t, err)
})
t.Run("fail to unmarshal metric", func(t *testing.T) {
datacoord := mocks.NewMockDataCoordClient(t)
datacoord.EXPECT().GetMetrics(mock.Anything, mock.Anything).Return(&milvuspb.GetMetricsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
Response: "invalid",
}, nil).Once()
ctx := context.Background()
proxy := &Proxy{
ctx: ctx,
dataCoord: datacoord,
}
err := proxy.reportCollectionStorage()
assert.Error(t, err)
})
t.Run("empty metric", func(t *testing.T) {
datacoord := mocks.NewMockDataCoordClient(t)
r, _ := json.Marshal(&metricsinfo.DataCoordTopology{
Cluster: metricsinfo.DataClusterTopology{
Self: metricsinfo.DataCoordInfos{},
},
})
datacoord.EXPECT().GetMetrics(mock.Anything, mock.Anything).Return(&milvuspb.GetMetricsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
Response: string(r),
ComponentName: "DataCoord",
}, nil).Once()
ctx := context.Background()
proxy := &Proxy{
ctx: ctx,
dataCoord: datacoord,
}
err := proxy.reportCollectionStorage()
assert.Error(t, err)
})
t.Run("fail to get cache", func(t *testing.T) {
origin := globalMetaCache
defer func() {
globalMetaCache = origin
}()
mockCache := NewMockCache(t)
globalMetaCache = mockCache
datacoord := mocks.NewMockDataCoordClient(t)
r, _ := json.Marshal(&metricsinfo.DataCoordTopology{
Cluster: metricsinfo.DataClusterTopology{
Self: metricsinfo.DataCoordInfos{
QuotaMetrics: &metricsinfo.DataCoordQuotaMetrics{
TotalBinlogSize: 200,
CollectionBinlogSize: map[int64]int64{
1: 100,
2: 50,
3: 50,
},
},
},
},
})
datacoord.EXPECT().GetMetrics(mock.Anything, mock.Anything).Return(&milvuspb.GetMetricsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
Response: string(r),
ComponentName: "DataCoord",
}, nil).Once()
mockCache.EXPECT().GetCollectionNamesByID(mock.Anything, mock.Anything).Return(nil, nil, errors.New("mock get collection names by id error")).Once()
ctx := context.Background()
proxy := &Proxy{
ctx: ctx,
dataCoord: datacoord,
}
err := proxy.reportCollectionStorage()
assert.Error(t, err)
})
t.Run("not match data", func(t *testing.T) {
origin := globalMetaCache
defer func() {
globalMetaCache = origin
}()
mockCache := NewMockCache(t)
globalMetaCache = mockCache
datacoord := mocks.NewMockDataCoordClient(t)
r, _ := json.Marshal(&metricsinfo.DataCoordTopology{
Cluster: metricsinfo.DataClusterTopology{
Self: metricsinfo.DataCoordInfos{
QuotaMetrics: &metricsinfo.DataCoordQuotaMetrics{
TotalBinlogSize: 200,
CollectionBinlogSize: map[int64]int64{
1: 100,
2: 50,
3: 50,
},
},
},
},
})
datacoord.EXPECT().GetMetrics(mock.Anything, mock.Anything).Return(&milvuspb.GetMetricsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
Response: string(r),
ComponentName: "DataCoord",
}, nil).Once()
mockCache.EXPECT().GetCollectionNamesByID(mock.Anything, mock.Anything).Return(
[]string{"db1", "db1"}, []string{"col1", "col2"}, nil).Once()
ctx := context.Background()
proxy := &Proxy{
ctx: ctx,
dataCoord: datacoord,
}
err := proxy.reportCollectionStorage()
assert.Error(t, err)
})
t.Run("success", func(t *testing.T) {
origin := globalMetaCache
defer func() {
globalMetaCache = origin
}()
mockCache := NewMockCache(t)
globalMetaCache = mockCache
datacoord := mocks.NewMockDataCoordClient(t)
r, _ := json.Marshal(&metricsinfo.DataCoordTopology{
Cluster: metricsinfo.DataClusterTopology{
Self: metricsinfo.DataCoordInfos{
QuotaMetrics: &metricsinfo.DataCoordQuotaMetrics{
TotalBinlogSize: 200,
CollectionBinlogSize: map[int64]int64{
1: 100,
2: 50,
3: 50,
},
},
},
},
})
datacoord.EXPECT().GetMetrics(mock.Anything, mock.Anything).Return(&milvuspb.GetMetricsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
Response: string(r),
ComponentName: "DataCoord",
}, nil).Once()
mockCache.EXPECT().GetCollectionNamesByID(mock.Anything, mock.Anything).Return(
[]string{"db1", "db1", "db2"}, []string{"col1", "col2", "col3"}, nil).Once()
originExtension := Extension
defer func() {
Extension = originExtension
}()
hasCheck := false
Extension = CheckExtension{
reportChecker: func(info any) {
infoMap := info.(map[string]any)
storage := infoMap[hookutil.StorageDetailKey].(map[string]any)
log.Info("storage map", zap.Any("storage", storage))
assert.EqualValues(t, 150, storage["db1"])
assert.EqualValues(t, 50, storage["db2"])
hasCheck = true
},
}
ctx := context.Background()
proxy := &Proxy{
ctx: ctx,
dataCoord: datacoord,
}
err := proxy.reportCollectionStorage()
assert.NoError(t, err)
assert.True(t, hasCheck)
})
}
type CheckExtension struct {
reportChecker func(info any)
}
func (c CheckExtension) Report(info any) int {
c.reportChecker(info)
return 0
}

View File

@ -117,6 +117,7 @@ func reduceSearchResultDataWithGroupBy(ctx context.Context, subSearchResultData
zap.Int64("topk", sData.TopK),
zap.Int("length of pks", pkLength),
zap.Int("length of FieldsData", len(sData.FieldsData)))
ret.Results.AllSearchCount += sData.GetAllSearchCount()
if err := checkSearchResultData(sData, nq, topk); err != nil {
log.Ctx(ctx).Warn("invalid search results", zap.Error(err))
return ret, err
@ -280,6 +281,7 @@ func reduceSearchResultDataNoGroupBy(ctx context.Context, subSearchResultData []
zap.Int64("topk", sData.TopK),
zap.Int("length of pks", pkLength),
zap.Int("length of FieldsData", len(sData.FieldsData)))
ret.Results.AllSearchCount += sData.GetAllSearchCount()
if err := checkSearchResultData(sData, nq, topk); err != nil {
log.Ctx(ctx).Warn("invalid search results", zap.Error(err))
return ret, err

View File

@ -61,7 +61,8 @@ type deleteTask struct {
msgID UniqueID
// result
count int64
count int64
allQueryCnt int64
}
func (dt *deleteTask) TraceCtx() context.Context {
@ -246,6 +247,8 @@ type deleteRunner struct {
// task queue
queue *dmTaskQueue
allQueryCnt atomic.Int64
}
func (dr *deleteRunner) Init(ctx context.Context) error {
@ -422,6 +425,7 @@ func (dr *deleteRunner) getStreamingQueryAndDelteFunc(plan *planpb.PlanNode) exe
taskCh := make(chan *deleteTask, 256)
go dr.receiveQueryResult(ctx, client, taskCh)
var allQueryCnt int64
// wait all task finish
for task := range taskCh {
err := task.WaitToFinish()
@ -429,12 +433,14 @@ func (dr *deleteRunner) getStreamingQueryAndDelteFunc(plan *planpb.PlanNode) exe
return err
}
dr.count.Add(task.count)
allQueryCnt += task.allQueryCnt
}
// query or produce task failed
if dr.err != nil {
return dr.err
}
dr.allQueryCnt.Add(allQueryCnt)
return nil
}
}
@ -468,6 +474,7 @@ func (dr *deleteRunner) receiveQueryResult(ctx context.Context, client querypb.Q
log.Warn("produce delete task failed", zap.Error(err))
return
}
task.allQueryCnt = result.GetAllRetrieveCount()
taskCh <- task
}

View File

@ -37,6 +37,7 @@ type hybridSearchTask struct {
ctx context.Context
*internalpb.HybridSearchRequest
dimension int64
result *milvuspb.SearchResults
request *milvuspb.HybridSearchRequest
searchTasks []*searchTask
@ -101,6 +102,11 @@ func (t *hybridSearchTask) PreExecute(ctx context.Context) error {
log.Warn("get collection schema failed", zap.Error(err))
return err
}
t.dimension, err = typeutil.GetCollectionDim(t.schema.CollectionSchema)
if err != nil {
log.Warn("get collection dimension failed", zap.Error(err))
return err
}
t.partitionKeyMode, err = isPartitionKeyMode(ctx, t.request.GetDbName(), collectionName)
if err != nil {
@ -529,6 +535,7 @@ func rankSearchResultData(ctx context.Context,
}
for _, result := range searchResults {
ret.Results.AllSearchCount += result.GetResults().GetAllSearchCount()
scores := result.GetResults().GetScores()
start := int64(0)
for i := int64(0); i < nq; i++ {

View File

@ -54,6 +54,7 @@ type queryTask struct {
collectionName string
queryParams *queryParams
schema *schemaInfo
dimension int64
userOutputFields []string
@ -65,7 +66,8 @@ type queryTask struct {
channelsMvcc map[string]Timestamp
fastSkip bool
reQuery bool
reQuery bool
allQueryCnt int64
}
type queryParams struct {
@ -333,8 +335,17 @@ func (t *queryTask) PreExecute(ctx context.Context) error {
t.queryParams = queryParams
t.RetrieveRequest.Limit = queryParams.limit + queryParams.offset
schema, _ := globalMetaCache.GetCollectionSchema(ctx, t.request.GetDbName(), t.collectionName)
schema, err := globalMetaCache.GetCollectionSchema(ctx, t.request.GetDbName(), t.collectionName)
if err != nil {
log.Warn("get collection schema failed", zap.Error(err))
return err
}
t.schema = schema
t.dimension, err = typeutil.GetCollectionDim(t.schema.CollectionSchema)
if err != nil {
log.Warn("get collection dimension failed", zap.Error(err))
return err
}
if t.ids != nil {
pkField := ""
@ -469,6 +480,7 @@ func (t *queryTask) PostExecute(ctx context.Context) error {
var err error
toReduceResults := make([]*internalpb.RetrieveResults, 0)
t.allQueryCnt = 0
select {
case <-t.TraceCtx().Done():
log.Warn("proxy", zap.Int64("Query: wait to finish failed, timeout!, msgID:", t.ID()))
@ -477,6 +489,7 @@ func (t *queryTask) PostExecute(ctx context.Context) error {
log.Debug("all queries are finished or canceled")
t.resultBuf.Range(func(res *internalpb.RetrieveResults) bool {
toReduceResults = append(toReduceResults, res)
t.allQueryCnt += res.GetAllRetrieveCount()
log.Debug("proxy receives one query result", zap.Int64("sourceID", res.GetBase().GetSourceID()))
return true
})

View File

@ -472,7 +472,6 @@ func (sched *taskScheduler) processTask(t task, q taskQueue) {
span.AddEvent("scheduler process PostExecute")
err = t.PostExecute(ctx)
if err != nil {
span.RecordError(err)
log.Ctx(ctx).Warn("Failed to post-execute task: ", zap.Error(err))

View File

@ -66,6 +66,7 @@ type searchTask struct {
userOutputFields []string
offset int64
dimension int64
resultBuf *typeutil.ConcurrentSet[*internalpb.SearchResults]
qc types.QueryCoordClient
@ -304,6 +305,11 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
log.Warn("get collection schema failed", zap.Error(err))
return err
}
t.dimension, err = typeutil.GetCollectionDim(t.schema.CollectionSchema)
if err != nil {
log.Warn("get collection dimension failed", zap.Error(err))
return err
}
t.partitionKeyMode, err = isPartitionKeyMode(ctx, t.request.GetDbName(), collectionName)
if err != nil {

View File

@ -686,15 +686,15 @@ func autoGenPrimaryFieldData(fieldSchema *schemapb.FieldSchema, data interface{}
},
}
case schemapb.DataType_VarChar:
strIds := make([]string, len(data))
strIDs := make([]string, len(data))
for i, v := range data {
strIds[i] = strconv.FormatInt(v, 10)
strIDs[i] = strconv.FormatInt(v, 10)
}
fieldData.Field = &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_StringData{
StringData: &schemapb.StringArray{
Data: strIds,
Data: strIDs,
},
},
},
@ -903,6 +903,11 @@ func GetCurUserFromContext(ctx context.Context) (string, error) {
return username, nil
}
func GetCurUserFromContextOrDefault(ctx context.Context) string {
username, _ := GetCurUserFromContext(ctx)
return username
}
func GetCurDBNameFromContextOrDefault(ctx context.Context) string {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
@ -1634,3 +1639,16 @@ func CheckDatabase(ctx context.Context, dbName string) bool {
}
return false
}
func SetReportValue(status *commonpb.Status, value int) {
if value <= 0 {
return
}
if !merr.Ok(status) {
return
}
if status.ExtraInfo == nil {
status.ExtraInfo = make(map[string]string)
}
status.ExtraInfo["report_value"] = strconv.Itoa(value)
}

View File

@ -12,26 +12,34 @@ type cntReducer struct{}
func (r *cntReducer) Reduce(ctx context.Context, results []*internalpb.RetrieveResults) (*internalpb.RetrieveResults, error) {
cnt := int64(0)
allRetrieveCount := int64(0)
for _, res := range results {
allRetrieveCount += res.GetAllRetrieveCount()
c, err := funcutil.CntOfInternalResult(res)
if err != nil {
return nil, err
}
cnt += c
}
return funcutil.WrapCntToInternalResult(cnt), nil
res := funcutil.WrapCntToInternalResult(cnt)
res.AllRetrieveCount = allRetrieveCount
return res, nil
}
type cntReducerSegCore struct{}
func (r *cntReducerSegCore) Reduce(ctx context.Context, results []*segcorepb.RetrieveResults) (*segcorepb.RetrieveResults, error) {
cnt := int64(0)
allRetrieveCount := int64(0)
for _, res := range results {
allRetrieveCount += res.GetAllRetrieveCount()
c, err := funcutil.CntOfSegCoreResult(res)
if err != nil {
return nil, err
}
cnt += c
}
return funcutil.WrapCntToSegCoreResult(cnt), nil
res := funcutil.WrapCntToSegCoreResult(cnt)
res.AllRetrieveCount = allRetrieveCount
return res, nil
}

View File

@ -130,6 +130,7 @@ func ReduceSearchResultData(ctx context.Context, searchResultData []*schemapb.Se
for j := int64(1); j < nq; j++ {
resultOffsets[i][j] = resultOffsets[i][j-1] + searchResultData[i].Topks[j-1]
}
ret.AllSearchCount += searchResultData[i].GetAllSearchCount()
}
var skipDupCnt int64
@ -284,6 +285,7 @@ func MergeInternalRetrieveResult(ctx context.Context, retrieveResults []*interna
validRetrieveResults := []*internalpb.RetrieveResults{}
for _, r := range retrieveResults {
ret.AllRetrieveCount += r.GetAllRetrieveCount()
size := typeutil.GetSizeOfIDs(r.GetIds())
if r == nil || len(r.GetFieldsData()) == 0 || size == 0 {
continue
@ -388,6 +390,7 @@ func MergeSegcoreRetrieveResults(ctx context.Context, retrieveResults []*segcore
validRetrieveResults := []*segcorepb.RetrieveResults{}
for _, r := range retrieveResults {
size := typeutil.GetSizeOfIDs(r.GetIds())
ret.AllRetrieveCount += r.GetAllRetrieveCount()
if r == nil || len(r.GetOffset()) == 0 || size == 0 {
log.Debug("filter out invalid retrieve result")
continue

View File

@ -118,9 +118,10 @@ func retrieveOnSegmentsWithStream(ctx context.Context, segments []Segment, segTy
if len(result.GetOffset()) != 0 {
if err = svr.Send(&internalpb.RetrieveResults{
Status: merr.Success(),
Ids: result.GetIds(),
FieldsData: result.GetFieldsData(),
Status: merr.Success(),
Ids: result.GetIds(),
FieldsData: result.GetFieldsData(),
AllRetrieveCount: result.GetAllRetrieveCount(),
}); err != nil {
errs[i] = err
}

View File

@ -626,7 +626,7 @@ func (s *LocalSegment) Insert(ctx context.Context, rowIDs []int64, timestamps []
numOfRow := len(rowIDs)
cOffset := C.int64_t(offset)
cNumOfRows := C.int64_t(numOfRow)
cEntityIdsPtr := (*C.int64_t)(&(rowIDs)[0])
cEntityIDsPtr := (*C.int64_t)(&(rowIDs)[0])
cTimestampsPtr := (*C.uint64_t)(&(timestamps)[0])
var status C.CStatus
@ -635,7 +635,7 @@ func (s *LocalSegment) Insert(ctx context.Context, rowIDs []int64, timestamps []
status = C.Insert(s.ptr,
cOffset,
cNumOfRows,
cEntityIdsPtr,
cEntityIDsPtr,
cTimestampsPtr,
(*C.uint8_t)(unsafe.Pointer(&insertRecordBlob[0])),
(C.uint64_t)(len(insertRecordBlob)),

View File

@ -135,6 +135,7 @@ func (t *QueryTask) Execute() error {
CostAggregation: &internalpb.CostAggregation{
ServiceTime: tr.ElapseSpan().Milliseconds(),
},
AllRetrieveCount: reducedResult.GetAllRetrieveCount(),
}
return nil
}

View File

@ -48,10 +48,10 @@ func checkGeneralCapacity(ctx context.Context, newColNum int,
var generalNum int64 = 0
collectionsMap := core.meta.ListAllAvailCollections(ctx)
for dbId, collectionIds := range collectionsMap {
for dbId, collectionIDs := range collectionsMap {
db, err := core.meta.GetDatabaseByID(ctx, dbId, ts)
if err == nil {
for _, collectionId := range collectionIds {
for _, collectionId := range collectionIDs {
collection, err := core.meta.GetCollectionByID(ctx, db.Name, collectionId, ts, true)
if err == nil {
partNum := int64(collection.GetPartitionNum(false))

View File

@ -164,7 +164,7 @@ func TransferColumnBasedInsertDataToRowBased(data *InsertData) (
}
tss := data.Data[common.TimeStampField].(*Int64FieldData)
rowIds := data.Data[common.RowIDField].(*Int64FieldData)
rowIDs := data.Data[common.RowIDField].(*Int64FieldData)
ls := fieldDataList{}
for fieldID := range data.Data {
@ -176,8 +176,8 @@ func TransferColumnBasedInsertDataToRowBased(data *InsertData) (
ls.datas = append(ls.datas, data.Data[fieldID])
}
// checkNumRows(tss, rowIds, ls.datas...) // don't work
all := []FieldData{tss, rowIds}
// checkNumRows(tss, rowIDs, ls.datas...) // don't work
all := []FieldData{tss, rowIDs}
all = append(all, ls.datas...)
if !checkNumRows(all...) {
return nil, nil, nil,
@ -210,7 +210,7 @@ func TransferColumnBasedInsertDataToRowBased(data *InsertData) (
utss[i] = uint64(tss.Data[i])
}
return utss, rowIds.Data, rows, nil
return utss, rowIDs.Data, rows, nil
}
///////////////////////////////////////////////////////////////////////////////////////////

View File

@ -125,10 +125,10 @@ func TestTransferColumnBasedInsertDataToRowBased(t *testing.T) {
_, _, _, err = TransferColumnBasedInsertDataToRowBased(data)
assert.Error(t, err)
rowIdsF := &Int64FieldData{
rowIDsF := &Int64FieldData{
Data: []int64{1, 2, 3, 4},
}
data.Data[common.RowIDField] = rowIdsF
data.Data[common.RowIDField] = rowIDsF
// row num mismatch
_, _, _, err = TransferColumnBasedInsertDataToRowBased(data)
@ -193,10 +193,10 @@ func TestTransferColumnBasedInsertDataToRowBased(t *testing.T) {
data.Data[111] = f11
data.Data[112] = f12
utss, rowIds, rows, err := TransferColumnBasedInsertDataToRowBased(data)
utss, rowIDs, rows, err := TransferColumnBasedInsertDataToRowBased(data)
assert.NoError(t, err)
assert.ElementsMatch(t, []uint64{1, 2, 3}, utss)
assert.ElementsMatch(t, []int64{1, 2, 3}, rowIds)
assert.ElementsMatch(t, []int64{1, 2, 3}, rowIDs)
assert.Equal(t, 3, len(rows))
// b := []byte("1")[0]
if common.Endian == binary.LittleEndian {

View File

@ -0,0 +1,43 @@
/*
* Licensed to the LF AI & Data foundation under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package hookutil
var (
// WARN: Please DO NOT modify all constants.
OpTypeKey = "op_type"
DatabaseKey = "database"
UsernameKey = "username"
DataSizeKey = "data_size"
SuccessCntKey = "success_cnt"
FailCntKey = "fail_cnt"
RelatedCntKey = "related_cnt"
StorageDetailKey = "storage_detail"
NodeIDKey = "id"
DimensionKey = "dim"
OpTypeInsert = "insert"
OpTypeDelete = "delete"
OpTypeUpsert = "upsert"
OpTypeQuery = "query"
OpTypeSearch = "search"
OpTypeHybridSearch = "hybrid_search"
OpTypeStorage = "storage"
OpTypeNodeID = "node_id"
)

View File

@ -0,0 +1,72 @@
/*
* Licensed to the LF AI & Data foundation under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package hookutil
import (
"context"
"github.com/cockroachdb/errors"
"github.com/milvus-io/milvus-proto/go-api/v2/hook"
)
type DefaultHook struct{}
var _ hook.Hook = (*DefaultHook)(nil)
func (d DefaultHook) VerifyAPIKey(key string) (string, error) {
return "", errors.New("default hook, can't verify api key")
}
func (d DefaultHook) Init(params map[string]string) error {
return nil
}
func (d DefaultHook) Mock(ctx context.Context, req interface{}, fullMethod string) (bool, interface{}, error) {
return false, nil, nil
}
func (d DefaultHook) Before(ctx context.Context, req interface{}, fullMethod string) (context.Context, error) {
return ctx, nil
}
func (d DefaultHook) After(ctx context.Context, result interface{}, err error, fullMethod string) error {
return nil
}
// MockAPIHook is a mock hook for api key verification, ONLY FOR TEST
type MockAPIHook struct {
DefaultHook
MockErr error
User string
}
func (m MockAPIHook) VerifyAPIKey(apiKey string) (string, error) {
return m.User, m.MockErr
}
func (d DefaultHook) Release() {}
type DefaultExtension struct{}
var _ hook.Extension = (*DefaultExtension)(nil)
func (d DefaultExtension) Report(info any) int {
return 0
}

View File

@ -0,0 +1,102 @@
/*
* Licensed to the LF AI & Data foundation under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package hookutil
import (
"fmt"
"plugin"
"sync"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/hook"
"github.com/milvus-io/milvus/pkg/config"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/paramtable"
)
var (
Hoo hook.Hook
Extension hook.Extension
initOnce sync.Once
)
func initHook() error {
Hoo = DefaultHook{}
Extension = DefaultExtension{}
path := paramtable.Get().ProxyCfg.SoPath.GetValue()
if path == "" {
log.Info("empty so path, skip to load plugin")
return nil
}
log.Info("start to load plugin", zap.String("path", path))
p, err := plugin.Open(path)
if err != nil {
return fmt.Errorf("fail to open the plugin, error: %s", err.Error())
}
log.Info("plugin open")
h, err := p.Lookup("MilvusHook")
if err != nil {
return fmt.Errorf("fail to the 'MilvusHook' object in the plugin, error: %s", err.Error())
}
var ok bool
Hoo, ok = h.(hook.Hook)
if !ok {
return fmt.Errorf("fail to convert the `Hook` interface")
}
if err = Hoo.Init(paramtable.GetHookParams().SoConfig.GetValue()); err != nil {
return fmt.Errorf("fail to init configs for the hook, error: %s", err.Error())
}
paramtable.GetHookParams().WatchHookWithPrefix("watch_hook", "", func(event *config.Event) {
log.Info("receive the hook refresh event", zap.Any("event", event))
go func() {
soConfig := paramtable.GetHookParams().SoConfig.GetValue()
log.Info("refresh hook configs", zap.Any("config", soConfig))
if err = Hoo.Init(soConfig); err != nil {
log.Panic("fail to init configs for the hook when refreshing", zap.Error(err))
}
}()
})
e, err := p.Lookup("MilvusExtension")
if err != nil {
return fmt.Errorf("fail to the 'MilvusExtension' object in the plugin, error: %s", err.Error())
}
Extension, ok = e.(hook.Extension)
if !ok {
return fmt.Errorf("fail to convert the `Extension` interface")
}
return nil
}
func InitOnceHook() {
initOnce.Do(func() {
err := initHook()
if err != nil {
log.Warn("fail to init hook",
zap.String("so_path", paramtable.Get().ProxyCfg.SoPath.GetValue()),
zap.Error(err))
}
})
}

View File

@ -0,0 +1,52 @@
/*
* Licensed to the LF AI & Data foundation under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package hookutil
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus/pkg/util/paramtable"
)
func TestInitHook(t *testing.T) {
paramtable.Init()
Params := paramtable.Get()
paramtable.Get().Save(Params.ProxyCfg.SoPath.Key, "")
initHook()
assert.IsType(t, DefaultHook{}, Hoo)
paramtable.Get().Save(Params.ProxyCfg.SoPath.Key, "/a/b/hook.so")
err := initHook()
assert.Error(t, err)
paramtable.Get().Save(Params.ProxyCfg.SoPath.Key, "")
}
func TestDefaultHook(t *testing.T) {
d := &DefaultHook{}
assert.NoError(t, d.Init(nil))
{
_, err := d.VerifyAPIKey("key")
assert.Error(t, err)
}
assert.NotPanics(t, func() {
d.Release()
})
}

View File

@ -13,7 +13,7 @@ func appendFieldData(result RetrieveResults, fieldData *schemapb.FieldData) {
result.AppendFieldData(fieldData)
}
func FillRetrieveResultIfEmpty(result RetrieveResults, outputFieldIds []int64, schema *schemapb.CollectionSchema) error {
func FillRetrieveResultIfEmpty(result RetrieveResults, outputFieldIDs []int64, schema *schemapb.CollectionSchema) error {
if !result.ResultEmpty() {
return nil
}
@ -24,7 +24,7 @@ func FillRetrieveResultIfEmpty(result RetrieveResults, outputFieldIds []int64, s
if err != nil {
return err
}
for _, outputFieldID := range outputFieldIds {
for _, outputFieldID := range outputFieldIDs {
field, err := helper.GetFieldFromID(outputFieldID)
if err != nil {
return err

View File

@ -14,7 +14,7 @@ require (
github.com/grpc-ecosystem/go-grpc-middleware v1.3.0
github.com/klauspost/compress v1.16.5
github.com/lingdor/stackerror v0.0.0-20191119040541-976d8885ed76
github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240228061649-a922b16f2a46
github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240313071055-2c89f346b00f
github.com/nats-io/nats-server/v2 v2.9.17
github.com/nats-io/nats.go v1.24.0
github.com/panjf2000/ants/v2 v2.7.2

View File

@ -483,8 +483,8 @@ github.com/matttproud/golang_protobuf_extensions v1.0.4/go.mod h1:BSXmuO+STAnVfr
github.com/mediocregopher/radix/v3 v3.4.2/go.mod h1:8FL3F6UQRXHXIBSPUs5h0RybMF8i4n7wVopoX3x7Bv8=
github.com/microcosm-cc/bluemonday v1.0.2/go.mod h1:iVP4YcDBq+n/5fb23BhYFvIMq/leAFZyRl6bYmGDlGc=
github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg=
github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240228061649-a922b16f2a46 h1:IgoGNTbsRPa2kdNI+IWuZrrortFEjTB42/gYDklZHVU=
github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240228061649-a922b16f2a46/go.mod h1:1OIl0v5PQeNxIJhCvY+K55CBUOYDZevw9g9380u1Wek=
github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240313071055-2c89f346b00f h1:f8rRJ5zatNq2WszAwy6S+J0Z2h7/CArqLDJ0gTHSFNs=
github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240313071055-2c89f346b00f/go.mod h1:1OIl0v5PQeNxIJhCvY+K55CBUOYDZevw9g9380u1Wek=
github.com/milvus-io/pulsar-client-go v0.6.10 h1:eqpJjU+/QX0iIhEo3nhOqMNXL+TyInAs1IAHZCrCM/A=
github.com/milvus-io/pulsar-client-go v0.6.10/go.mod h1:lQqCkgwDF8YFYjKA+zOheTk1tev2B+bKj5j7+nm8M1w=
github.com/minio/highwayhash v1.0.2 h1:Aak5U0nElisjDCfPSG79Tgzkn2gl66NxOMspRrKnA/g=

View File

@ -27,6 +27,9 @@ const (
// SystemInfoMetrics means users request for system information metrics.
SystemInfoMetrics = "system_info"
// CollectionStorageMetrics means users request for collection storage metrics.
CollectionStorageMetrics = "collection_storage"
)
// ParseMetricType returns the metric type of req

View File

@ -27,3 +27,14 @@ func GetDim(field *schemapb.FieldSchema) (int64, error) {
}
return int64(dim), nil
}
func GetCollectionDim(collection *schemapb.CollectionSchema) (int64, error) {
for _, fieldSchema := range collection.GetFields() {
dim, err := GetDim(fieldSchema)
if err != nil {
continue
}
return dim, nil
}
return 0, fmt.Errorf("dim not found")
}

View File

@ -20,6 +20,7 @@ import (
"context"
"fmt"
"testing"
"time"
"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/suite"
@ -28,6 +29,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/util/hookutil"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/funcutil"
@ -77,6 +79,27 @@ func (s *HelloMilvusSuite) TestHelloMilvus() {
fVecColumn := integration.NewFloatVectorFieldData(integration.FloatVecField, rowNum, dim)
hashKeys := integration.GenerateHashKeys(rowNum)
insertCheckReport := func() {
timeoutCtx, cancelFunc := context.WithTimeout(ctx, 5*time.Second)
defer cancelFunc()
for {
select {
case <-timeoutCtx.Done():
s.Fail("insert check timeout")
case report := <-c.Extension.GetReportChan():
reportInfo := report.(map[string]any)
log.Info("insert report info", zap.Any("reportInfo", reportInfo))
if reportInfo[hookutil.OpTypeKey] == hookutil.OpTypeStorage {
continue
}
s.Equal(hookutil.OpTypeInsert, reportInfo[hookutil.OpTypeKey])
s.NotEqualValues(0, reportInfo[hookutil.DataSizeKey])
return
}
}
}
go insertCheckReport()
insertResult, err := c.Proxy.Insert(ctx, &milvuspb.InsertRequest{
DbName: dbName,
CollectionName: collectionName,
@ -145,11 +168,99 @@ func (s *HelloMilvusSuite) TestHelloMilvus() {
searchReq := integration.ConstructSearchRequest("", collectionName, expr,
integration.FloatVecField, schemapb.DataType_FloatVector, nil, metric.L2, params, nq, dim, topk, roundDecimal)
searchResult, err := c.Proxy.Search(ctx, searchReq)
searchCheckReport := func() {
timeoutCtx, cancelFunc := context.WithTimeout(ctx, 5*time.Second)
defer cancelFunc()
for {
select {
case <-timeoutCtx.Done():
s.Fail("search check timeout")
case report := <-c.Extension.GetReportChan():
reportInfo := report.(map[string]any)
log.Info("search report info", zap.Any("reportInfo", reportInfo))
if reportInfo[hookutil.OpTypeKey] == hookutil.OpTypeStorage {
continue
}
s.Equal(hookutil.OpTypeSearch, reportInfo[hookutil.OpTypeKey])
s.NotEqualValues(0, reportInfo[hookutil.DataSizeKey])
s.EqualValues(rowNum, reportInfo[hookutil.RelatedCntKey])
return
}
}
}
go searchCheckReport()
searchResult, err := c.Proxy.Search(ctx, searchReq)
err = merr.CheckRPCCall(searchResult, err)
s.NoError(err)
queryCheckReport := func() {
timeoutCtx, cancelFunc := context.WithTimeout(ctx, 5*time.Second)
defer cancelFunc()
for {
select {
case <-timeoutCtx.Done():
s.Fail("query check timeout")
case report := <-c.Extension.GetReportChan():
reportInfo := report.(map[string]any)
log.Info("query report info", zap.Any("reportInfo", reportInfo))
if reportInfo[hookutil.OpTypeKey] == hookutil.OpTypeStorage {
continue
}
s.Equal(hookutil.OpTypeQuery, reportInfo[hookutil.OpTypeKey])
s.NotEqualValues(0, reportInfo[hookutil.DataSizeKey])
s.EqualValues(rowNum, reportInfo[hookutil.RelatedCntKey])
return
}
}
}
go queryCheckReport()
queryResult, err := c.Proxy.Query(ctx, &milvuspb.QueryRequest{
DbName: dbName,
CollectionName: collectionName,
Expr: "",
OutputFields: []string{"count(*)"},
})
if queryResult.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
log.Warn("searchResult fail reason", zap.String("reason", queryResult.GetStatus().GetReason()))
}
s.NoError(err)
s.Equal(commonpb.ErrorCode_Success, queryResult.GetStatus().GetErrorCode())
deleteCheckReport := func() {
timeoutCtx, cancelFunc := context.WithTimeout(ctx, 5*time.Second)
defer cancelFunc()
for {
select {
case <-timeoutCtx.Done():
s.Fail("delete check timeout")
case report := <-c.Extension.GetReportChan():
reportInfo := report.(map[string]any)
log.Info("delete report info", zap.Any("reportInfo", reportInfo))
if reportInfo[hookutil.OpTypeKey] == hookutil.OpTypeStorage {
continue
}
s.Equal(hookutil.OpTypeDelete, reportInfo[hookutil.OpTypeKey])
s.EqualValues(2, reportInfo[hookutil.SuccessCntKey])
s.EqualValues(0, reportInfo[hookutil.RelatedCntKey])
return
}
}
}
go deleteCheckReport()
deleteResult, err := c.Proxy.Delete(ctx, &milvuspb.DeleteRequest{
DbName: dbName,
CollectionName: collectionName,
Expr: integration.Int64Field + " in [1, 2]",
})
if deleteResult.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
log.Warn("deleteResult fail reason", zap.String("reason", deleteResult.GetStatus().GetReason()))
}
s.NoError(err)
s.Equal(commonpb.ErrorCode_Success, deleteResult.GetStatus().GetErrorCode())
status, err := c.Proxy.ReleaseCollection(ctx, &milvuspb.ReleaseCollectionRequest{
CollectionName: collectionName,
})

View File

@ -47,6 +47,7 @@ import (
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/hookutil"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/etcd"
"github.com/milvus-io/milvus/pkg/util/paramtable"
@ -125,6 +126,8 @@ type MiniClusterV2 struct {
qnid atomic.Int64
datanodes []*grpcdatanode.Server
dnid atomic.Int64
Extension *ReportChanExtension
}
type OptionV2 func(cluster *MiniClusterV2)
@ -136,6 +139,8 @@ func StartMiniClusterV2(ctx context.Context, opts ...OptionV2) (*MiniClusterV2,
dnid: *atomic.NewInt64(20000),
}
paramtable.Init()
cluster.Extension = InitReportExtension()
cluster.params = DefaultParams()
cluster.clusterConfig = DefaultClusterConfig()
for _, opt := range opts {
@ -445,3 +450,32 @@ func (cluster *MiniClusterV2) GetAvailablePort() (int, error) {
defer listener.Close()
return listener.Addr().(*net.TCPAddr).Port, nil
}
func InitReportExtension() *ReportChanExtension {
e := NewReportChanExtension()
hookutil.InitOnceHook()
hookutil.Extension = e
return e
}
type ReportChanExtension struct {
reportChan chan any
}
func NewReportChanExtension() *ReportChanExtension {
return &ReportChanExtension{
reportChan: make(chan any),
}
}
func (r *ReportChanExtension) Report(info any) int {
select {
case r.reportChan <- info:
default:
}
return 1
}
func (r *ReportChanExtension) GetReportChan() <-chan any {
return r.reportChan
}

View File

@ -0,0 +1,411 @@
/*
* Licensed to the LF AI & Data foundation under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package partitionkey
import (
"context"
"fmt"
"testing"
"time"
"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/suite"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/util/hookutil"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/metric"
"github.com/milvus-io/milvus/tests/integration"
)
type PartitionKeySuite struct {
integration.MiniClusterSuite
}
func (s *PartitionKeySuite) TestPartitionKey() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
c := s.Cluster
const (
dim = 128
dbName = ""
rowNum = 1000
)
collectionName := "TestPartitionKey" + funcutil.GenRandomStr()
schema := integration.ConstructSchema(collectionName, dim, false)
schema.Fields = append(schema.Fields, &schemapb.FieldSchema{
FieldID: 102,
Name: "pid",
Description: "",
DataType: schemapb.DataType_Int64,
TypeParams: nil,
IndexParams: nil,
IsPartitionKey: true,
})
marshaledSchema, err := proto.Marshal(schema)
s.NoError(err)
createCollectionStatus, err := c.Proxy.CreateCollection(ctx, &milvuspb.CreateCollectionRequest{
DbName: dbName,
CollectionName: collectionName,
Schema: marshaledSchema,
ShardsNum: common.DefaultShardsNum,
})
s.NoError(err)
if createCollectionStatus.GetErrorCode() != commonpb.ErrorCode_Success {
log.Warn("createCollectionStatus fail reason", zap.String("reason", createCollectionStatus.GetReason()))
}
s.Equal(createCollectionStatus.GetErrorCode(), commonpb.ErrorCode_Success)
{
pkColumn := integration.NewInt64FieldDataWithStart(integration.Int64Field, rowNum, 0)
fVecColumn := integration.NewFloatVectorFieldData(integration.FloatVecField, rowNum, dim)
partitionKeyColumn := integration.NewInt64SameFieldData("pid", rowNum, 1)
hashKeys := integration.GenerateHashKeys(rowNum)
insertResult, err := c.Proxy.Insert(ctx, &milvuspb.InsertRequest{
DbName: dbName,
CollectionName: collectionName,
FieldsData: []*schemapb.FieldData{pkColumn, fVecColumn, partitionKeyColumn},
HashKeys: hashKeys,
NumRows: uint32(rowNum),
})
s.NoError(err)
s.Equal(insertResult.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success)
}
{
pkColumn := integration.NewInt64FieldDataWithStart(integration.Int64Field, rowNum, rowNum)
fVecColumn := integration.NewFloatVectorFieldData(integration.FloatVecField, rowNum, dim)
partitionKeyColumn := integration.NewInt64SameFieldData("pid", rowNum, 2)
hashKeys := integration.GenerateHashKeys(rowNum)
insertResult, err := c.Proxy.Insert(ctx, &milvuspb.InsertRequest{
DbName: dbName,
CollectionName: collectionName,
FieldsData: []*schemapb.FieldData{pkColumn, fVecColumn, partitionKeyColumn},
HashKeys: hashKeys,
NumRows: uint32(rowNum),
})
s.NoError(err)
s.Equal(insertResult.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success)
}
{
pkColumn := integration.NewInt64FieldDataWithStart(integration.Int64Field, rowNum, rowNum*2)
fVecColumn := integration.NewFloatVectorFieldData(integration.FloatVecField, rowNum, dim)
partitionKeyColumn := integration.NewInt64SameFieldData("pid", rowNum, 3)
hashKeys := integration.GenerateHashKeys(rowNum)
insertResult, err := c.Proxy.Insert(ctx, &milvuspb.InsertRequest{
DbName: dbName,
CollectionName: collectionName,
FieldsData: []*schemapb.FieldData{pkColumn, fVecColumn, partitionKeyColumn},
HashKeys: hashKeys,
NumRows: uint32(rowNum),
})
s.NoError(err)
s.Equal(insertResult.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success)
}
flushResp, err := c.Proxy.Flush(ctx, &milvuspb.FlushRequest{
DbName: dbName,
CollectionNames: []string{collectionName},
})
s.NoError(err)
segmentIDs, has := flushResp.GetCollSegIDs()[collectionName]
ids := segmentIDs.GetData()
s.Require().NotEmpty(segmentIDs)
s.Require().True(has)
flushTs, has := flushResp.GetCollFlushTs()[collectionName]
s.True(has)
segments, err := c.MetaWatcher.ShowSegments()
s.NoError(err)
s.NotEmpty(segments)
for _, segment := range segments {
log.Info("ShowSegments result", zap.String("segment", segment.String()))
}
s.WaitForFlush(ctx, ids, flushTs, dbName, collectionName)
// create index
createIndexStatus, err := c.Proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{
CollectionName: collectionName,
FieldName: integration.FloatVecField,
IndexName: "_default",
ExtraParams: integration.ConstructIndexParam(dim, integration.IndexFaissIvfFlat, metric.L2),
})
if createIndexStatus.GetErrorCode() != commonpb.ErrorCode_Success {
log.Warn("createIndexStatus fail reason", zap.String("reason", createIndexStatus.GetReason()))
}
s.NoError(err)
s.Equal(commonpb.ErrorCode_Success, createIndexStatus.GetErrorCode())
s.WaitForIndexBuilt(ctx, collectionName, integration.FloatVecField)
// load
loadStatus, err := c.Proxy.LoadCollection(ctx, &milvuspb.LoadCollectionRequest{
DbName: dbName,
CollectionName: collectionName,
})
s.NoError(err)
if loadStatus.GetErrorCode() != commonpb.ErrorCode_Success {
log.Warn("loadStatus fail reason", zap.String("reason", loadStatus.GetReason()))
}
s.Equal(commonpb.ErrorCode_Success, loadStatus.GetErrorCode())
s.WaitForLoad(ctx, collectionName)
{
// search without partition key
expr := fmt.Sprintf("%s > 0", integration.Int64Field)
nq := 10
topk := 10
roundDecimal := -1
params := integration.GetSearchParams(integration.IndexFaissIvfFlat, metric.L2)
searchReq := integration.ConstructSearchRequest("", collectionName, expr,
integration.FloatVecField, schemapb.DataType_FloatVector, nil, metric.L2, params, nq, dim, topk, roundDecimal)
searchCheckReport := func() {
timeoutCtx, cancelFunc := context.WithTimeout(ctx, 5*time.Second)
defer cancelFunc()
for {
select {
case <-timeoutCtx.Done():
s.Fail("search check timeout")
case report := <-c.Extension.GetReportChan():
reportInfo := report.(map[string]any)
log.Info("search report info", zap.Any("reportInfo", reportInfo))
if reportInfo[hookutil.OpTypeKey] == hookutil.OpTypeStorage {
continue
}
s.Equal(hookutil.OpTypeSearch, reportInfo[hookutil.OpTypeKey])
s.NotEqualValues(0, reportInfo[hookutil.DataSizeKey])
s.EqualValues(rowNum*3, reportInfo[hookutil.RelatedCntKey])
return
}
}
}
go searchCheckReport()
searchResult, err := c.Proxy.Search(ctx, searchReq)
if searchResult.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
log.Warn("searchResult fail reason", zap.String("reason", searchResult.GetStatus().GetReason()))
}
s.NoError(err)
s.Equal(commonpb.ErrorCode_Success, searchResult.GetStatus().GetErrorCode())
}
{
// search without partition key
expr := fmt.Sprintf("%s > 0 && pid == 1", integration.Int64Field)
nq := 10
topk := 10
roundDecimal := -1
params := integration.GetSearchParams(integration.IndexFaissIvfFlat, metric.L2)
searchReq := integration.ConstructSearchRequest("", collectionName, expr,
integration.FloatVecField, schemapb.DataType_FloatVector, nil, metric.L2, params, nq, dim, topk, roundDecimal)
searchCheckReport := func() {
timeoutCtx, cancelFunc := context.WithTimeout(ctx, 5*time.Second)
defer cancelFunc()
for {
select {
case <-timeoutCtx.Done():
s.Fail("search check timeout")
case report := <-c.Extension.GetReportChan():
reportInfo := report.(map[string]any)
log.Info("search report info", zap.Any("reportInfo", reportInfo))
if reportInfo[hookutil.OpTypeKey] == hookutil.OpTypeStorage {
continue
}
s.Equal(hookutil.OpTypeSearch, reportInfo[hookutil.OpTypeKey])
s.NotEqualValues(0, reportInfo[hookutil.DataSizeKey])
s.EqualValues(rowNum, reportInfo[hookutil.RelatedCntKey])
return
}
}
}
go searchCheckReport()
searchResult, err := c.Proxy.Search(ctx, searchReq)
if searchResult.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
log.Warn("searchResult fail reason", zap.String("reason", searchResult.GetStatus().GetReason()))
}
s.NoError(err)
s.Equal(commonpb.ErrorCode_Success, searchResult.GetStatus().GetErrorCode())
}
{
// query without partition key
queryCheckReport := func() {
timeoutCtx, cancelFunc := context.WithTimeout(ctx, 5*time.Second)
defer cancelFunc()
for {
select {
case <-timeoutCtx.Done():
s.Fail("query check timeout")
case report := <-c.Extension.GetReportChan():
reportInfo := report.(map[string]any)
log.Info("query report info", zap.Any("reportInfo", reportInfo))
if reportInfo[hookutil.OpTypeKey] == hookutil.OpTypeStorage {
continue
}
s.Equal(hookutil.OpTypeQuery, reportInfo[hookutil.OpTypeKey])
s.NotEqualValues(0, reportInfo[hookutil.DataSizeKey])
s.EqualValues(3*rowNum, reportInfo[hookutil.RelatedCntKey])
return
}
}
}
go queryCheckReport()
queryResult, err := c.Proxy.Query(ctx, &milvuspb.QueryRequest{
DbName: dbName,
CollectionName: collectionName,
Expr: "",
OutputFields: []string{"count(*)"},
})
if queryResult.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
log.Warn("searchResult fail reason", zap.String("reason", queryResult.GetStatus().GetReason()))
}
s.NoError(err)
s.Equal(commonpb.ErrorCode_Success, queryResult.GetStatus().GetErrorCode())
}
{
// query with partition key
queryCheckReport := func() {
timeoutCtx, cancelFunc := context.WithTimeout(ctx, 5*time.Second)
defer cancelFunc()
for {
select {
case <-timeoutCtx.Done():
s.Fail("query check timeout")
case report := <-c.Extension.GetReportChan():
reportInfo := report.(map[string]any)
log.Info("query report info", zap.Any("reportInfo", reportInfo))
if reportInfo[hookutil.OpTypeKey] == hookutil.OpTypeStorage {
continue
}
s.Equal(hookutil.OpTypeQuery, reportInfo[hookutil.OpTypeKey])
s.NotEqualValues(0, reportInfo[hookutil.DataSizeKey])
s.EqualValues(rowNum, reportInfo[hookutil.RelatedCntKey])
return
}
}
}
go queryCheckReport()
queryResult, err := c.Proxy.Query(ctx, &milvuspb.QueryRequest{
DbName: dbName,
CollectionName: collectionName,
Expr: "pid == 1",
OutputFields: []string{"count(*)"},
})
if queryResult.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
log.Warn("searchResult fail reason", zap.String("reason", queryResult.GetStatus().GetReason()))
}
s.NoError(err)
s.Equal(commonpb.ErrorCode_Success, queryResult.GetStatus().GetErrorCode())
}
{
// delete without partition key
deleteCheckReport := func() {
timeoutCtx, cancelFunc := context.WithTimeout(ctx, 5*time.Second)
defer cancelFunc()
for {
select {
case <-timeoutCtx.Done():
s.Fail("delete check timeout")
case report := <-c.Extension.GetReportChan():
reportInfo := report.(map[string]any)
log.Info("delete report info", zap.Any("reportInfo", reportInfo))
if reportInfo[hookutil.OpTypeKey] == hookutil.OpTypeStorage {
continue
}
s.Equal(hookutil.OpTypeDelete, reportInfo[hookutil.OpTypeKey])
s.EqualValues(rowNum, reportInfo[hookutil.SuccessCntKey])
s.EqualValues(rowNum, reportInfo[hookutil.RelatedCntKey])
return
}
}
}
go deleteCheckReport()
deleteResult, err := c.Proxy.Delete(ctx, &milvuspb.DeleteRequest{
DbName: dbName,
CollectionName: collectionName,
Expr: integration.Int64Field + " < 1000",
})
if deleteResult.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
log.Warn("deleteResult fail reason", zap.String("reason", deleteResult.GetStatus().GetReason()))
}
s.NoError(err)
s.Equal(commonpb.ErrorCode_Success, deleteResult.GetStatus().GetErrorCode())
}
{
// delete with partition key
deleteCheckReport := func() {
timeoutCtx, cancelFunc := context.WithTimeout(ctx, 5*time.Second)
defer cancelFunc()
for {
select {
case <-timeoutCtx.Done():
s.Fail("delete check timeout")
case report := <-c.Extension.GetReportChan():
reportInfo := report.(map[string]any)
log.Info("delete report info", zap.Any("reportInfo", reportInfo))
if reportInfo[hookutil.OpTypeKey] == hookutil.OpTypeStorage {
continue
}
s.Equal(hookutil.OpTypeDelete, reportInfo[hookutil.OpTypeKey])
s.EqualValues(rowNum, reportInfo[hookutil.SuccessCntKey])
s.EqualValues(rowNum, reportInfo[hookutil.RelatedCntKey])
return
}
}
}
go deleteCheckReport()
deleteResult, err := c.Proxy.Delete(ctx, &milvuspb.DeleteRequest{
DbName: dbName,
CollectionName: collectionName,
Expr: integration.Int64Field + " < 2000 && pid == 2",
})
if deleteResult.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
log.Warn("deleteResult fail reason", zap.String("reason", deleteResult.GetStatus().GetReason()))
}
s.NoError(err)
s.Equal(commonpb.ErrorCode_Success, deleteResult.GetStatus().GetErrorCode())
}
}
func TestPartitionKey(t *testing.T) {
suite.Run(t, new(PartitionKeySuite))
}

View File

@ -28,6 +28,7 @@ import (
"go.uber.org/zap/zapcore"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/util/hookutil"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/etcd"
"github.com/milvus-io/milvus/pkg/util/paramtable"
@ -104,6 +105,24 @@ func (s *MiniClusterSuite) SetupTest() {
s.Cluster = c
// start mini cluster
nodeIDCheckReport := func() {
timeoutCtx, cancelFunc := context.WithTimeout(ctx, 5*time.Second)
defer cancelFunc()
for {
select {
case <-timeoutCtx.Done():
s.Fail("node id check timeout")
case report := <-c.Extension.GetReportChan():
reportInfo := report.(map[string]any)
s.T().Log("node id report info: ", reportInfo)
s.Equal(hookutil.OpTypeNodeID, reportInfo[hookutil.OpTypeKey])
s.NotEqualValues(0, reportInfo[hookutil.NodeIDKey])
return
}
}
}
go nodeIDCheckReport()
s.Require().NoError(s.Cluster.Start())
}

View File

@ -58,7 +58,39 @@ func NewInt64FieldData(fieldName string, numRows int) *schemapb.FieldData {
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_LongData{
LongData: &schemapb.LongArray{
Data: GenerateInt64Array(numRows),
Data: GenerateInt64Array(numRows, 0),
},
},
},
},
}
}
func NewInt64FieldDataWithStart(fieldName string, numRows int, start int64) *schemapb.FieldData {
return &schemapb.FieldData{
Type: schemapb.DataType_Int64,
FieldName: fieldName,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_LongData{
LongData: &schemapb.LongArray{
Data: GenerateInt64Array(numRows, start),
},
},
},
},
}
}
func NewInt64SameFieldData(fieldName string, numRows int, value int64) *schemapb.FieldData {
return &schemapb.FieldData{
Type: schemapb.DataType_Int64,
FieldName: fieldName,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_LongData{
LongData: &schemapb.LongArray{
Data: GenerateSameInt64Array(numRows, value),
},
},
},
@ -144,10 +176,18 @@ func NewBinaryVectorFieldData(fieldName string, numRows, dim int) *schemapb.Fiel
}
}
func GenerateInt64Array(numRows int) []int64 {
func GenerateInt64Array(numRows int, start int64) []int64 {
ret := make([]int64, numRows)
for i := 0; i < numRows; i++ {
ret[i] = int64(i)
ret[i] = int64(i) + start
}
return ret
}
func GenerateSameInt64Array(numRows int, value int64) []int64 {
ret := make([]int64, numRows)
for i := 0; i < numRows; i++ {
ret[i] = value
}
return ret
}