mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-11-29 18:38:44 +08:00
feat: add more operation detail info for better allocation (#30438)
issue: #30436 --------- Signed-off-by: SimFG <bang.fu@zilliz.com>
This commit is contained in:
parent
fbff46a005
commit
b1a1cca10b
@ -140,6 +140,8 @@ issues:
|
||||
- which can be annoying to use
|
||||
# Binds to all network interfaces
|
||||
- G102
|
||||
# Use of unsafe calls should be audited
|
||||
- G103
|
||||
# Errors unhandled
|
||||
- G104
|
||||
# file/folder Permission
|
||||
@ -164,4 +166,5 @@ issues:
|
||||
max-same-issues: 0
|
||||
|
||||
service:
|
||||
golangci-lint-version: 1.55.2 # use the fixed version to not introduce new linters unexpectedly
|
||||
# use the fixed version to not introduce new linters unexpectedly
|
||||
golangci-lint-version: 1.55.2
|
||||
|
17
go.mod
17
go.mod
@ -9,6 +9,7 @@ require (
|
||||
github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.1.0
|
||||
github.com/aliyun/credentials-go v1.2.7
|
||||
github.com/antlr/antlr4/runtime/Go/antlr v0.0.0-20210826220005-b48c857c3a0e
|
||||
github.com/apache/arrow/go/v12 v12.0.1
|
||||
github.com/apache/pulsar-client-go v0.6.1-0.20210728062540-29414db801a7
|
||||
github.com/bits-and-blooms/bloom/v3 v3.0.1
|
||||
github.com/blang/semver/v4 v4.0.0
|
||||
@ -17,6 +18,7 @@ require (
|
||||
github.com/cockroachdb/errors v1.9.1
|
||||
github.com/containerd/cgroups/v3 v3.0.3 // indirect
|
||||
github.com/gin-gonic/gin v1.9.1
|
||||
github.com/go-playground/validator/v10 v10.14.0
|
||||
github.com/gofrs/flock v0.8.1
|
||||
github.com/gogo/protobuf v1.3.2
|
||||
github.com/golang/protobuf v1.5.3
|
||||
@ -26,9 +28,11 @@ require (
|
||||
github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d
|
||||
github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240317125658-67a0f065c1de
|
||||
github.com/minio/minio-go/v7 v7.0.61
|
||||
github.com/pingcap/log v1.1.1-0.20221015072633-39906604fb81
|
||||
github.com/prometheus/client_golang v1.14.0
|
||||
github.com/prometheus/client_model v0.3.0
|
||||
github.com/prometheus/common v0.42.0
|
||||
github.com/quasilyte/go-ruleguard/dsl v0.3.22
|
||||
github.com/samber/lo v1.27.0
|
||||
github.com/sbinet/npyio v0.6.0
|
||||
github.com/soheilhy/cmux v0.1.5
|
||||
@ -36,6 +40,7 @@ require (
|
||||
github.com/spf13/viper v1.8.1
|
||||
github.com/stretchr/testify v1.8.4
|
||||
github.com/tecbot/gorocksdb v0.0.0-20191217155057-f0fad39f321c
|
||||
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common v1.0.865
|
||||
github.com/tidwall/gjson v1.14.4
|
||||
github.com/tikv/client-go/v2 v2.0.4
|
||||
go.etcd.io/etcd/api/v3 v3.5.5
|
||||
@ -49,6 +54,7 @@ require (
|
||||
go.uber.org/zap v1.24.0
|
||||
golang.org/x/crypto v0.16.0
|
||||
golang.org/x/exp v0.0.0-20230728194245-b0cb94b80691
|
||||
golang.org/x/net v0.19.0
|
||||
golang.org/x/oauth2 v0.8.0
|
||||
golang.org/x/sync v0.5.0
|
||||
golang.org/x/text v0.14.0
|
||||
@ -56,18 +62,9 @@ require (
|
||||
google.golang.org/grpc/examples v0.0.0-20220617181431-3e7b97febc7f
|
||||
)
|
||||
|
||||
require github.com/apache/arrow/go/v12 v12.0.1
|
||||
|
||||
require github.com/milvus-io/milvus-storage/go v0.0.0-20231227072638-ebd0b8e56d70
|
||||
|
||||
require (
|
||||
github.com/go-playground/validator/v10 v10.14.0
|
||||
github.com/milvus-io/milvus/pkg v0.0.0-00010101000000-000000000000
|
||||
github.com/pingcap/log v1.1.1-0.20221015072633-39906604fb81
|
||||
github.com/quasilyte/go-ruleguard/dsl v0.3.22
|
||||
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common v1.0.865
|
||||
golang.org/x/net v0.19.0
|
||||
)
|
||||
require github.com/milvus-io/milvus/pkg v0.0.0-00010101000000-000000000000
|
||||
|
||||
require (
|
||||
cloud.google.com/go/compute v1.20.1 // indirect
|
||||
|
@ -189,6 +189,7 @@ struct SearchResult {
|
||||
public:
|
||||
int64_t total_nq_;
|
||||
int64_t unity_topK_;
|
||||
int64_t total_data_cnt_;
|
||||
void* segment_;
|
||||
|
||||
// first fill data during search, and then update data after reducing search results
|
||||
@ -223,6 +224,7 @@ struct RetrieveResult {
|
||||
RetrieveResult() = default;
|
||||
|
||||
public:
|
||||
int64_t total_data_cnt_;
|
||||
void* segment_;
|
||||
std::vector<int64_t> result_offsets_;
|
||||
std::vector<DataArray> field_data_;
|
||||
|
@ -71,6 +71,7 @@ empty_search_result(int64_t num_queries, SearchInfo& search_info) {
|
||||
SearchResult final_result;
|
||||
final_result.total_nq_ = num_queries;
|
||||
final_result.unity_topK_ = 0; // no result
|
||||
final_result.total_data_cnt_ = 0;
|
||||
return final_result;
|
||||
}
|
||||
|
||||
@ -212,6 +213,7 @@ ExecPlanNodeVisitor::VectorVisitorImpl(VectorPlanNode& node) {
|
||||
timestamp_,
|
||||
final_view,
|
||||
search_result);
|
||||
search_result.total_data_cnt_ = final_view.size();
|
||||
if (search_result.vector_iterators_.has_value()) {
|
||||
std::vector<GroupByValueType> group_by_values;
|
||||
GroupBy(search_result.vector_iterators_.value(),
|
||||
@ -239,6 +241,7 @@ wrap_num_entities(int64_t cnt) {
|
||||
auto scalar = arr.mutable_scalars();
|
||||
scalar->mutable_long_data()->mutable_data()->Add(cnt);
|
||||
retrieve_result->field_data_ = {arr};
|
||||
retrieve_result->total_data_cnt_ = 0;
|
||||
return retrieve_result;
|
||||
}
|
||||
|
||||
@ -249,6 +252,7 @@ ExecPlanNodeVisitor::visit(RetrievePlanNode& node) {
|
||||
dynamic_cast<const segcore::SegmentInternalInterface*>(&segment_);
|
||||
AssertInfo(segment, "Support SegmentSmallIndex Only");
|
||||
RetrieveResult retrieve_result;
|
||||
retrieve_result.total_data_cnt_ = 0;
|
||||
|
||||
auto active_count = segment->get_active_count(timestamp_);
|
||||
|
||||
@ -295,10 +299,12 @@ ExecPlanNodeVisitor::visit(RetrievePlanNode& node) {
|
||||
if (node.is_count_) {
|
||||
auto cnt = bitset_holder.size() - bitset_holder.count();
|
||||
retrieve_result = *(wrap_num_entities(cnt));
|
||||
retrieve_result.total_data_cnt_ = bitset_holder.size();
|
||||
retrieve_result_opt_ = std::move(retrieve_result);
|
||||
return;
|
||||
}
|
||||
|
||||
retrieve_result.total_data_cnt_ = bitset_holder.size();
|
||||
bool false_filtered_out = false;
|
||||
if (get_cache_offset) {
|
||||
segment->timestamp_filter(bitset_holder, cache_offsets, timestamp_);
|
||||
|
@ -348,12 +348,14 @@ ReduceHelper::GetSearchResultDataSlice(int slice_index) {
|
||||
auto nq_end = slice_nqs_prefix_sum_[slice_index + 1];
|
||||
|
||||
int64_t result_count = 0;
|
||||
int64_t all_search_count = 0;
|
||||
for (auto search_result : search_results_) {
|
||||
AssertInfo(search_result->topk_per_nq_prefix_sum_.size() ==
|
||||
search_result->total_nq_ + 1,
|
||||
"incorrect topk_per_nq_prefix_sum_ size in search result");
|
||||
result_count += search_result->topk_per_nq_prefix_sum_[nq_end] -
|
||||
search_result->topk_per_nq_prefix_sum_[nq_begin];
|
||||
all_search_count += search_result->total_data_cnt_;
|
||||
}
|
||||
|
||||
auto search_result_data =
|
||||
@ -363,6 +365,8 @@ ReduceHelper::GetSearchResultDataSlice(int slice_index) {
|
||||
search_result_data->set_num_queries(nq_end - nq_begin);
|
||||
search_result_data->mutable_topks()->Resize(nq_end - nq_begin, 0);
|
||||
|
||||
search_result_data->set_all_search_count(all_search_count);
|
||||
|
||||
// `result_pairs` contains the SearchResult and result_offset info, used for filling output fields
|
||||
std::vector<std::pair<SearchResult*, int64_t>> result_pairs(result_count);
|
||||
|
||||
|
@ -100,6 +100,7 @@ SegmentInternalInterface::Retrieve(const query::RetrievePlan* plan,
|
||||
fmt::format("query results exceed the limit size ", limit_size));
|
||||
}
|
||||
|
||||
results->set_all_retrieve_count(retrieve_results.total_data_cnt_);
|
||||
if (plan->plan_node_->is_count_) {
|
||||
AssertInfo(retrieve_results.field_data_.size() == 1,
|
||||
"count result should only have one column");
|
||||
|
@ -205,7 +205,7 @@ func (ib *indexBuilder) run() {
|
||||
}
|
||||
}
|
||||
|
||||
func getBinLogIds(segment *SegmentInfo, fieldID int64) []int64 {
|
||||
func getBinLogIDs(segment *SegmentInfo, fieldID int64) []int64 {
|
||||
binlogIDs := make([]int64, 0)
|
||||
for _, fieldBinLog := range segment.GetBinlogs() {
|
||||
if fieldBinLog.GetFieldID() == fieldID {
|
||||
@ -299,7 +299,7 @@ func (ib *indexBuilder) process(buildID UniqueID) bool {
|
||||
FieldID: partitionKeyField.FieldID,
|
||||
FieldName: partitionKeyField.Name,
|
||||
FieldType: int32(partitionKeyField.DataType),
|
||||
DataIds: getBinLogIds(segment, partitionKeyField.FieldID),
|
||||
DataIds: getBinLogIDs(segment, partitionKeyField.FieldID),
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -333,7 +333,7 @@ func (ib *indexBuilder) process(buildID UniqueID) bool {
|
||||
}
|
||||
|
||||
fieldID := ib.meta.indexMeta.GetFieldIDByIndexID(meta.CollectionID, meta.IndexID)
|
||||
binlogIDs := getBinLogIds(segment, fieldID)
|
||||
binlogIDs := getBinLogIDs(segment, fieldID)
|
||||
if isDiskANNIndex(GetIndexType(indexParams)) {
|
||||
var err error
|
||||
indexParams, err = indexparams.UpdateDiskIndexBuildParams(Params, indexParams)
|
||||
|
@ -139,6 +139,31 @@ func (s *Server) getSystemInfoMetrics(
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (s *Server) getCollectionStorageMetrics(ctx context.Context) (*milvuspb.GetMetricsResponse, error) {
|
||||
coordTopology := metricsinfo.DataCoordTopology{
|
||||
Cluster: metricsinfo.DataClusterTopology{
|
||||
Self: s.getDataCoordMetrics(ctx),
|
||||
},
|
||||
Connections: metricsinfo.ConnTopology{
|
||||
Name: metricsinfo.ConstructComponentName(typeutil.DataCoordRole, paramtable.GetNodeID()),
|
||||
ConnectedComponents: []metricsinfo.ConnectionInfo{},
|
||||
},
|
||||
}
|
||||
|
||||
resp := &milvuspb.GetMetricsResponse{
|
||||
Status: merr.Success(),
|
||||
ComponentName: metricsinfo.ConstructComponentName(typeutil.DataCoordRole, paramtable.GetNodeID()),
|
||||
}
|
||||
var err error
|
||||
resp.Response, err = metricsinfo.MarshalTopology(coordTopology)
|
||||
if err != nil {
|
||||
resp.Status = merr.Status(err)
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// getDataCoordMetrics composes datacoord infos
|
||||
func (s *Server) getDataCoordMetrics(ctx context.Context) metricsinfo.DataCoordInfos {
|
||||
ret := metricsinfo.DataCoordInfos{
|
||||
|
@ -236,8 +236,8 @@ func TestLastExpireReset(t *testing.T) {
|
||||
assert.Equal(t, expire1, segment1.GetLastExpireTime())
|
||||
assert.Equal(t, expire2, segment2.GetLastExpireTime())
|
||||
assert.True(t, segment3.GetLastExpireTime() > expire3)
|
||||
flushableSegIds, _ := newSegmentManager.GetFlushableSegments(context.Background(), channelName, expire3)
|
||||
assert.ElementsMatch(t, []UniqueID{segmentID1, segmentID2}, flushableSegIds) // segment1 and segment2 can be flushed
|
||||
flushableSegIDs, _ := newSegmentManager.GetFlushableSegments(context.Background(), channelName, expire3)
|
||||
assert.ElementsMatch(t, []UniqueID{segmentID1, segmentID2}, flushableSegIDs) // segment1 and segment2 can be flushed
|
||||
newAlloc, err := newSegmentManager.AllocSegment(context.Background(), collID, 0, channelName, 2000)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, segmentID3, newAlloc[0].SegmentID) // segment3 still can be used to allocate
|
||||
|
@ -94,7 +94,7 @@ type rootCoordCreatorFunc func(ctx context.Context) (types.RootCoordClient, erro
|
||||
// makes sure Server implements `DataCoord`
|
||||
var _ types.DataCoord = (*Server)(nil)
|
||||
|
||||
var Params *paramtable.ComponentParam = paramtable.Get()
|
||||
var Params = paramtable.Get()
|
||||
|
||||
// Server implements `types.DataCoord`
|
||||
// handles Data Coordinator related jobs
|
||||
@ -161,6 +161,11 @@ type Server struct {
|
||||
broker broker.Broker
|
||||
}
|
||||
|
||||
type CollectionNameInfo struct {
|
||||
CollectionName string
|
||||
DBName string
|
||||
}
|
||||
|
||||
// ServerHelper datacoord server injection helper
|
||||
type ServerHelper struct {
|
||||
eventAfterHandleDataNodeTt func()
|
||||
|
@ -34,6 +34,7 @@ import (
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
clientv3 "go.etcd.io/etcd/client/v3"
|
||||
"go.uber.org/atomic"
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
@ -3836,3 +3837,67 @@ func TestUpdateAutoBalanceConfigLoop(t *testing.T) {
|
||||
wg.Wait()
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetCollectionStorage(t *testing.T) {
|
||||
paramtable.Init()
|
||||
mockSession := sessionutil.NewMockSession(t)
|
||||
mockSession.EXPECT().GetAddress().Return("localhost:8888")
|
||||
size := atomic.NewInt64(100)
|
||||
|
||||
s := &Server{
|
||||
session: mockSession,
|
||||
meta: &meta{
|
||||
segments: &SegmentsInfo{
|
||||
segments: map[UniqueID]*SegmentInfo{
|
||||
1: {
|
||||
SegmentInfo: &datapb.SegmentInfo{
|
||||
ID: 1,
|
||||
State: commonpb.SegmentState_Growing,
|
||||
CollectionID: 10001,
|
||||
PartitionID: 10000,
|
||||
NumOfRows: 10,
|
||||
},
|
||||
size: *size,
|
||||
},
|
||||
2: {
|
||||
SegmentInfo: &datapb.SegmentInfo{
|
||||
ID: 2,
|
||||
State: commonpb.SegmentState_Dropped,
|
||||
CollectionID: 10001,
|
||||
PartitionID: 10000,
|
||||
NumOfRows: 10,
|
||||
},
|
||||
size: *size,
|
||||
},
|
||||
3: {
|
||||
SegmentInfo: &datapb.SegmentInfo{
|
||||
ID: 3,
|
||||
State: commonpb.SegmentState_Flushed,
|
||||
CollectionID: 10002,
|
||||
PartitionID: 9999,
|
||||
NumOfRows: 10,
|
||||
},
|
||||
size: *size,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
s.stateCode.Store(commonpb.StateCode_Healthy)
|
||||
|
||||
req, err := metricsinfo.ConstructRequestByMetricType(metricsinfo.CollectionStorageMetrics)
|
||||
assert.NoError(t, err)
|
||||
resp, err := s.GetMetrics(context.TODO(), req)
|
||||
assert.NoError(t, err)
|
||||
|
||||
var coordTopology metricsinfo.DataCoordTopology
|
||||
err = metricsinfo.UnmarshalTopology(resp.Response, &coordTopology)
|
||||
assert.NoError(t, err)
|
||||
|
||||
m := coordTopology.Cluster.Self.QuotaMetrics
|
||||
assert.NotNil(t, m)
|
||||
assert.Equal(t, int64(200), m.TotalBinlogSize)
|
||||
assert.Len(t, m.CollectionBinlogSize, 2)
|
||||
assert.Equal(t, int64(100), m.CollectionBinlogSize[10001])
|
||||
assert.Equal(t, int64(100), m.CollectionBinlogSize[10002])
|
||||
}
|
||||
|
@ -1029,6 +1029,23 @@ func (s *Server) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest
|
||||
zap.Any("metrics", metrics), // TODO(dragondriver): necessary? may be very large
|
||||
zap.Error(err))
|
||||
|
||||
return metrics, nil
|
||||
} else if metricType == metricsinfo.CollectionStorageMetrics {
|
||||
metrics, err := s.getCollectionStorageMetrics(ctx)
|
||||
if err != nil {
|
||||
log.Warn("DataCoord GetMetrics CollectionStorage failed", zap.Int64("nodeID", paramtable.GetNodeID()), zap.Error(err))
|
||||
return &milvuspb.GetMetricsResponse{
|
||||
Status: merr.Status(err),
|
||||
}, nil
|
||||
}
|
||||
|
||||
log.RatedDebug(60, "DataCoord.GetMetrics CollectionStorage",
|
||||
zap.Int64("nodeID", paramtable.GetNodeID()),
|
||||
zap.String("req", req.Request),
|
||||
zap.String("metricType", metricType),
|
||||
zap.Any("metrics", metrics),
|
||||
zap.Error(err))
|
||||
|
||||
return metrics, nil
|
||||
}
|
||||
|
||||
|
@ -232,7 +232,6 @@ func (c *SessionManagerImpl) SyncSegments(nodeID int64, req *datapb.SyncSegments
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Warn("failed to sync segments after retry", zap.Error(err))
|
||||
return err
|
||||
|
@ -108,20 +108,20 @@ func getVchanInfo(info *testInfo) *datapb.VchannelInfo {
|
||||
ufs = []*datapb.SegmentInfo{}
|
||||
}
|
||||
|
||||
var ufsIds []int64
|
||||
var fsIds []int64
|
||||
var ufsIDs []int64
|
||||
var fsIDs []int64
|
||||
for _, segmentInfo := range ufs {
|
||||
ufsIds = append(ufsIds, segmentInfo.ID)
|
||||
ufsIDs = append(ufsIDs, segmentInfo.ID)
|
||||
}
|
||||
for _, segmentInfo := range fs {
|
||||
fsIds = append(fsIds, segmentInfo.ID)
|
||||
fsIDs = append(fsIDs, segmentInfo.ID)
|
||||
}
|
||||
vi := &datapb.VchannelInfo{
|
||||
CollectionID: info.collID,
|
||||
ChannelName: info.chanName,
|
||||
SeekPosition: &msgpb.MsgPosition{},
|
||||
UnflushedSegmentIds: ufsIds,
|
||||
FlushedSegmentIds: fsIds,
|
||||
UnflushedSegmentIds: ufsIDs,
|
||||
FlushedSegmentIds: fsIDs,
|
||||
}
|
||||
return vi
|
||||
}
|
||||
@ -465,13 +465,13 @@ func (s *DataSyncServiceSuite) TestStartStop() {
|
||||
NumOfRows: 0,
|
||||
DmlPosition: &msgpb.MsgPosition{},
|
||||
}}
|
||||
var ufsIds []int64
|
||||
var fsIds []int64
|
||||
var ufsIDs []int64
|
||||
var fsIDs []int64
|
||||
for _, segmentInfo := range ufs {
|
||||
ufsIds = append(ufsIds, segmentInfo.ID)
|
||||
ufsIDs = append(ufsIDs, segmentInfo.ID)
|
||||
}
|
||||
for _, segmentInfo := range fs {
|
||||
fsIds = append(fsIds, segmentInfo.ID)
|
||||
fsIDs = append(fsIDs, segmentInfo.ID)
|
||||
}
|
||||
|
||||
watchInfo := &datapb.ChannelWatchInfo{
|
||||
@ -479,8 +479,8 @@ func (s *DataSyncServiceSuite) TestStartStop() {
|
||||
Vchan: &datapb.VchannelInfo{
|
||||
CollectionID: collMeta.ID,
|
||||
ChannelName: insertChannelName,
|
||||
UnflushedSegmentIds: ufsIds,
|
||||
FlushedSegmentIds: fsIds,
|
||||
UnflushedSegmentIds: ufsIDs,
|
||||
FlushedSegmentIds: fsIDs,
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -691,7 +691,7 @@ func TestInsert(t *testing.T) {
|
||||
mp5, _ = wrapWithDescribeColl(t, mp5, ReturnSuccess, 1, nil)
|
||||
mp5.EXPECT().Insert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{
|
||||
Status: &StatusSuccess,
|
||||
IDs: genIds(schemapb.DataType_Int64),
|
||||
IDs: genIDs(schemapb.DataType_Int64),
|
||||
InsertCnt: 3,
|
||||
}, nil).Once()
|
||||
testCases = append(testCases, testCase{
|
||||
@ -705,7 +705,7 @@ func TestInsert(t *testing.T) {
|
||||
mp6, _ = wrapWithDescribeColl(t, mp6, ReturnSuccess, 1, nil)
|
||||
mp6.EXPECT().Insert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{
|
||||
Status: &StatusSuccess,
|
||||
IDs: genIds(schemapb.DataType_VarChar),
|
||||
IDs: genIDs(schemapb.DataType_VarChar),
|
||||
InsertCnt: 3,
|
||||
}, nil).Once()
|
||||
testCases = append(testCases, testCase{
|
||||
@ -776,7 +776,7 @@ func TestInsertForDataType(t *testing.T) {
|
||||
}, nil).Once()
|
||||
mp.EXPECT().Insert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{
|
||||
Status: &StatusSuccess,
|
||||
IDs: genIds(schemapb.DataType_Int64),
|
||||
IDs: genIDs(schemapb.DataType_Int64),
|
||||
InsertCnt: 3,
|
||||
}, nil).Once()
|
||||
testEngine := initHTTPServer(mp, true)
|
||||
@ -844,7 +844,7 @@ func TestReturnInt64(t *testing.T) {
|
||||
}, nil).Once()
|
||||
mp.EXPECT().Insert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{
|
||||
Status: &StatusSuccess,
|
||||
IDs: genIds(dataType),
|
||||
IDs: genIDs(dataType),
|
||||
InsertCnt: 3,
|
||||
}, nil).Once()
|
||||
testEngine := initHTTPServer(mp, true)
|
||||
@ -875,7 +875,7 @@ func TestReturnInt64(t *testing.T) {
|
||||
}, nil).Once()
|
||||
mp.EXPECT().Upsert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{
|
||||
Status: &StatusSuccess,
|
||||
IDs: genIds(dataType),
|
||||
IDs: genIDs(dataType),
|
||||
UpsertCnt: 3,
|
||||
}, nil).Once()
|
||||
testEngine := initHTTPServer(mp, true)
|
||||
@ -906,7 +906,7 @@ func TestReturnInt64(t *testing.T) {
|
||||
}, nil).Once()
|
||||
mp.EXPECT().Insert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{
|
||||
Status: &StatusSuccess,
|
||||
IDs: genIds(dataType),
|
||||
IDs: genIDs(dataType),
|
||||
InsertCnt: 3,
|
||||
}, nil).Once()
|
||||
testEngine := initHTTPServer(mp, true)
|
||||
@ -938,7 +938,7 @@ func TestReturnInt64(t *testing.T) {
|
||||
}, nil).Once()
|
||||
mp.EXPECT().Upsert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{
|
||||
Status: &StatusSuccess,
|
||||
IDs: genIds(dataType),
|
||||
IDs: genIDs(dataType),
|
||||
UpsertCnt: 3,
|
||||
}, nil).Once()
|
||||
testEngine := initHTTPServer(mp, true)
|
||||
@ -971,7 +971,7 @@ func TestReturnInt64(t *testing.T) {
|
||||
}, nil).Once()
|
||||
mp.EXPECT().Insert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{
|
||||
Status: &StatusSuccess,
|
||||
IDs: genIds(dataType),
|
||||
IDs: genIDs(dataType),
|
||||
InsertCnt: 3,
|
||||
}, nil).Once()
|
||||
testEngine := initHTTPServer(mp, true)
|
||||
@ -1002,7 +1002,7 @@ func TestReturnInt64(t *testing.T) {
|
||||
}, nil).Once()
|
||||
mp.EXPECT().Upsert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{
|
||||
Status: &StatusSuccess,
|
||||
IDs: genIds(dataType),
|
||||
IDs: genIDs(dataType),
|
||||
UpsertCnt: 3,
|
||||
}, nil).Once()
|
||||
testEngine := initHTTPServer(mp, true)
|
||||
@ -1033,7 +1033,7 @@ func TestReturnInt64(t *testing.T) {
|
||||
}, nil).Once()
|
||||
mp.EXPECT().Insert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{
|
||||
Status: &StatusSuccess,
|
||||
IDs: genIds(dataType),
|
||||
IDs: genIDs(dataType),
|
||||
InsertCnt: 3,
|
||||
}, nil).Once()
|
||||
testEngine := initHTTPServer(mp, true)
|
||||
@ -1065,7 +1065,7 @@ func TestReturnInt64(t *testing.T) {
|
||||
}, nil).Once()
|
||||
mp.EXPECT().Upsert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{
|
||||
Status: &StatusSuccess,
|
||||
IDs: genIds(dataType),
|
||||
IDs: genIDs(dataType),
|
||||
UpsertCnt: 3,
|
||||
}, nil).Once()
|
||||
testEngine := initHTTPServer(mp, true)
|
||||
@ -1132,7 +1132,7 @@ func TestUpsert(t *testing.T) {
|
||||
mp5, _ = wrapWithDescribeColl(t, mp5, ReturnSuccess, 1, nil)
|
||||
mp5.EXPECT().Upsert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{
|
||||
Status: &StatusSuccess,
|
||||
IDs: genIds(schemapb.DataType_Int64),
|
||||
IDs: genIDs(schemapb.DataType_Int64),
|
||||
UpsertCnt: 3,
|
||||
}, nil).Once()
|
||||
testCases = append(testCases, testCase{
|
||||
@ -1146,7 +1146,7 @@ func TestUpsert(t *testing.T) {
|
||||
mp6, _ = wrapWithDescribeColl(t, mp6, ReturnSuccess, 1, nil)
|
||||
mp6.EXPECT().Upsert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{
|
||||
Status: &StatusSuccess,
|
||||
IDs: genIds(schemapb.DataType_VarChar),
|
||||
IDs: genIDs(schemapb.DataType_VarChar),
|
||||
UpsertCnt: 3,
|
||||
}, nil).Once()
|
||||
testCases = append(testCases, testCase{
|
||||
@ -1198,8 +1198,8 @@ func TestUpsert(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func genIds(dataType schemapb.DataType) *schemapb.IDs {
|
||||
return generateIds(dataType, 3)
|
||||
func genIDs(dataType schemapb.DataType) *schemapb.IDs {
|
||||
return generateIDs(dataType, 3)
|
||||
}
|
||||
|
||||
func TestSearch(t *testing.T) {
|
||||
|
@ -37,7 +37,7 @@ func generatePrimaryField(datatype schemapb.DataType) schemapb.FieldSchema {
|
||||
}
|
||||
}
|
||||
|
||||
func generateIds(dataType schemapb.DataType, num int) *schemapb.IDs {
|
||||
func generateIDs(dataType schemapb.DataType, num int) *schemapb.IDs {
|
||||
var intArray []int64
|
||||
if num == 0 {
|
||||
intArray = []int64{}
|
||||
@ -684,7 +684,7 @@ func compareRows(row1 []map[string]interface{}, row2 []map[string]interface{}, c
|
||||
|
||||
func TestBuildQueryResp(t *testing.T) {
|
||||
outputFields := []string{FieldBookID, FieldWordCount, "author", "date"}
|
||||
rows, err := buildQueryResp(int64(0), outputFields, generateFieldData(), generateIds(schemapb.DataType_Int64, 3), DefaultScores, true) // []*schemapb.FieldData{&fieldData1, &fieldData2, &fieldData3}
|
||||
rows, err := buildQueryResp(int64(0), outputFields, generateFieldData(), generateIDs(schemapb.DataType_Int64, 3), DefaultScores, true) // []*schemapb.FieldData{&fieldData1, &fieldData2, &fieldData3}
|
||||
assert.Equal(t, nil, err)
|
||||
exceptRows := generateSearchResult(schemapb.DataType_Int64)
|
||||
assert.Equal(t, true, compareRows(rows, exceptRows, compareRow))
|
||||
@ -1298,7 +1298,7 @@ func TestBuildQueryResps(t *testing.T) {
|
||||
outputFields := []string{"XXX", "YYY"}
|
||||
outputFieldsList := [][]string{outputFields, {"$meta"}, {"$meta", FieldBookID, FieldBookIntro, "YYY"}}
|
||||
for _, theOutputFields := range outputFieldsList {
|
||||
rows, err := buildQueryResp(int64(0), theOutputFields, newFieldData(generateFieldData(), schemapb.DataType_None), generateIds(schemapb.DataType_Int64, 3), DefaultScores, true)
|
||||
rows, err := buildQueryResp(int64(0), theOutputFields, newFieldData(generateFieldData(), schemapb.DataType_None), generateIDs(schemapb.DataType_Int64, 3), DefaultScores, true)
|
||||
assert.Equal(t, nil, err)
|
||||
exceptRows := newSearchResult(generateSearchResult(schemapb.DataType_Int64))
|
||||
assert.Equal(t, true, compareRows(rows, exceptRows, compareRow))
|
||||
@ -1312,29 +1312,29 @@ func TestBuildQueryResps(t *testing.T) {
|
||||
schemapb.DataType_JSON, schemapb.DataType_Array,
|
||||
}
|
||||
for _, dateType := range dataTypes {
|
||||
_, err := buildQueryResp(int64(0), outputFields, newFieldData([]*schemapb.FieldData{}, dateType), generateIds(schemapb.DataType_Int64, 3), DefaultScores, true)
|
||||
_, err := buildQueryResp(int64(0), outputFields, newFieldData([]*schemapb.FieldData{}, dateType), generateIDs(schemapb.DataType_Int64, 3), DefaultScores, true)
|
||||
assert.Equal(t, nil, err)
|
||||
}
|
||||
|
||||
_, err := buildQueryResp(int64(0), outputFields, newFieldData([]*schemapb.FieldData{}, 1000), generateIds(schemapb.DataType_Int64, 3), DefaultScores, true)
|
||||
_, err := buildQueryResp(int64(0), outputFields, newFieldData([]*schemapb.FieldData{}, 1000), generateIDs(schemapb.DataType_Int64, 3), DefaultScores, true)
|
||||
assert.Equal(t, "the type(1000) of field(wrong-field-type) is not supported, use other sdk please", err.Error())
|
||||
|
||||
res, err := buildQueryResp(int64(0), outputFields, []*schemapb.FieldData{}, generateIds(schemapb.DataType_Int64, 3), DefaultScores, true)
|
||||
res, err := buildQueryResp(int64(0), outputFields, []*schemapb.FieldData{}, generateIDs(schemapb.DataType_Int64, 3), DefaultScores, true)
|
||||
assert.Equal(t, 3, len(res))
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
res, err = buildQueryResp(int64(0), outputFields, []*schemapb.FieldData{}, generateIds(schemapb.DataType_Int64, 3), DefaultScores, false)
|
||||
res, err = buildQueryResp(int64(0), outputFields, []*schemapb.FieldData{}, generateIDs(schemapb.DataType_Int64, 3), DefaultScores, false)
|
||||
assert.Equal(t, 3, len(res))
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
res, err = buildQueryResp(int64(0), outputFields, []*schemapb.FieldData{}, generateIds(schemapb.DataType_VarChar, 3), DefaultScores, true)
|
||||
res, err = buildQueryResp(int64(0), outputFields, []*schemapb.FieldData{}, generateIDs(schemapb.DataType_VarChar, 3), DefaultScores, true)
|
||||
assert.Equal(t, 3, len(res))
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
_, err = buildQueryResp(int64(0), outputFields, generateFieldData(), generateIds(schemapb.DataType_Int64, 3), DefaultScores, false)
|
||||
_, err = buildQueryResp(int64(0), outputFields, generateFieldData(), generateIDs(schemapb.DataType_Int64, 3), DefaultScores, false)
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
// len(rows) != len(scores), didn't show distance
|
||||
_, err = buildQueryResp(int64(0), outputFields, newFieldData(generateFieldData(), schemapb.DataType_None), generateIds(schemapb.DataType_Int64, 3), []float32{0.01, 0.04}, true)
|
||||
_, err = buildQueryResp(int64(0), outputFields, newFieldData(generateFieldData(), schemapb.DataType_None), generateIDs(schemapb.DataType_Int64, 3), []float32{0.01, 0.04}, true)
|
||||
assert.Equal(t, nil, err)
|
||||
}
|
||||
|
@ -125,16 +125,16 @@ func (suite *CatalogTestSuite) TestPartition() {
|
||||
}
|
||||
|
||||
func (suite *CatalogTestSuite) TestReleaseManyPartitions() {
|
||||
partitionIds := make([]int64, 0)
|
||||
partitionIDs := make([]int64, 0)
|
||||
for i := 1; i <= 150; i++ {
|
||||
suite.catalog.SavePartition(&querypb.PartitionLoadInfo{
|
||||
CollectionID: 1,
|
||||
PartitionID: int64(i),
|
||||
})
|
||||
partitionIds = append(partitionIds, int64(i))
|
||||
partitionIDs = append(partitionIDs, int64(i))
|
||||
}
|
||||
|
||||
err := suite.catalog.ReleasePartition(1, partitionIds...)
|
||||
err := suite.catalog.ReleasePartition(1, partitionIDs...)
|
||||
suite.NoError(err)
|
||||
partitions, err := suite.catalog.GetPartitions()
|
||||
suite.NoError(err)
|
||||
|
@ -16,6 +16,6 @@ type Segment struct {
|
||||
CreatedByCompaction bool
|
||||
SegmentState commonpb.SegmentState
|
||||
// IndexInfos []*SegmentIndex
|
||||
ReplicaIds []int64
|
||||
NodeIds []int64
|
||||
ReplicaIDs []int64
|
||||
NodeIDs []int64
|
||||
}
|
||||
|
@ -134,6 +134,7 @@ message SearchResults {
|
||||
// search request cost
|
||||
CostAggregation costAggregation = 13;
|
||||
map<string, uint64> channels_mvcc = 14;
|
||||
int64 all_search_count = 15;
|
||||
}
|
||||
|
||||
message CostAggregation {
|
||||
@ -175,6 +176,7 @@ message RetrieveResults {
|
||||
|
||||
// query request cost
|
||||
CostAggregation costAggregation = 13;
|
||||
int64 all_retrieve_count = 14;
|
||||
}
|
||||
|
||||
message LoadIndex {
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -9,6 +9,7 @@ message RetrieveResults {
|
||||
schema.IDs ids = 1;
|
||||
repeated int64 offset = 2;
|
||||
repeated schema.FieldData fields_data = 3;
|
||||
int64 all_retrieve_count = 4;
|
||||
}
|
||||
|
||||
message LoadFieldMeta {
|
||||
|
@ -10,6 +10,7 @@ import (
|
||||
"google.golang.org/grpc/metadata"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/mocks"
|
||||
"github.com/milvus-io/milvus/internal/util/hookutil"
|
||||
"github.com/milvus-io/milvus/pkg/util"
|
||||
"github.com/milvus-io/milvus/pkg/util/crypto"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
@ -140,5 +141,5 @@ func TestAuthenticationInterceptor(t *testing.T) {
|
||||
user, _ := parseMD(rawToken)
|
||||
assert.Equal(t, "mockUser", user)
|
||||
}
|
||||
hoo = defaultHook{}
|
||||
hoo = hookutil.DefaultHook{}
|
||||
}
|
||||
|
@ -177,7 +177,6 @@ func createStream(factory msgstream.Factory, pchans []pChan, repack repackFuncTy
|
||||
var err error
|
||||
|
||||
stream, err = factory.NewMsgStream(context.Background())
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
)
|
||||
|
||||
type cntReducer struct {
|
||||
@ -20,6 +21,7 @@ func (r *cntReducer) Reduce(results []*internalpb.RetrieveResults) (*milvuspb.Qu
|
||||
cnt += c
|
||||
}
|
||||
res := funcutil.WrapCntToQueryResults(cnt)
|
||||
res.Status = merr.Success()
|
||||
res.CollectionName = r.collectionName
|
||||
return res, nil
|
||||
}
|
||||
|
@ -2,93 +2,24 @@ package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"plugin"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/hook"
|
||||
"github.com/milvus-io/milvus/pkg/config"
|
||||
"github.com/milvus-io/milvus/internal/util/hookutil"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/metrics"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
)
|
||||
|
||||
type defaultHook struct{}
|
||||
|
||||
func (d defaultHook) VerifyAPIKey(key string) (string, error) {
|
||||
return "", errors.New("default hook, can't verify api key")
|
||||
}
|
||||
|
||||
func (d defaultHook) Init(params map[string]string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d defaultHook) Mock(ctx context.Context, req interface{}, fullMethod string) (bool, interface{}, error) {
|
||||
return false, nil, nil
|
||||
}
|
||||
|
||||
func (d defaultHook) Before(ctx context.Context, req interface{}, fullMethod string) (context.Context, error) {
|
||||
return ctx, nil
|
||||
}
|
||||
|
||||
func (d defaultHook) After(ctx context.Context, result interface{}, err error, fullMethod string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d defaultHook) Release() {}
|
||||
|
||||
var hoo hook.Hook
|
||||
|
||||
func initHook() error {
|
||||
path := Params.ProxyCfg.SoPath.GetValue()
|
||||
if path == "" {
|
||||
hoo = defaultHook{}
|
||||
return nil
|
||||
}
|
||||
|
||||
logger.Debug("start to load plugin", zap.String("path", path))
|
||||
p, err := plugin.Open(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("fail to open the plugin, error: %s", err.Error())
|
||||
}
|
||||
logger.Debug("plugin open")
|
||||
|
||||
h, err := p.Lookup("MilvusHook")
|
||||
if err != nil {
|
||||
return fmt.Errorf("fail to the 'MilvusHook' object in the plugin, error: %s", err.Error())
|
||||
}
|
||||
|
||||
var ok bool
|
||||
hoo, ok = h.(hook.Hook)
|
||||
if !ok {
|
||||
return fmt.Errorf("fail to convert the `Hook` interface")
|
||||
}
|
||||
if err = hoo.Init(paramtable.GetHookParams().SoConfig.GetValue()); err != nil {
|
||||
return fmt.Errorf("fail to init configs for the hook, error: %s", err.Error())
|
||||
}
|
||||
paramtable.GetHookParams().WatchHookWithPrefix("watch_hook", "", func(event *config.Event) {
|
||||
log.Info("receive the hook refresh event", zap.Any("event", event))
|
||||
go func() {
|
||||
soConfig := paramtable.GetHookParams().SoConfig.GetValue()
|
||||
log.Info("refresh hook configs", zap.Any("config", soConfig))
|
||||
if err = hoo.Init(soConfig); err != nil {
|
||||
log.Panic("fail to init configs for the hook when refreshing", zap.Error(err))
|
||||
}
|
||||
}()
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func UnaryServerHookInterceptor() grpc.UnaryServerInterceptor {
|
||||
if hookError := initHook(); hookError != nil {
|
||||
logger.Error("hook error", zap.String("path", Params.ProxyCfg.SoPath.GetValue()), zap.Error(hookError))
|
||||
hoo = defaultHook{}
|
||||
}
|
||||
hookutil.InitOnceHook()
|
||||
hoo = hookutil.Hoo
|
||||
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
|
||||
var (
|
||||
fullMethod = info.FullMethod
|
||||
@ -145,24 +76,13 @@ func getCurrentUser(ctx context.Context) string {
|
||||
return username
|
||||
}
|
||||
|
||||
// MockAPIHook is a mock hook for api key verification, ONLY FOR TEST
|
||||
type MockAPIHook struct {
|
||||
defaultHook
|
||||
mockErr error
|
||||
apiUser string
|
||||
}
|
||||
|
||||
func (m MockAPIHook) VerifyAPIKey(apiKey string) (string, error) {
|
||||
return m.apiUser, m.mockErr
|
||||
}
|
||||
|
||||
func SetMockAPIHook(apiUser string, mockErr error) {
|
||||
if apiUser == "" && mockErr == nil {
|
||||
hoo = defaultHook{}
|
||||
hoo = &hookutil.DefaultHook{}
|
||||
return
|
||||
}
|
||||
hoo = MockAPIHook{
|
||||
mockErr: mockErr,
|
||||
apiUser: apiUser,
|
||||
hoo = &hookutil.MockAPIHook{
|
||||
MockErr: mockErr,
|
||||
User: apiUser,
|
||||
}
|
||||
}
|
||||
|
@ -8,22 +8,11 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
"github.com/milvus-io/milvus/internal/util/hookutil"
|
||||
)
|
||||
|
||||
func TestInitHook(t *testing.T) {
|
||||
paramtable.Get().Save(Params.ProxyCfg.SoPath.Key, "")
|
||||
initHook()
|
||||
assert.IsType(t, defaultHook{}, hoo)
|
||||
|
||||
paramtable.Get().Save(Params.ProxyCfg.SoPath.Key, "/a/b/hook.so")
|
||||
err := initHook()
|
||||
assert.Error(t, err)
|
||||
paramtable.Get().Save(Params.ProxyCfg.SoPath.Key, "")
|
||||
}
|
||||
|
||||
type mockHook struct {
|
||||
defaultHook
|
||||
hookutil.DefaultHook
|
||||
mockRes interface{}
|
||||
mockErr error
|
||||
}
|
||||
@ -39,7 +28,7 @@ type req struct {
|
||||
type BeforeMockCtxKey int
|
||||
|
||||
type beforeMock struct {
|
||||
defaultHook
|
||||
hookutil.DefaultHook
|
||||
method string
|
||||
ctxKey BeforeMockCtxKey
|
||||
ctxValue string
|
||||
@ -60,7 +49,7 @@ type resp struct {
|
||||
}
|
||||
|
||||
type afterMock struct {
|
||||
defaultHook
|
||||
hookutil.DefaultHook
|
||||
method string
|
||||
err error
|
||||
}
|
||||
@ -129,7 +118,7 @@ func TestHookInterceptor(t *testing.T) {
|
||||
assert.Equal(t, re.method, afterHoo.method)
|
||||
assert.Equal(t, err, afterHoo.err)
|
||||
|
||||
hoo = defaultHook{}
|
||||
hoo = &hookutil.DefaultHook{}
|
||||
res, err = interceptor(ctx, r, info, func(ctx context.Context, r interface{}) (interface{}, error) {
|
||||
return &resp{
|
||||
method: r.(*req).method,
|
||||
@ -139,18 +128,6 @@ func TestHookInterceptor(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestDefaultHook(t *testing.T) {
|
||||
d := defaultHook{}
|
||||
assert.NoError(t, d.Init(nil))
|
||||
{
|
||||
_, err := d.VerifyAPIKey("key")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
assert.NotPanics(t, func() {
|
||||
d.Release()
|
||||
})
|
||||
}
|
||||
|
||||
func TestUpdateProxyFunctionCallMetric(t *testing.T) {
|
||||
assert.NotPanics(t, func() {
|
||||
updateProxyFunctionCallMetric("/milvus.proto.milvus.MilvusService/Flush")
|
||||
|
@ -42,6 +42,7 @@ import (
|
||||
"github.com/milvus-io/milvus/internal/proto/proxypb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/proxy/connection"
|
||||
"github.com/milvus-io/milvus/internal/util/hookutil"
|
||||
"github.com/milvus-io/milvus/internal/util/importutilv2"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
@ -2394,6 +2395,15 @@ func (node *Proxy) Insert(ctx context.Context, request *milvuspb.InsertRequest)
|
||||
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method,
|
||||
metrics.SuccessLabel).Inc()
|
||||
successCnt := it.result.InsertCnt - int64(len(it.result.ErrIndex))
|
||||
v := Extension.Report(map[string]any{
|
||||
hookutil.OpTypeKey: hookutil.OpTypeInsert,
|
||||
hookutil.DatabaseKey: request.DbName,
|
||||
hookutil.UsernameKey: GetCurUserFromContextOrDefault(ctx),
|
||||
hookutil.DataSizeKey: proto.Size(request),
|
||||
hookutil.SuccessCntKey: successCnt,
|
||||
hookutil.FailCntKey: len(it.result.ErrIndex),
|
||||
})
|
||||
SetReportValue(it.result.GetStatus(), v)
|
||||
metrics.ProxyInsertVectors.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Add(float64(successCnt))
|
||||
metrics.ProxyMutationLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.InsertLabel).Observe(float64(tr.ElapseSpan().Milliseconds()))
|
||||
metrics.ProxyCollectionMutationLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.InsertLabel, request.CollectionName).Observe(float64(tr.ElapseSpan().Milliseconds()))
|
||||
@ -2469,6 +2479,15 @@ func (node *Proxy) Delete(ctx context.Context, request *milvuspb.DeleteRequest)
|
||||
successCnt := dr.result.GetDeleteCnt()
|
||||
metrics.ProxyDeleteVectors.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Add(float64(successCnt))
|
||||
|
||||
v := Extension.Report(map[string]any{
|
||||
hookutil.OpTypeKey: hookutil.OpTypeDelete,
|
||||
hookutil.DatabaseKey: request.DbName,
|
||||
hookutil.UsernameKey: GetCurUserFromContextOrDefault(ctx),
|
||||
hookutil.SuccessCntKey: successCnt,
|
||||
hookutil.RelatedCntKey: dr.allQueryCnt.Load(),
|
||||
})
|
||||
SetReportValue(dr.result.GetStatus(), v)
|
||||
|
||||
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method,
|
||||
metrics.SuccessLabel).Inc()
|
||||
metrics.ProxyMutationLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.DeleteLabel).Observe(float64(tr.ElapseSpan().Milliseconds()))
|
||||
@ -2584,6 +2603,16 @@ func (node *Proxy) Upsert(ctx context.Context, request *milvuspb.UpsertRequest)
|
||||
// UpsertCnt always equals to the number of entities in the request
|
||||
it.result.UpsertCnt = int64(request.NumRows)
|
||||
|
||||
v := Extension.Report(map[string]any{
|
||||
hookutil.OpTypeKey: hookutil.OpTypeUpsert,
|
||||
hookutil.DatabaseKey: request.DbName,
|
||||
hookutil.UsernameKey: GetCurUserFromContextOrDefault(ctx),
|
||||
hookutil.DataSizeKey: proto.Size(it.req),
|
||||
hookutil.SuccessCntKey: it.result.UpsertCnt,
|
||||
hookutil.FailCntKey: len(it.result.ErrIndex),
|
||||
})
|
||||
SetReportValue(it.result.GetStatus(), v)
|
||||
|
||||
rateCol.Add(internalpb.RateType_DMLUpsert.String(), float64(it.upsertMsg.DeleteMsg.Size()+it.upsertMsg.DeleteMsg.Size()))
|
||||
|
||||
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method,
|
||||
@ -2759,6 +2788,15 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest)
|
||||
|
||||
if qt.result != nil {
|
||||
sentSize := proto.Size(qt.result)
|
||||
v := Extension.Report(map[string]any{
|
||||
hookutil.OpTypeKey: hookutil.OpTypeSearch,
|
||||
hookutil.DatabaseKey: request.DbName,
|
||||
hookutil.UsernameKey: GetCurUserFromContextOrDefault(ctx),
|
||||
hookutil.DataSizeKey: sentSize,
|
||||
hookutil.RelatedCntKey: qt.result.GetResults().GetAllSearchCount(),
|
||||
hookutil.DimensionKey: qt.dimension,
|
||||
})
|
||||
SetReportValue(qt.result.GetStatus(), v)
|
||||
metrics.ProxyReadReqSendBytes.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Add(float64(sentSize))
|
||||
rateCol.Add(metricsinfo.ReadResultThroughput, float64(sentSize))
|
||||
}
|
||||
@ -2902,6 +2940,15 @@ func (node *Proxy) HybridSearch(ctx context.Context, request *milvuspb.HybridSea
|
||||
|
||||
if qt.result != nil {
|
||||
sentSize := proto.Size(qt.result)
|
||||
v := Extension.Report(map[string]any{
|
||||
hookutil.OpTypeKey: hookutil.OpTypeHybridSearch,
|
||||
hookutil.DatabaseKey: request.DbName,
|
||||
hookutil.UsernameKey: GetCurUserFromContextOrDefault(ctx),
|
||||
hookutil.DataSizeKey: sentSize,
|
||||
hookutil.RelatedCntKey: qt.result.GetResults().GetAllSearchCount(),
|
||||
hookutil.DimensionKey: qt.dimension,
|
||||
})
|
||||
SetReportValue(qt.result.GetStatus(), v)
|
||||
metrics.ProxyReadReqSendBytes.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Add(float64(sentSize))
|
||||
rateCol.Add(metricsinfo.ReadResultThroughput, float64(sentSize))
|
||||
}
|
||||
@ -3182,7 +3229,19 @@ func (node *Proxy) Query(ctx context.Context, request *milvuspb.QueryRequest) (*
|
||||
qc: node.queryCoord,
|
||||
lb: node.lbPolicy,
|
||||
}
|
||||
return node.query(ctx, qt)
|
||||
res, err := node.query(ctx, qt)
|
||||
if merr.Ok(res.Status) && err == nil {
|
||||
v := Extension.Report(map[string]any{
|
||||
hookutil.OpTypeKey: hookutil.OpTypeQuery,
|
||||
hookutil.DatabaseKey: request.DbName,
|
||||
hookutil.UsernameKey: GetCurUserFromContextOrDefault(ctx),
|
||||
hookutil.DataSizeKey: proto.Size(res),
|
||||
hookutil.RelatedCntKey: qt.allQueryCnt,
|
||||
hookutil.DimensionKey: qt.dimension,
|
||||
})
|
||||
SetReportValue(res.Status, v)
|
||||
}
|
||||
return res, err
|
||||
}
|
||||
|
||||
// CreateAlias create alias for collection, then you can search the collection with alias.
|
||||
|
@ -59,6 +59,8 @@ type Cache interface {
|
||||
GetCollectionName(ctx context.Context, database string, collectionID int64) (string, error)
|
||||
// GetCollectionInfo get collection's information by name or collection id, such as schema, and etc.
|
||||
GetCollectionInfo(ctx context.Context, database, collectionName string, collectionID int64) (*collectionBasicInfo, error)
|
||||
// GetCollectionNamesByID get collection name and database name by collection id
|
||||
GetCollectionNamesByID(ctx context.Context, collectionID []UniqueID) ([]string, []string, error)
|
||||
// GetPartitionID get partition's identifier of specific collection.
|
||||
GetPartitionID(ctx context.Context, database, collectionName string, partitionName string) (typeutil.UniqueID, error)
|
||||
// GetPartitions get all partitions' id of specific collection.
|
||||
@ -242,11 +244,12 @@ type MetaCache struct {
|
||||
rootCoord types.RootCoordClient
|
||||
queryCoord types.QueryCoordClient
|
||||
|
||||
collInfo map[string]map[string]*collectionInfo // database -> collectionName -> collection_info
|
||||
collLeader map[string]map[string]*shardLeaders // database -> collectionName -> collection_leaders
|
||||
credMap map[string]*internalpb.CredentialInfo // cache for credential, lazy load
|
||||
privilegeInfos map[string]struct{} // privileges cache
|
||||
userToRoles map[string]map[string]struct{} // user to role cache
|
||||
collInfo map[string]map[string]*collectionInfo // database -> collectionName -> collection_info
|
||||
collLeader map[string]map[string]*shardLeaders // database -> collectionName -> collection_leaders
|
||||
dbInfo map[string]map[typeutil.UniqueID]string // database -> collectionID -> collectionName
|
||||
credMap map[string]*internalpb.CredentialInfo // cache for credential, lazy load
|
||||
privilegeInfos map[string]struct{} // privileges cache
|
||||
userToRoles map[string]map[string]struct{} // user to role cache
|
||||
mu sync.RWMutex
|
||||
credMut sync.RWMutex
|
||||
leaderMut sync.RWMutex
|
||||
@ -288,6 +291,7 @@ func NewMetaCache(rootCoord types.RootCoordClient, queryCoord types.QueryCoordCl
|
||||
queryCoord: queryCoord,
|
||||
collInfo: map[string]map[string]*collectionInfo{},
|
||||
collLeader: map[string]map[string]*shardLeaders{},
|
||||
dbInfo: map[string]map[typeutil.UniqueID]string{},
|
||||
credMap: map[string]*internalpb.CredentialInfo{},
|
||||
shardMgr: shardMgr,
|
||||
privilegeInfos: map[string]struct{}{},
|
||||
@ -471,6 +475,90 @@ func (m *MetaCache) GetCollectionInfo(ctx context.Context, database string, coll
|
||||
return collInfo.getBasicInfo(), nil
|
||||
}
|
||||
|
||||
func (m *MetaCache) GetCollectionNamesByID(ctx context.Context, collectionIDs []UniqueID) ([]string, []string, error) {
|
||||
hasUpdate := false
|
||||
|
||||
dbNames := make([]string, 0)
|
||||
collectionNames := make([]string, 0)
|
||||
for _, collectionID := range collectionIDs {
|
||||
dbName, collectionName := m.innerGetCollectionByID(collectionID)
|
||||
if dbName != "" {
|
||||
dbNames = append(dbNames, dbName)
|
||||
collectionNames = append(collectionNames, collectionName)
|
||||
continue
|
||||
}
|
||||
if hasUpdate {
|
||||
return nil, nil, errors.New("collection not found after meta cache has been updated")
|
||||
}
|
||||
hasUpdate = true
|
||||
err := m.updateDBInfo(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
dbName, collectionName = m.innerGetCollectionByID(collectionID)
|
||||
if dbName == "" {
|
||||
return nil, nil, errors.New("collection not found")
|
||||
}
|
||||
dbNames = append(dbNames, dbName)
|
||||
collectionNames = append(collectionNames, collectionName)
|
||||
}
|
||||
|
||||
return dbNames, collectionNames, nil
|
||||
}
|
||||
|
||||
func (m *MetaCache) innerGetCollectionByID(collectionID int64) (string, string) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
for database, db := range m.dbInfo {
|
||||
name, ok := db[collectionID]
|
||||
if ok {
|
||||
return database, name
|
||||
}
|
||||
}
|
||||
return "", ""
|
||||
}
|
||||
|
||||
func (m *MetaCache) updateDBInfo(ctx context.Context) error {
|
||||
databaseResp, err := m.rootCoord.ListDatabases(ctx, &milvuspb.ListDatabasesRequest{
|
||||
Base: commonpbutil.NewMsgBase(commonpbutil.WithMsgType(commonpb.MsgType_ListDatabases)),
|
||||
})
|
||||
|
||||
if err := merr.CheckRPCCall(databaseResp, err); err != nil {
|
||||
log.Warn("failed to ListDatabases", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
dbInfo := make(map[string]map[int64]string)
|
||||
for _, dbName := range databaseResp.DbNames {
|
||||
resp, err := m.rootCoord.ShowCollections(ctx, &milvuspb.ShowCollectionsRequest{
|
||||
Base: commonpbutil.NewMsgBase(
|
||||
commonpbutil.WithMsgType(commonpb.MsgType_ShowCollections),
|
||||
),
|
||||
DbName: dbName,
|
||||
})
|
||||
|
||||
if err := merr.CheckRPCCall(resp, err); err != nil {
|
||||
log.Warn("failed to ShowCollections",
|
||||
zap.String("dbName", dbName),
|
||||
zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
collections := make(map[int64]string)
|
||||
for i, collection := range resp.CollectionNames {
|
||||
collections[resp.CollectionIds[i]] = collection
|
||||
}
|
||||
dbInfo[dbName] = collections
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.dbInfo = dbInfo
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetCollectionInfo returns the collection information related to provided collection name
|
||||
// If the information is not found, proxy will try to fetch information for other source (RootCoord for now)
|
||||
// TODO: may cause data race of this implementation, should be refactored in future.
|
||||
|
@ -881,3 +881,149 @@ func TestMetaCache_AllocID(t *testing.T) {
|
||||
assert.Equal(t, id, int64(0))
|
||||
})
|
||||
}
|
||||
|
||||
func TestGlobalMetaCache_UpdateDBInfo(t *testing.T) {
|
||||
rootCoord := mocks.NewMockRootCoordClient(t)
|
||||
queryCoord := mocks.NewMockQueryCoordClient(t)
|
||||
shardMgr := newShardClientMgr()
|
||||
ctx := context.Background()
|
||||
|
||||
cache, err := NewMetaCache(rootCoord, queryCoord, shardMgr)
|
||||
assert.NoError(t, err)
|
||||
|
||||
t.Run("fail to list db", func(t *testing.T) {
|
||||
rootCoord.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return(&milvuspb.ListDatabasesResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Code: 500,
|
||||
},
|
||||
}, nil).Once()
|
||||
err := cache.updateDBInfo(ctx)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("fail to list collection", func(t *testing.T) {
|
||||
rootCoord.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return(&milvuspb.ListDatabasesResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
},
|
||||
DbNames: []string{"db1"},
|
||||
}, nil).Once()
|
||||
rootCoord.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&milvuspb.ShowCollectionsResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Code: 500,
|
||||
},
|
||||
}, nil).Once()
|
||||
err := cache.updateDBInfo(ctx)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
rootCoord.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return(&milvuspb.ListDatabasesResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
},
|
||||
DbNames: []string{"db1"},
|
||||
}, nil).Once()
|
||||
rootCoord.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&milvuspb.ShowCollectionsResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
},
|
||||
CollectionNames: []string{"collection1"},
|
||||
CollectionIds: []int64{1},
|
||||
}, nil).Once()
|
||||
err := cache.updateDBInfo(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, cache.dbInfo, 1)
|
||||
assert.Len(t, cache.dbInfo["db1"], 1)
|
||||
assert.Equal(t, "collection1", cache.dbInfo["db1"][1])
|
||||
})
|
||||
}
|
||||
|
||||
func TestGlobalMetaCache_GetCollectionNamesByID(t *testing.T) {
|
||||
rootCoord := mocks.NewMockRootCoordClient(t)
|
||||
queryCoord := mocks.NewMockQueryCoordClient(t)
|
||||
shardMgr := newShardClientMgr()
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("fail to update db info", func(t *testing.T) {
|
||||
rootCoord.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return(&milvuspb.ListDatabasesResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Code: 500,
|
||||
},
|
||||
}, nil).Once()
|
||||
|
||||
cache, err := NewMetaCache(rootCoord, queryCoord, shardMgr)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, _, err = cache.GetCollectionNamesByID(ctx, []int64{1})
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("not found collection", func(t *testing.T) {
|
||||
rootCoord.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return(&milvuspb.ListDatabasesResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
},
|
||||
DbNames: []string{"db1"},
|
||||
}, nil).Once()
|
||||
rootCoord.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&milvuspb.ShowCollectionsResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
},
|
||||
CollectionNames: []string{"collection1"},
|
||||
CollectionIds: []int64{1},
|
||||
}, nil).Once()
|
||||
|
||||
cache, err := NewMetaCache(rootCoord, queryCoord, shardMgr)
|
||||
assert.NoError(t, err)
|
||||
_, _, err = cache.GetCollectionNamesByID(ctx, []int64{2})
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("not found collection 2", func(t *testing.T) {
|
||||
rootCoord.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return(&milvuspb.ListDatabasesResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
},
|
||||
DbNames: []string{"db1"},
|
||||
}, nil).Once()
|
||||
rootCoord.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&milvuspb.ShowCollectionsResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
},
|
||||
CollectionNames: []string{"collection1"},
|
||||
CollectionIds: []int64{1},
|
||||
}, nil).Once()
|
||||
|
||||
cache, err := NewMetaCache(rootCoord, queryCoord, shardMgr)
|
||||
assert.NoError(t, err)
|
||||
_, _, err = cache.GetCollectionNamesByID(ctx, []int64{1, 2})
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
rootCoord.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return(&milvuspb.ListDatabasesResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
},
|
||||
DbNames: []string{"db1"},
|
||||
}, nil).Once()
|
||||
rootCoord.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&milvuspb.ShowCollectionsResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
},
|
||||
CollectionNames: []string{"collection1", "collection2"},
|
||||
CollectionIds: []int64{1, 2},
|
||||
}, nil).Once()
|
||||
|
||||
cache, err := NewMetaCache(rootCoord, queryCoord, shardMgr)
|
||||
assert.NoError(t, err)
|
||||
dbNames, collectionNames, err := cache.GetCollectionNamesByID(ctx, []int64{1, 2})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{"collection1", "collection2"}, collectionNames)
|
||||
assert.Equal(t, []string{"db1", "db1"}, dbNames)
|
||||
})
|
||||
}
|
||||
|
@ -275,6 +275,70 @@ func (_c *MockCache_GetCollectionName_Call) RunAndReturn(run func(context.Contex
|
||||
return _c
|
||||
}
|
||||
|
||||
// GetCollectionNamesByID provides a mock function with given fields: ctx, collectionID
|
||||
func (_m *MockCache) GetCollectionNamesByID(ctx context.Context, collectionID []int64) ([]string, []string, error) {
|
||||
ret := _m.Called(ctx, collectionID)
|
||||
|
||||
var r0 []string
|
||||
var r1 []string
|
||||
var r2 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, []int64) ([]string, []string, error)); ok {
|
||||
return rf(ctx, collectionID)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, []int64) []string); ok {
|
||||
r0 = rf(ctx, collectionID)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).([]string)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, []int64) []string); ok {
|
||||
r1 = rf(ctx, collectionID)
|
||||
} else {
|
||||
if ret.Get(1) != nil {
|
||||
r1 = ret.Get(1).([]string)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(2).(func(context.Context, []int64) error); ok {
|
||||
r2 = rf(ctx, collectionID)
|
||||
} else {
|
||||
r2 = ret.Error(2)
|
||||
}
|
||||
|
||||
return r0, r1, r2
|
||||
}
|
||||
|
||||
// MockCache_GetCollectionNamesByID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetCollectionNamesByID'
|
||||
type MockCache_GetCollectionNamesByID_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// GetCollectionNamesByID is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - collectionID []int64
|
||||
func (_e *MockCache_Expecter) GetCollectionNamesByID(ctx interface{}, collectionID interface{}) *MockCache_GetCollectionNamesByID_Call {
|
||||
return &MockCache_GetCollectionNamesByID_Call{Call: _e.mock.On("GetCollectionNamesByID", ctx, collectionID)}
|
||||
}
|
||||
|
||||
func (_c *MockCache_GetCollectionNamesByID_Call) Run(run func(ctx context.Context, collectionID []int64)) *MockCache_GetCollectionNamesByID_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context), args[1].([]int64))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockCache_GetCollectionNamesByID_Call) Return(_a0 []string, _a1 []string, _a2 error) *MockCache_GetCollectionNamesByID_Call {
|
||||
_c.Call.Return(_a0, _a1, _a2)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockCache_GetCollectionNamesByID_Call) RunAndReturn(run func(context.Context, []int64) ([]string, []string, error)) *MockCache_GetCollectionNamesByID_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// GetCollectionSchema provides a mock function with given fields: ctx, database, collectionName
|
||||
func (_m *MockCache) GetCollectionSchema(ctx context.Context, database string, collectionName string) (*schemaInfo, error) {
|
||||
ret := _m.Called(ctx, database, collectionName)
|
||||
|
@ -26,11 +26,13 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/samber/lo"
|
||||
clientv3 "go.etcd.io/etcd/client/v3"
|
||||
"go.uber.org/atomic"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/hook"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus/internal/allocator"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
@ -38,6 +40,7 @@ import (
|
||||
"github.com/milvus-io/milvus/internal/proxy/connection"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/internal/util/dependency"
|
||||
"github.com/milvus-io/milvus/internal/util/hookutil"
|
||||
"github.com/milvus-io/milvus/internal/util/sessionutil"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/metrics"
|
||||
@ -45,6 +48,7 @@ import (
|
||||
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/expr"
|
||||
"github.com/milvus-io/milvus/pkg/util/logutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/metricsinfo"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/util/ratelimitutil"
|
||||
@ -65,10 +69,11 @@ type Timestamp = typeutil.Timestamp
|
||||
// make sure Proxy implements types.Proxy
|
||||
var _ types.Proxy = (*Proxy)(nil)
|
||||
|
||||
var Params *paramtable.ComponentParam = paramtable.Get()
|
||||
|
||||
// rateCol is global rateCollector in Proxy.
|
||||
var rateCol *ratelimitutil.RateCollector
|
||||
var (
|
||||
Params = paramtable.Get()
|
||||
Extension hook.Extension
|
||||
rateCol *ratelimitutil.RateCollector
|
||||
)
|
||||
|
||||
// Proxy of milvus
|
||||
type Proxy struct {
|
||||
@ -151,6 +156,8 @@ func NewProxy(ctx context.Context, factory dependency.Factory) (*Proxy, error) {
|
||||
}
|
||||
node.UpdateStateCode(commonpb.StateCode_Abnormal)
|
||||
expr.Register("proxy", node)
|
||||
hookutil.InitOnceHook()
|
||||
Extension = hookutil.Extension
|
||||
logutil.Logger(ctx).Debug("create a new Proxy instance", zap.Any("state", node.stateCode.Load()))
|
||||
return node, nil
|
||||
}
|
||||
@ -415,6 +422,12 @@ func (node *Proxy) Start() error {
|
||||
cb()
|
||||
}
|
||||
|
||||
Extension.Report(map[string]any{
|
||||
hookutil.OpTypeKey: hookutil.OpTypeNodeID,
|
||||
hookutil.NodeIDKey: paramtable.GetNodeID(),
|
||||
})
|
||||
node.startReportCollectionStorage()
|
||||
|
||||
log.Debug("update state code", zap.String("role", typeutil.ProxyRole), zap.String("State", commonpb.StateCode_Healthy.String()))
|
||||
node.UpdateStateCode(commonpb.StateCode_Healthy)
|
||||
|
||||
@ -537,3 +550,87 @@ func (node *Proxy) GetRateLimiter() (types.Limiter, error) {
|
||||
}
|
||||
return node.multiRateLimiter, nil
|
||||
}
|
||||
|
||||
func (node *Proxy) startReportCollectionStorage() {
|
||||
go func() {
|
||||
tick := time.NewTicker(30 * time.Second)
|
||||
defer tick.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-node.ctx.Done():
|
||||
return
|
||||
case <-tick.C:
|
||||
_ = node.reportCollectionStorage()
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (node *Proxy) reportCollectionStorage() error {
|
||||
if node.dataCoord == nil {
|
||||
return errors.New("nil datacoord")
|
||||
}
|
||||
req, err := metricsinfo.ConstructRequestByMetricType(metricsinfo.CollectionStorageMetrics)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rsp, err := node.dataCoord.GetMetrics(node.ctx, req)
|
||||
if err = merr.CheckRPCCall(rsp, err); err != nil {
|
||||
log.Warn("failed to get metrics", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
dataCoordTopology := &metricsinfo.DataCoordTopology{}
|
||||
err = metricsinfo.UnmarshalTopology(rsp.GetResponse(), dataCoordTopology)
|
||||
if err != nil {
|
||||
log.Warn("failed to unmarshal topology", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
quotaMetric := dataCoordTopology.Cluster.Self.QuotaMetrics
|
||||
if quotaMetric == nil {
|
||||
log.Warn("quota metric is nil")
|
||||
return errors.New("quota metric is nil")
|
||||
}
|
||||
|
||||
ctx, cancelFunc := context.WithTimeout(node.ctx, 5*time.Second)
|
||||
defer cancelFunc()
|
||||
|
||||
ids := lo.Keys(quotaMetric.CollectionBinlogSize)
|
||||
dbNames, collectionNames, err := globalMetaCache.GetCollectionNamesByID(ctx, ids)
|
||||
if err != nil {
|
||||
log.Warn("failed to get collection names", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
if len(ids) != len(dbNames) || len(ids) != len(collectionNames) {
|
||||
log.Warn("failed to get collection names",
|
||||
zap.Int("len(ids)", len(ids)),
|
||||
zap.Int("len(dbNames)", len(dbNames)),
|
||||
zap.Int("len(collectionNames)", len(collectionNames)))
|
||||
return errors.New("failed to get collection names")
|
||||
}
|
||||
|
||||
nameInfos := make(map[typeutil.UniqueID]lo.Tuple2[string, string])
|
||||
for i, k := range ids {
|
||||
nameInfos[k] = lo.Tuple2[string, string]{A: dbNames[i], B: collectionNames[i]}
|
||||
}
|
||||
|
||||
storeInfo := make(map[string]int64)
|
||||
for collectionID, dataSize := range quotaMetric.CollectionBinlogSize {
|
||||
nameTuple, ok := nameInfos[collectionID]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
storeInfo[nameTuple.A] += dataSize
|
||||
}
|
||||
|
||||
if len(storeInfo) > 0 {
|
||||
Extension.Report(map[string]any{
|
||||
hookutil.OpTypeKey: hookutil.OpTypeStorage,
|
||||
hookutil.StorageDetailKey: lo.MapValues(storeInfo, func(v int64, _ string) any { return v }),
|
||||
})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -60,6 +60,7 @@ import (
|
||||
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
|
||||
"github.com/milvus-io/milvus/internal/util/componentutil"
|
||||
"github.com/milvus-io/milvus/internal/util/dependency"
|
||||
"github.com/milvus-io/milvus/internal/util/hookutil"
|
||||
"github.com/milvus-io/milvus/internal/util/sessionutil"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
@ -1085,7 +1086,7 @@ func TestProxy(t *testing.T) {
|
||||
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode())
|
||||
})
|
||||
|
||||
var insertedIds []int64
|
||||
var insertedIDs []int64
|
||||
wg.Add(1)
|
||||
t.Run("insert", func(t *testing.T) {
|
||||
defer wg.Done()
|
||||
@ -1100,7 +1101,7 @@ func TestProxy(t *testing.T) {
|
||||
|
||||
switch field := resp.GetIDs().GetIdField().(type) {
|
||||
case *schemapb.IDs_IntId:
|
||||
insertedIds = field.IntId.GetData()
|
||||
insertedIDs = field.IntId.GetData()
|
||||
default:
|
||||
t.Fatalf("Unexpected ID type")
|
||||
}
|
||||
@ -1611,7 +1612,7 @@ func TestProxy(t *testing.T) {
|
||||
nq = 10
|
||||
|
||||
constructPrimaryKeysPlaceholderGroup := func() *commonpb.PlaceholderGroup {
|
||||
expr := fmt.Sprintf("%v in [%v]", int64Field, insertedIds[0])
|
||||
expr := fmt.Sprintf("%v in [%v]", int64Field, insertedIDs[0])
|
||||
exprBytes := []byte(expr)
|
||||
|
||||
return &commonpb.PlaceholderGroup{
|
||||
@ -4803,3 +4804,227 @@ func TestUnhealthProxy_GetIndexStatistics(t *testing.T) {
|
||||
assert.Equal(t, commonpb.ErrorCode_NotReadyServe, resp.GetStatus().GetErrorCode())
|
||||
})
|
||||
}
|
||||
|
||||
func TestProxy_ReportCollectionStorage(t *testing.T) {
|
||||
t.Run("nil datacoord", func(t *testing.T) {
|
||||
proxy := &Proxy{}
|
||||
err := proxy.reportCollectionStorage()
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("fail to get metric", func(t *testing.T) {
|
||||
datacoord := mocks.NewMockDataCoordClient(t)
|
||||
datacoord.EXPECT().GetMetrics(mock.Anything, mock.Anything).Return(&milvuspb.GetMetricsResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Code: 500,
|
||||
},
|
||||
}, nil).Once()
|
||||
|
||||
ctx := context.Background()
|
||||
proxy := &Proxy{
|
||||
ctx: ctx,
|
||||
dataCoord: datacoord,
|
||||
}
|
||||
err := proxy.reportCollectionStorage()
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("fail to unmarshal metric", func(t *testing.T) {
|
||||
datacoord := mocks.NewMockDataCoordClient(t)
|
||||
datacoord.EXPECT().GetMetrics(mock.Anything, mock.Anything).Return(&milvuspb.GetMetricsResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
},
|
||||
Response: "invalid",
|
||||
}, nil).Once()
|
||||
|
||||
ctx := context.Background()
|
||||
proxy := &Proxy{
|
||||
ctx: ctx,
|
||||
dataCoord: datacoord,
|
||||
}
|
||||
err := proxy.reportCollectionStorage()
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("empty metric", func(t *testing.T) {
|
||||
datacoord := mocks.NewMockDataCoordClient(t)
|
||||
|
||||
r, _ := json.Marshal(&metricsinfo.DataCoordTopology{
|
||||
Cluster: metricsinfo.DataClusterTopology{
|
||||
Self: metricsinfo.DataCoordInfos{},
|
||||
},
|
||||
})
|
||||
|
||||
datacoord.EXPECT().GetMetrics(mock.Anything, mock.Anything).Return(&milvuspb.GetMetricsResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
},
|
||||
Response: string(r),
|
||||
ComponentName: "DataCoord",
|
||||
}, nil).Once()
|
||||
|
||||
ctx := context.Background()
|
||||
proxy := &Proxy{
|
||||
ctx: ctx,
|
||||
dataCoord: datacoord,
|
||||
}
|
||||
err := proxy.reportCollectionStorage()
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("fail to get cache", func(t *testing.T) {
|
||||
origin := globalMetaCache
|
||||
defer func() {
|
||||
globalMetaCache = origin
|
||||
}()
|
||||
mockCache := NewMockCache(t)
|
||||
globalMetaCache = mockCache
|
||||
|
||||
datacoord := mocks.NewMockDataCoordClient(t)
|
||||
r, _ := json.Marshal(&metricsinfo.DataCoordTopology{
|
||||
Cluster: metricsinfo.DataClusterTopology{
|
||||
Self: metricsinfo.DataCoordInfos{
|
||||
QuotaMetrics: &metricsinfo.DataCoordQuotaMetrics{
|
||||
TotalBinlogSize: 200,
|
||||
CollectionBinlogSize: map[int64]int64{
|
||||
1: 100,
|
||||
2: 50,
|
||||
3: 50,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
datacoord.EXPECT().GetMetrics(mock.Anything, mock.Anything).Return(&milvuspb.GetMetricsResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
},
|
||||
Response: string(r),
|
||||
ComponentName: "DataCoord",
|
||||
}, nil).Once()
|
||||
mockCache.EXPECT().GetCollectionNamesByID(mock.Anything, mock.Anything).Return(nil, nil, errors.New("mock get collection names by id error")).Once()
|
||||
|
||||
ctx := context.Background()
|
||||
proxy := &Proxy{
|
||||
ctx: ctx,
|
||||
dataCoord: datacoord,
|
||||
}
|
||||
err := proxy.reportCollectionStorage()
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("not match data", func(t *testing.T) {
|
||||
origin := globalMetaCache
|
||||
defer func() {
|
||||
globalMetaCache = origin
|
||||
}()
|
||||
mockCache := NewMockCache(t)
|
||||
globalMetaCache = mockCache
|
||||
|
||||
datacoord := mocks.NewMockDataCoordClient(t)
|
||||
r, _ := json.Marshal(&metricsinfo.DataCoordTopology{
|
||||
Cluster: metricsinfo.DataClusterTopology{
|
||||
Self: metricsinfo.DataCoordInfos{
|
||||
QuotaMetrics: &metricsinfo.DataCoordQuotaMetrics{
|
||||
TotalBinlogSize: 200,
|
||||
CollectionBinlogSize: map[int64]int64{
|
||||
1: 100,
|
||||
2: 50,
|
||||
3: 50,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
datacoord.EXPECT().GetMetrics(mock.Anything, mock.Anything).Return(&milvuspb.GetMetricsResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
},
|
||||
Response: string(r),
|
||||
ComponentName: "DataCoord",
|
||||
}, nil).Once()
|
||||
mockCache.EXPECT().GetCollectionNamesByID(mock.Anything, mock.Anything).Return(
|
||||
[]string{"db1", "db1"}, []string{"col1", "col2"}, nil).Once()
|
||||
|
||||
ctx := context.Background()
|
||||
proxy := &Proxy{
|
||||
ctx: ctx,
|
||||
dataCoord: datacoord,
|
||||
}
|
||||
err := proxy.reportCollectionStorage()
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
origin := globalMetaCache
|
||||
defer func() {
|
||||
globalMetaCache = origin
|
||||
}()
|
||||
mockCache := NewMockCache(t)
|
||||
globalMetaCache = mockCache
|
||||
|
||||
datacoord := mocks.NewMockDataCoordClient(t)
|
||||
r, _ := json.Marshal(&metricsinfo.DataCoordTopology{
|
||||
Cluster: metricsinfo.DataClusterTopology{
|
||||
Self: metricsinfo.DataCoordInfos{
|
||||
QuotaMetrics: &metricsinfo.DataCoordQuotaMetrics{
|
||||
TotalBinlogSize: 200,
|
||||
CollectionBinlogSize: map[int64]int64{
|
||||
1: 100,
|
||||
2: 50,
|
||||
3: 50,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
datacoord.EXPECT().GetMetrics(mock.Anything, mock.Anything).Return(&milvuspb.GetMetricsResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
},
|
||||
Response: string(r),
|
||||
ComponentName: "DataCoord",
|
||||
}, nil).Once()
|
||||
mockCache.EXPECT().GetCollectionNamesByID(mock.Anything, mock.Anything).Return(
|
||||
[]string{"db1", "db1", "db2"}, []string{"col1", "col2", "col3"}, nil).Once()
|
||||
|
||||
originExtension := Extension
|
||||
defer func() {
|
||||
Extension = originExtension
|
||||
}()
|
||||
hasCheck := false
|
||||
Extension = CheckExtension{
|
||||
reportChecker: func(info any) {
|
||||
infoMap := info.(map[string]any)
|
||||
storage := infoMap[hookutil.StorageDetailKey].(map[string]any)
|
||||
log.Info("storage map", zap.Any("storage", storage))
|
||||
assert.EqualValues(t, 150, storage["db1"])
|
||||
assert.EqualValues(t, 50, storage["db2"])
|
||||
hasCheck = true
|
||||
},
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
proxy := &Proxy{
|
||||
ctx: ctx,
|
||||
dataCoord: datacoord,
|
||||
}
|
||||
err := proxy.reportCollectionStorage()
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, hasCheck)
|
||||
})
|
||||
}
|
||||
|
||||
type CheckExtension struct {
|
||||
reportChecker func(info any)
|
||||
}
|
||||
|
||||
func (c CheckExtension) Report(info any) int {
|
||||
c.reportChecker(info)
|
||||
return 0
|
||||
}
|
||||
|
@ -117,6 +117,7 @@ func reduceSearchResultDataWithGroupBy(ctx context.Context, subSearchResultData
|
||||
zap.Int64("topk", sData.TopK),
|
||||
zap.Int("length of pks", pkLength),
|
||||
zap.Int("length of FieldsData", len(sData.FieldsData)))
|
||||
ret.Results.AllSearchCount += sData.GetAllSearchCount()
|
||||
if err := checkSearchResultData(sData, nq, topk); err != nil {
|
||||
log.Ctx(ctx).Warn("invalid search results", zap.Error(err))
|
||||
return ret, err
|
||||
@ -280,6 +281,7 @@ func reduceSearchResultDataNoGroupBy(ctx context.Context, subSearchResultData []
|
||||
zap.Int64("topk", sData.TopK),
|
||||
zap.Int("length of pks", pkLength),
|
||||
zap.Int("length of FieldsData", len(sData.FieldsData)))
|
||||
ret.Results.AllSearchCount += sData.GetAllSearchCount()
|
||||
if err := checkSearchResultData(sData, nq, topk); err != nil {
|
||||
log.Ctx(ctx).Warn("invalid search results", zap.Error(err))
|
||||
return ret, err
|
||||
|
@ -61,7 +61,8 @@ type deleteTask struct {
|
||||
msgID UniqueID
|
||||
|
||||
// result
|
||||
count int64
|
||||
count int64
|
||||
allQueryCnt int64
|
||||
}
|
||||
|
||||
func (dt *deleteTask) TraceCtx() context.Context {
|
||||
@ -246,6 +247,8 @@ type deleteRunner struct {
|
||||
|
||||
// task queue
|
||||
queue *dmTaskQueue
|
||||
|
||||
allQueryCnt atomic.Int64
|
||||
}
|
||||
|
||||
func (dr *deleteRunner) Init(ctx context.Context) error {
|
||||
@ -422,6 +425,7 @@ func (dr *deleteRunner) getStreamingQueryAndDelteFunc(plan *planpb.PlanNode) exe
|
||||
|
||||
taskCh := make(chan *deleteTask, 256)
|
||||
go dr.receiveQueryResult(ctx, client, taskCh)
|
||||
var allQueryCnt int64
|
||||
// wait all task finish
|
||||
for task := range taskCh {
|
||||
err := task.WaitToFinish()
|
||||
@ -429,12 +433,14 @@ func (dr *deleteRunner) getStreamingQueryAndDelteFunc(plan *planpb.PlanNode) exe
|
||||
return err
|
||||
}
|
||||
dr.count.Add(task.count)
|
||||
allQueryCnt += task.allQueryCnt
|
||||
}
|
||||
|
||||
// query or produce task failed
|
||||
if dr.err != nil {
|
||||
return dr.err
|
||||
}
|
||||
dr.allQueryCnt.Add(allQueryCnt)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@ -468,6 +474,7 @@ func (dr *deleteRunner) receiveQueryResult(ctx context.Context, client querypb.Q
|
||||
log.Warn("produce delete task failed", zap.Error(err))
|
||||
return
|
||||
}
|
||||
task.allQueryCnt = result.GetAllRetrieveCount()
|
||||
|
||||
taskCh <- task
|
||||
}
|
||||
|
@ -37,6 +37,7 @@ type hybridSearchTask struct {
|
||||
ctx context.Context
|
||||
*internalpb.HybridSearchRequest
|
||||
|
||||
dimension int64
|
||||
result *milvuspb.SearchResults
|
||||
request *milvuspb.HybridSearchRequest
|
||||
searchTasks []*searchTask
|
||||
@ -101,6 +102,11 @@ func (t *hybridSearchTask) PreExecute(ctx context.Context) error {
|
||||
log.Warn("get collection schema failed", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
t.dimension, err = typeutil.GetCollectionDim(t.schema.CollectionSchema)
|
||||
if err != nil {
|
||||
log.Warn("get collection dimension failed", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
t.partitionKeyMode, err = isPartitionKeyMode(ctx, t.request.GetDbName(), collectionName)
|
||||
if err != nil {
|
||||
@ -529,6 +535,7 @@ func rankSearchResultData(ctx context.Context,
|
||||
}
|
||||
|
||||
for _, result := range searchResults {
|
||||
ret.Results.AllSearchCount += result.GetResults().GetAllSearchCount()
|
||||
scores := result.GetResults().GetScores()
|
||||
start := int64(0)
|
||||
for i := int64(0); i < nq; i++ {
|
||||
|
@ -54,6 +54,7 @@ type queryTask struct {
|
||||
collectionName string
|
||||
queryParams *queryParams
|
||||
schema *schemaInfo
|
||||
dimension int64
|
||||
|
||||
userOutputFields []string
|
||||
|
||||
@ -65,7 +66,8 @@ type queryTask struct {
|
||||
channelsMvcc map[string]Timestamp
|
||||
fastSkip bool
|
||||
|
||||
reQuery bool
|
||||
reQuery bool
|
||||
allQueryCnt int64
|
||||
}
|
||||
|
||||
type queryParams struct {
|
||||
@ -333,8 +335,17 @@ func (t *queryTask) PreExecute(ctx context.Context) error {
|
||||
t.queryParams = queryParams
|
||||
t.RetrieveRequest.Limit = queryParams.limit + queryParams.offset
|
||||
|
||||
schema, _ := globalMetaCache.GetCollectionSchema(ctx, t.request.GetDbName(), t.collectionName)
|
||||
schema, err := globalMetaCache.GetCollectionSchema(ctx, t.request.GetDbName(), t.collectionName)
|
||||
if err != nil {
|
||||
log.Warn("get collection schema failed", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
t.schema = schema
|
||||
t.dimension, err = typeutil.GetCollectionDim(t.schema.CollectionSchema)
|
||||
if err != nil {
|
||||
log.Warn("get collection dimension failed", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
if t.ids != nil {
|
||||
pkField := ""
|
||||
@ -469,6 +480,7 @@ func (t *queryTask) PostExecute(ctx context.Context) error {
|
||||
var err error
|
||||
|
||||
toReduceResults := make([]*internalpb.RetrieveResults, 0)
|
||||
t.allQueryCnt = 0
|
||||
select {
|
||||
case <-t.TraceCtx().Done():
|
||||
log.Warn("proxy", zap.Int64("Query: wait to finish failed, timeout!, msgID:", t.ID()))
|
||||
@ -477,6 +489,7 @@ func (t *queryTask) PostExecute(ctx context.Context) error {
|
||||
log.Debug("all queries are finished or canceled")
|
||||
t.resultBuf.Range(func(res *internalpb.RetrieveResults) bool {
|
||||
toReduceResults = append(toReduceResults, res)
|
||||
t.allQueryCnt += res.GetAllRetrieveCount()
|
||||
log.Debug("proxy receives one query result", zap.Int64("sourceID", res.GetBase().GetSourceID()))
|
||||
return true
|
||||
})
|
||||
|
@ -472,7 +472,6 @@ func (sched *taskScheduler) processTask(t task, q taskQueue) {
|
||||
|
||||
span.AddEvent("scheduler process PostExecute")
|
||||
err = t.PostExecute(ctx)
|
||||
|
||||
if err != nil {
|
||||
span.RecordError(err)
|
||||
log.Ctx(ctx).Warn("Failed to post-execute task: ", zap.Error(err))
|
||||
|
@ -66,6 +66,7 @@ type searchTask struct {
|
||||
userOutputFields []string
|
||||
|
||||
offset int64
|
||||
dimension int64
|
||||
resultBuf *typeutil.ConcurrentSet[*internalpb.SearchResults]
|
||||
|
||||
qc types.QueryCoordClient
|
||||
@ -304,6 +305,11 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
|
||||
log.Warn("get collection schema failed", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
t.dimension, err = typeutil.GetCollectionDim(t.schema.CollectionSchema)
|
||||
if err != nil {
|
||||
log.Warn("get collection dimension failed", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
t.partitionKeyMode, err = isPartitionKeyMode(ctx, t.request.GetDbName(), collectionName)
|
||||
if err != nil {
|
||||
|
@ -686,15 +686,15 @@ func autoGenPrimaryFieldData(fieldSchema *schemapb.FieldSchema, data interface{}
|
||||
},
|
||||
}
|
||||
case schemapb.DataType_VarChar:
|
||||
strIds := make([]string, len(data))
|
||||
strIDs := make([]string, len(data))
|
||||
for i, v := range data {
|
||||
strIds[i] = strconv.FormatInt(v, 10)
|
||||
strIDs[i] = strconv.FormatInt(v, 10)
|
||||
}
|
||||
fieldData.Field = &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_StringData{
|
||||
StringData: &schemapb.StringArray{
|
||||
Data: strIds,
|
||||
Data: strIDs,
|
||||
},
|
||||
},
|
||||
},
|
||||
@ -903,6 +903,11 @@ func GetCurUserFromContext(ctx context.Context) (string, error) {
|
||||
return username, nil
|
||||
}
|
||||
|
||||
func GetCurUserFromContextOrDefault(ctx context.Context) string {
|
||||
username, _ := GetCurUserFromContext(ctx)
|
||||
return username
|
||||
}
|
||||
|
||||
func GetCurDBNameFromContextOrDefault(ctx context.Context) string {
|
||||
md, ok := metadata.FromIncomingContext(ctx)
|
||||
if !ok {
|
||||
@ -1634,3 +1639,16 @@ func CheckDatabase(ctx context.Context, dbName string) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func SetReportValue(status *commonpb.Status, value int) {
|
||||
if value <= 0 {
|
||||
return
|
||||
}
|
||||
if !merr.Ok(status) {
|
||||
return
|
||||
}
|
||||
if status.ExtraInfo == nil {
|
||||
status.ExtraInfo = make(map[string]string)
|
||||
}
|
||||
status.ExtraInfo["report_value"] = strconv.Itoa(value)
|
||||
}
|
||||
|
@ -12,26 +12,34 @@ type cntReducer struct{}
|
||||
|
||||
func (r *cntReducer) Reduce(ctx context.Context, results []*internalpb.RetrieveResults) (*internalpb.RetrieveResults, error) {
|
||||
cnt := int64(0)
|
||||
allRetrieveCount := int64(0)
|
||||
for _, res := range results {
|
||||
allRetrieveCount += res.GetAllRetrieveCount()
|
||||
c, err := funcutil.CntOfInternalResult(res)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cnt += c
|
||||
}
|
||||
return funcutil.WrapCntToInternalResult(cnt), nil
|
||||
res := funcutil.WrapCntToInternalResult(cnt)
|
||||
res.AllRetrieveCount = allRetrieveCount
|
||||
return res, nil
|
||||
}
|
||||
|
||||
type cntReducerSegCore struct{}
|
||||
|
||||
func (r *cntReducerSegCore) Reduce(ctx context.Context, results []*segcorepb.RetrieveResults) (*segcorepb.RetrieveResults, error) {
|
||||
cnt := int64(0)
|
||||
allRetrieveCount := int64(0)
|
||||
for _, res := range results {
|
||||
allRetrieveCount += res.GetAllRetrieveCount()
|
||||
c, err := funcutil.CntOfSegCoreResult(res)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cnt += c
|
||||
}
|
||||
return funcutil.WrapCntToSegCoreResult(cnt), nil
|
||||
res := funcutil.WrapCntToSegCoreResult(cnt)
|
||||
res.AllRetrieveCount = allRetrieveCount
|
||||
return res, nil
|
||||
}
|
||||
|
@ -130,6 +130,7 @@ func ReduceSearchResultData(ctx context.Context, searchResultData []*schemapb.Se
|
||||
for j := int64(1); j < nq; j++ {
|
||||
resultOffsets[i][j] = resultOffsets[i][j-1] + searchResultData[i].Topks[j-1]
|
||||
}
|
||||
ret.AllSearchCount += searchResultData[i].GetAllSearchCount()
|
||||
}
|
||||
|
||||
var skipDupCnt int64
|
||||
@ -284,6 +285,7 @@ func MergeInternalRetrieveResult(ctx context.Context, retrieveResults []*interna
|
||||
|
||||
validRetrieveResults := []*internalpb.RetrieveResults{}
|
||||
for _, r := range retrieveResults {
|
||||
ret.AllRetrieveCount += r.GetAllRetrieveCount()
|
||||
size := typeutil.GetSizeOfIDs(r.GetIds())
|
||||
if r == nil || len(r.GetFieldsData()) == 0 || size == 0 {
|
||||
continue
|
||||
@ -388,6 +390,7 @@ func MergeSegcoreRetrieveResults(ctx context.Context, retrieveResults []*segcore
|
||||
validRetrieveResults := []*segcorepb.RetrieveResults{}
|
||||
for _, r := range retrieveResults {
|
||||
size := typeutil.GetSizeOfIDs(r.GetIds())
|
||||
ret.AllRetrieveCount += r.GetAllRetrieveCount()
|
||||
if r == nil || len(r.GetOffset()) == 0 || size == 0 {
|
||||
log.Debug("filter out invalid retrieve result")
|
||||
continue
|
||||
|
@ -118,9 +118,10 @@ func retrieveOnSegmentsWithStream(ctx context.Context, segments []Segment, segTy
|
||||
|
||||
if len(result.GetOffset()) != 0 {
|
||||
if err = svr.Send(&internalpb.RetrieveResults{
|
||||
Status: merr.Success(),
|
||||
Ids: result.GetIds(),
|
||||
FieldsData: result.GetFieldsData(),
|
||||
Status: merr.Success(),
|
||||
Ids: result.GetIds(),
|
||||
FieldsData: result.GetFieldsData(),
|
||||
AllRetrieveCount: result.GetAllRetrieveCount(),
|
||||
}); err != nil {
|
||||
errs[i] = err
|
||||
}
|
||||
|
@ -626,7 +626,7 @@ func (s *LocalSegment) Insert(ctx context.Context, rowIDs []int64, timestamps []
|
||||
numOfRow := len(rowIDs)
|
||||
cOffset := C.int64_t(offset)
|
||||
cNumOfRows := C.int64_t(numOfRow)
|
||||
cEntityIdsPtr := (*C.int64_t)(&(rowIDs)[0])
|
||||
cEntityIDsPtr := (*C.int64_t)(&(rowIDs)[0])
|
||||
cTimestampsPtr := (*C.uint64_t)(&(timestamps)[0])
|
||||
|
||||
var status C.CStatus
|
||||
@ -635,7 +635,7 @@ func (s *LocalSegment) Insert(ctx context.Context, rowIDs []int64, timestamps []
|
||||
status = C.Insert(s.ptr,
|
||||
cOffset,
|
||||
cNumOfRows,
|
||||
cEntityIdsPtr,
|
||||
cEntityIDsPtr,
|
||||
cTimestampsPtr,
|
||||
(*C.uint8_t)(unsafe.Pointer(&insertRecordBlob[0])),
|
||||
(C.uint64_t)(len(insertRecordBlob)),
|
||||
|
@ -135,6 +135,7 @@ func (t *QueryTask) Execute() error {
|
||||
CostAggregation: &internalpb.CostAggregation{
|
||||
ServiceTime: tr.ElapseSpan().Milliseconds(),
|
||||
},
|
||||
AllRetrieveCount: reducedResult.GetAllRetrieveCount(),
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -48,10 +48,10 @@ func checkGeneralCapacity(ctx context.Context, newColNum int,
|
||||
|
||||
var generalNum int64 = 0
|
||||
collectionsMap := core.meta.ListAllAvailCollections(ctx)
|
||||
for dbId, collectionIds := range collectionsMap {
|
||||
for dbId, collectionIDs := range collectionsMap {
|
||||
db, err := core.meta.GetDatabaseByID(ctx, dbId, ts)
|
||||
if err == nil {
|
||||
for _, collectionId := range collectionIds {
|
||||
for _, collectionId := range collectionIDs {
|
||||
collection, err := core.meta.GetCollectionByID(ctx, db.Name, collectionId, ts, true)
|
||||
if err == nil {
|
||||
partNum := int64(collection.GetPartitionNum(false))
|
||||
|
@ -164,7 +164,7 @@ func TransferColumnBasedInsertDataToRowBased(data *InsertData) (
|
||||
}
|
||||
|
||||
tss := data.Data[common.TimeStampField].(*Int64FieldData)
|
||||
rowIds := data.Data[common.RowIDField].(*Int64FieldData)
|
||||
rowIDs := data.Data[common.RowIDField].(*Int64FieldData)
|
||||
|
||||
ls := fieldDataList{}
|
||||
for fieldID := range data.Data {
|
||||
@ -176,8 +176,8 @@ func TransferColumnBasedInsertDataToRowBased(data *InsertData) (
|
||||
ls.datas = append(ls.datas, data.Data[fieldID])
|
||||
}
|
||||
|
||||
// checkNumRows(tss, rowIds, ls.datas...) // don't work
|
||||
all := []FieldData{tss, rowIds}
|
||||
// checkNumRows(tss, rowIDs, ls.datas...) // don't work
|
||||
all := []FieldData{tss, rowIDs}
|
||||
all = append(all, ls.datas...)
|
||||
if !checkNumRows(all...) {
|
||||
return nil, nil, nil,
|
||||
@ -210,7 +210,7 @@ func TransferColumnBasedInsertDataToRowBased(data *InsertData) (
|
||||
utss[i] = uint64(tss.Data[i])
|
||||
}
|
||||
|
||||
return utss, rowIds.Data, rows, nil
|
||||
return utss, rowIDs.Data, rows, nil
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -125,10 +125,10 @@ func TestTransferColumnBasedInsertDataToRowBased(t *testing.T) {
|
||||
_, _, _, err = TransferColumnBasedInsertDataToRowBased(data)
|
||||
assert.Error(t, err)
|
||||
|
||||
rowIdsF := &Int64FieldData{
|
||||
rowIDsF := &Int64FieldData{
|
||||
Data: []int64{1, 2, 3, 4},
|
||||
}
|
||||
data.Data[common.RowIDField] = rowIdsF
|
||||
data.Data[common.RowIDField] = rowIDsF
|
||||
|
||||
// row num mismatch
|
||||
_, _, _, err = TransferColumnBasedInsertDataToRowBased(data)
|
||||
@ -193,10 +193,10 @@ func TestTransferColumnBasedInsertDataToRowBased(t *testing.T) {
|
||||
data.Data[111] = f11
|
||||
data.Data[112] = f12
|
||||
|
||||
utss, rowIds, rows, err := TransferColumnBasedInsertDataToRowBased(data)
|
||||
utss, rowIDs, rows, err := TransferColumnBasedInsertDataToRowBased(data)
|
||||
assert.NoError(t, err)
|
||||
assert.ElementsMatch(t, []uint64{1, 2, 3}, utss)
|
||||
assert.ElementsMatch(t, []int64{1, 2, 3}, rowIds)
|
||||
assert.ElementsMatch(t, []int64{1, 2, 3}, rowIDs)
|
||||
assert.Equal(t, 3, len(rows))
|
||||
// b := []byte("1")[0]
|
||||
if common.Endian == binary.LittleEndian {
|
||||
|
43
internal/util/hookutil/constant.go
Normal file
43
internal/util/hookutil/constant.go
Normal file
@ -0,0 +1,43 @@
|
||||
/*
|
||||
* Licensed to the LF AI & Data foundation under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package hookutil
|
||||
|
||||
var (
|
||||
// WARN: Please DO NOT modify all constants.
|
||||
|
||||
OpTypeKey = "op_type"
|
||||
DatabaseKey = "database"
|
||||
UsernameKey = "username"
|
||||
DataSizeKey = "data_size"
|
||||
SuccessCntKey = "success_cnt"
|
||||
FailCntKey = "fail_cnt"
|
||||
RelatedCntKey = "related_cnt"
|
||||
StorageDetailKey = "storage_detail"
|
||||
NodeIDKey = "id"
|
||||
DimensionKey = "dim"
|
||||
|
||||
OpTypeInsert = "insert"
|
||||
OpTypeDelete = "delete"
|
||||
OpTypeUpsert = "upsert"
|
||||
OpTypeQuery = "query"
|
||||
OpTypeSearch = "search"
|
||||
OpTypeHybridSearch = "hybrid_search"
|
||||
OpTypeStorage = "storage"
|
||||
OpTypeNodeID = "node_id"
|
||||
)
|
72
internal/util/hookutil/default.go
Normal file
72
internal/util/hookutil/default.go
Normal file
@ -0,0 +1,72 @@
|
||||
/*
|
||||
* Licensed to the LF AI & Data foundation under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package hookutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/hook"
|
||||
)
|
||||
|
||||
type DefaultHook struct{}
|
||||
|
||||
var _ hook.Hook = (*DefaultHook)(nil)
|
||||
|
||||
func (d DefaultHook) VerifyAPIKey(key string) (string, error) {
|
||||
return "", errors.New("default hook, can't verify api key")
|
||||
}
|
||||
|
||||
func (d DefaultHook) Init(params map[string]string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d DefaultHook) Mock(ctx context.Context, req interface{}, fullMethod string) (bool, interface{}, error) {
|
||||
return false, nil, nil
|
||||
}
|
||||
|
||||
func (d DefaultHook) Before(ctx context.Context, req interface{}, fullMethod string) (context.Context, error) {
|
||||
return ctx, nil
|
||||
}
|
||||
|
||||
func (d DefaultHook) After(ctx context.Context, result interface{}, err error, fullMethod string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// MockAPIHook is a mock hook for api key verification, ONLY FOR TEST
|
||||
type MockAPIHook struct {
|
||||
DefaultHook
|
||||
MockErr error
|
||||
User string
|
||||
}
|
||||
|
||||
func (m MockAPIHook) VerifyAPIKey(apiKey string) (string, error) {
|
||||
return m.User, m.MockErr
|
||||
}
|
||||
|
||||
func (d DefaultHook) Release() {}
|
||||
|
||||
type DefaultExtension struct{}
|
||||
|
||||
var _ hook.Extension = (*DefaultExtension)(nil)
|
||||
|
||||
func (d DefaultExtension) Report(info any) int {
|
||||
return 0
|
||||
}
|
102
internal/util/hookutil/hook.go
Normal file
102
internal/util/hookutil/hook.go
Normal file
@ -0,0 +1,102 @@
|
||||
/*
|
||||
* Licensed to the LF AI & Data foundation under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package hookutil
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"plugin"
|
||||
"sync"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/hook"
|
||||
"github.com/milvus-io/milvus/pkg/config"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
)
|
||||
|
||||
var (
|
||||
Hoo hook.Hook
|
||||
Extension hook.Extension
|
||||
initOnce sync.Once
|
||||
)
|
||||
|
||||
func initHook() error {
|
||||
Hoo = DefaultHook{}
|
||||
Extension = DefaultExtension{}
|
||||
|
||||
path := paramtable.Get().ProxyCfg.SoPath.GetValue()
|
||||
if path == "" {
|
||||
log.Info("empty so path, skip to load plugin")
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Info("start to load plugin", zap.String("path", path))
|
||||
p, err := plugin.Open(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("fail to open the plugin, error: %s", err.Error())
|
||||
}
|
||||
log.Info("plugin open")
|
||||
|
||||
h, err := p.Lookup("MilvusHook")
|
||||
if err != nil {
|
||||
return fmt.Errorf("fail to the 'MilvusHook' object in the plugin, error: %s", err.Error())
|
||||
}
|
||||
|
||||
var ok bool
|
||||
Hoo, ok = h.(hook.Hook)
|
||||
if !ok {
|
||||
return fmt.Errorf("fail to convert the `Hook` interface")
|
||||
}
|
||||
if err = Hoo.Init(paramtable.GetHookParams().SoConfig.GetValue()); err != nil {
|
||||
return fmt.Errorf("fail to init configs for the hook, error: %s", err.Error())
|
||||
}
|
||||
paramtable.GetHookParams().WatchHookWithPrefix("watch_hook", "", func(event *config.Event) {
|
||||
log.Info("receive the hook refresh event", zap.Any("event", event))
|
||||
go func() {
|
||||
soConfig := paramtable.GetHookParams().SoConfig.GetValue()
|
||||
log.Info("refresh hook configs", zap.Any("config", soConfig))
|
||||
if err = Hoo.Init(soConfig); err != nil {
|
||||
log.Panic("fail to init configs for the hook when refreshing", zap.Error(err))
|
||||
}
|
||||
}()
|
||||
})
|
||||
|
||||
e, err := p.Lookup("MilvusExtension")
|
||||
if err != nil {
|
||||
return fmt.Errorf("fail to the 'MilvusExtension' object in the plugin, error: %s", err.Error())
|
||||
}
|
||||
Extension, ok = e.(hook.Extension)
|
||||
if !ok {
|
||||
return fmt.Errorf("fail to convert the `Extension` interface")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func InitOnceHook() {
|
||||
initOnce.Do(func() {
|
||||
err := initHook()
|
||||
if err != nil {
|
||||
log.Warn("fail to init hook",
|
||||
zap.String("so_path", paramtable.Get().ProxyCfg.SoPath.GetValue()),
|
||||
zap.Error(err))
|
||||
}
|
||||
})
|
||||
}
|
52
internal/util/hookutil/hook_test.go
Normal file
52
internal/util/hookutil/hook_test.go
Normal file
@ -0,0 +1,52 @@
|
||||
/*
|
||||
* Licensed to the LF AI & Data foundation under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package hookutil
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
)
|
||||
|
||||
func TestInitHook(t *testing.T) {
|
||||
paramtable.Init()
|
||||
Params := paramtable.Get()
|
||||
paramtable.Get().Save(Params.ProxyCfg.SoPath.Key, "")
|
||||
initHook()
|
||||
assert.IsType(t, DefaultHook{}, Hoo)
|
||||
|
||||
paramtable.Get().Save(Params.ProxyCfg.SoPath.Key, "/a/b/hook.so")
|
||||
err := initHook()
|
||||
assert.Error(t, err)
|
||||
paramtable.Get().Save(Params.ProxyCfg.SoPath.Key, "")
|
||||
}
|
||||
|
||||
func TestDefaultHook(t *testing.T) {
|
||||
d := &DefaultHook{}
|
||||
assert.NoError(t, d.Init(nil))
|
||||
{
|
||||
_, err := d.VerifyAPIKey("key")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
assert.NotPanics(t, func() {
|
||||
d.Release()
|
||||
})
|
||||
}
|
@ -13,7 +13,7 @@ func appendFieldData(result RetrieveResults, fieldData *schemapb.FieldData) {
|
||||
result.AppendFieldData(fieldData)
|
||||
}
|
||||
|
||||
func FillRetrieveResultIfEmpty(result RetrieveResults, outputFieldIds []int64, schema *schemapb.CollectionSchema) error {
|
||||
func FillRetrieveResultIfEmpty(result RetrieveResults, outputFieldIDs []int64, schema *schemapb.CollectionSchema) error {
|
||||
if !result.ResultEmpty() {
|
||||
return nil
|
||||
}
|
||||
@ -24,7 +24,7 @@ func FillRetrieveResultIfEmpty(result RetrieveResults, outputFieldIds []int64, s
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, outputFieldID := range outputFieldIds {
|
||||
for _, outputFieldID := range outputFieldIDs {
|
||||
field, err := helper.GetFieldFromID(outputFieldID)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -14,7 +14,7 @@ require (
|
||||
github.com/grpc-ecosystem/go-grpc-middleware v1.3.0
|
||||
github.com/klauspost/compress v1.16.5
|
||||
github.com/lingdor/stackerror v0.0.0-20191119040541-976d8885ed76
|
||||
github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240228061649-a922b16f2a46
|
||||
github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240313071055-2c89f346b00f
|
||||
github.com/nats-io/nats-server/v2 v2.9.17
|
||||
github.com/nats-io/nats.go v1.24.0
|
||||
github.com/panjf2000/ants/v2 v2.7.2
|
||||
|
@ -483,8 +483,8 @@ github.com/matttproud/golang_protobuf_extensions v1.0.4/go.mod h1:BSXmuO+STAnVfr
|
||||
github.com/mediocregopher/radix/v3 v3.4.2/go.mod h1:8FL3F6UQRXHXIBSPUs5h0RybMF8i4n7wVopoX3x7Bv8=
|
||||
github.com/microcosm-cc/bluemonday v1.0.2/go.mod h1:iVP4YcDBq+n/5fb23BhYFvIMq/leAFZyRl6bYmGDlGc=
|
||||
github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg=
|
||||
github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240228061649-a922b16f2a46 h1:IgoGNTbsRPa2kdNI+IWuZrrortFEjTB42/gYDklZHVU=
|
||||
github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240228061649-a922b16f2a46/go.mod h1:1OIl0v5PQeNxIJhCvY+K55CBUOYDZevw9g9380u1Wek=
|
||||
github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240313071055-2c89f346b00f h1:f8rRJ5zatNq2WszAwy6S+J0Z2h7/CArqLDJ0gTHSFNs=
|
||||
github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240313071055-2c89f346b00f/go.mod h1:1OIl0v5PQeNxIJhCvY+K55CBUOYDZevw9g9380u1Wek=
|
||||
github.com/milvus-io/pulsar-client-go v0.6.10 h1:eqpJjU+/QX0iIhEo3nhOqMNXL+TyInAs1IAHZCrCM/A=
|
||||
github.com/milvus-io/pulsar-client-go v0.6.10/go.mod h1:lQqCkgwDF8YFYjKA+zOheTk1tev2B+bKj5j7+nm8M1w=
|
||||
github.com/minio/highwayhash v1.0.2 h1:Aak5U0nElisjDCfPSG79Tgzkn2gl66NxOMspRrKnA/g=
|
||||
|
@ -27,6 +27,9 @@ const (
|
||||
|
||||
// SystemInfoMetrics means users request for system information metrics.
|
||||
SystemInfoMetrics = "system_info"
|
||||
|
||||
// CollectionStorageMetrics means users request for collection storage metrics.
|
||||
CollectionStorageMetrics = "collection_storage"
|
||||
)
|
||||
|
||||
// ParseMetricType returns the metric type of req
|
||||
|
@ -27,3 +27,14 @@ func GetDim(field *schemapb.FieldSchema) (int64, error) {
|
||||
}
|
||||
return int64(dim), nil
|
||||
}
|
||||
|
||||
func GetCollectionDim(collection *schemapb.CollectionSchema) (int64, error) {
|
||||
for _, fieldSchema := range collection.GetFields() {
|
||||
dim, err := GetDim(fieldSchema)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
return dim, nil
|
||||
}
|
||||
return 0, fmt.Errorf("dim not found")
|
||||
}
|
||||
|
@ -20,6 +20,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/stretchr/testify/suite"
|
||||
@ -28,6 +29,7 @@ import (
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/util/hookutil"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
@ -77,6 +79,27 @@ func (s *HelloMilvusSuite) TestHelloMilvus() {
|
||||
|
||||
fVecColumn := integration.NewFloatVectorFieldData(integration.FloatVecField, rowNum, dim)
|
||||
hashKeys := integration.GenerateHashKeys(rowNum)
|
||||
insertCheckReport := func() {
|
||||
timeoutCtx, cancelFunc := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancelFunc()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-timeoutCtx.Done():
|
||||
s.Fail("insert check timeout")
|
||||
case report := <-c.Extension.GetReportChan():
|
||||
reportInfo := report.(map[string]any)
|
||||
log.Info("insert report info", zap.Any("reportInfo", reportInfo))
|
||||
if reportInfo[hookutil.OpTypeKey] == hookutil.OpTypeStorage {
|
||||
continue
|
||||
}
|
||||
s.Equal(hookutil.OpTypeInsert, reportInfo[hookutil.OpTypeKey])
|
||||
s.NotEqualValues(0, reportInfo[hookutil.DataSizeKey])
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
go insertCheckReport()
|
||||
insertResult, err := c.Proxy.Insert(ctx, &milvuspb.InsertRequest{
|
||||
DbName: dbName,
|
||||
CollectionName: collectionName,
|
||||
@ -145,11 +168,99 @@ func (s *HelloMilvusSuite) TestHelloMilvus() {
|
||||
searchReq := integration.ConstructSearchRequest("", collectionName, expr,
|
||||
integration.FloatVecField, schemapb.DataType_FloatVector, nil, metric.L2, params, nq, dim, topk, roundDecimal)
|
||||
|
||||
searchResult, err := c.Proxy.Search(ctx, searchReq)
|
||||
searchCheckReport := func() {
|
||||
timeoutCtx, cancelFunc := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancelFunc()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-timeoutCtx.Done():
|
||||
s.Fail("search check timeout")
|
||||
case report := <-c.Extension.GetReportChan():
|
||||
reportInfo := report.(map[string]any)
|
||||
log.Info("search report info", zap.Any("reportInfo", reportInfo))
|
||||
if reportInfo[hookutil.OpTypeKey] == hookutil.OpTypeStorage {
|
||||
continue
|
||||
}
|
||||
s.Equal(hookutil.OpTypeSearch, reportInfo[hookutil.OpTypeKey])
|
||||
s.NotEqualValues(0, reportInfo[hookutil.DataSizeKey])
|
||||
s.EqualValues(rowNum, reportInfo[hookutil.RelatedCntKey])
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
go searchCheckReport()
|
||||
searchResult, err := c.Proxy.Search(ctx, searchReq)
|
||||
err = merr.CheckRPCCall(searchResult, err)
|
||||
s.NoError(err)
|
||||
|
||||
queryCheckReport := func() {
|
||||
timeoutCtx, cancelFunc := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancelFunc()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-timeoutCtx.Done():
|
||||
s.Fail("query check timeout")
|
||||
case report := <-c.Extension.GetReportChan():
|
||||
reportInfo := report.(map[string]any)
|
||||
log.Info("query report info", zap.Any("reportInfo", reportInfo))
|
||||
if reportInfo[hookutil.OpTypeKey] == hookutil.OpTypeStorage {
|
||||
continue
|
||||
}
|
||||
s.Equal(hookutil.OpTypeQuery, reportInfo[hookutil.OpTypeKey])
|
||||
s.NotEqualValues(0, reportInfo[hookutil.DataSizeKey])
|
||||
s.EqualValues(rowNum, reportInfo[hookutil.RelatedCntKey])
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
go queryCheckReport()
|
||||
queryResult, err := c.Proxy.Query(ctx, &milvuspb.QueryRequest{
|
||||
DbName: dbName,
|
||||
CollectionName: collectionName,
|
||||
Expr: "",
|
||||
OutputFields: []string{"count(*)"},
|
||||
})
|
||||
if queryResult.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
|
||||
log.Warn("searchResult fail reason", zap.String("reason", queryResult.GetStatus().GetReason()))
|
||||
}
|
||||
s.NoError(err)
|
||||
s.Equal(commonpb.ErrorCode_Success, queryResult.GetStatus().GetErrorCode())
|
||||
|
||||
deleteCheckReport := func() {
|
||||
timeoutCtx, cancelFunc := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancelFunc()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-timeoutCtx.Done():
|
||||
s.Fail("delete check timeout")
|
||||
case report := <-c.Extension.GetReportChan():
|
||||
reportInfo := report.(map[string]any)
|
||||
log.Info("delete report info", zap.Any("reportInfo", reportInfo))
|
||||
if reportInfo[hookutil.OpTypeKey] == hookutil.OpTypeStorage {
|
||||
continue
|
||||
}
|
||||
s.Equal(hookutil.OpTypeDelete, reportInfo[hookutil.OpTypeKey])
|
||||
s.EqualValues(2, reportInfo[hookutil.SuccessCntKey])
|
||||
s.EqualValues(0, reportInfo[hookutil.RelatedCntKey])
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
go deleteCheckReport()
|
||||
deleteResult, err := c.Proxy.Delete(ctx, &milvuspb.DeleteRequest{
|
||||
DbName: dbName,
|
||||
CollectionName: collectionName,
|
||||
Expr: integration.Int64Field + " in [1, 2]",
|
||||
})
|
||||
if deleteResult.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
|
||||
log.Warn("deleteResult fail reason", zap.String("reason", deleteResult.GetStatus().GetReason()))
|
||||
}
|
||||
s.NoError(err)
|
||||
s.Equal(commonpb.ErrorCode_Success, deleteResult.GetStatus().GetErrorCode())
|
||||
|
||||
status, err := c.Proxy.ReleaseCollection(ctx, &milvuspb.ReleaseCollectionRequest{
|
||||
CollectionName: collectionName,
|
||||
})
|
||||
|
@ -47,6 +47,7 @@ import (
|
||||
"github.com/milvus-io/milvus/internal/storage"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/internal/util/dependency"
|
||||
"github.com/milvus-io/milvus/internal/util/hookutil"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/util/etcd"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
@ -125,6 +126,8 @@ type MiniClusterV2 struct {
|
||||
qnid atomic.Int64
|
||||
datanodes []*grpcdatanode.Server
|
||||
dnid atomic.Int64
|
||||
|
||||
Extension *ReportChanExtension
|
||||
}
|
||||
|
||||
type OptionV2 func(cluster *MiniClusterV2)
|
||||
@ -136,6 +139,8 @@ func StartMiniClusterV2(ctx context.Context, opts ...OptionV2) (*MiniClusterV2,
|
||||
dnid: *atomic.NewInt64(20000),
|
||||
}
|
||||
paramtable.Init()
|
||||
cluster.Extension = InitReportExtension()
|
||||
|
||||
cluster.params = DefaultParams()
|
||||
cluster.clusterConfig = DefaultClusterConfig()
|
||||
for _, opt := range opts {
|
||||
@ -445,3 +450,32 @@ func (cluster *MiniClusterV2) GetAvailablePort() (int, error) {
|
||||
defer listener.Close()
|
||||
return listener.Addr().(*net.TCPAddr).Port, nil
|
||||
}
|
||||
|
||||
func InitReportExtension() *ReportChanExtension {
|
||||
e := NewReportChanExtension()
|
||||
hookutil.InitOnceHook()
|
||||
hookutil.Extension = e
|
||||
return e
|
||||
}
|
||||
|
||||
type ReportChanExtension struct {
|
||||
reportChan chan any
|
||||
}
|
||||
|
||||
func NewReportChanExtension() *ReportChanExtension {
|
||||
return &ReportChanExtension{
|
||||
reportChan: make(chan any),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *ReportChanExtension) Report(info any) int {
|
||||
select {
|
||||
case r.reportChan <- info:
|
||||
default:
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
func (r *ReportChanExtension) GetReportChan() <-chan any {
|
||||
return r.reportChan
|
||||
}
|
||||
|
411
tests/integration/partitionkey/partition_key_test.go
Normal file
411
tests/integration/partitionkey/partition_key_test.go
Normal file
@ -0,0 +1,411 @@
|
||||
/*
|
||||
* Licensed to the LF AI & Data foundation under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package partitionkey
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/util/hookutil"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||
"github.com/milvus-io/milvus/tests/integration"
|
||||
)
|
||||
|
||||
type PartitionKeySuite struct {
|
||||
integration.MiniClusterSuite
|
||||
}
|
||||
|
||||
func (s *PartitionKeySuite) TestPartitionKey() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
c := s.Cluster
|
||||
|
||||
const (
|
||||
dim = 128
|
||||
dbName = ""
|
||||
rowNum = 1000
|
||||
)
|
||||
|
||||
collectionName := "TestPartitionKey" + funcutil.GenRandomStr()
|
||||
schema := integration.ConstructSchema(collectionName, dim, false)
|
||||
schema.Fields = append(schema.Fields, &schemapb.FieldSchema{
|
||||
FieldID: 102,
|
||||
Name: "pid",
|
||||
Description: "",
|
||||
DataType: schemapb.DataType_Int64,
|
||||
TypeParams: nil,
|
||||
IndexParams: nil,
|
||||
IsPartitionKey: true,
|
||||
})
|
||||
marshaledSchema, err := proto.Marshal(schema)
|
||||
s.NoError(err)
|
||||
|
||||
createCollectionStatus, err := c.Proxy.CreateCollection(ctx, &milvuspb.CreateCollectionRequest{
|
||||
DbName: dbName,
|
||||
CollectionName: collectionName,
|
||||
Schema: marshaledSchema,
|
||||
ShardsNum: common.DefaultShardsNum,
|
||||
})
|
||||
s.NoError(err)
|
||||
if createCollectionStatus.GetErrorCode() != commonpb.ErrorCode_Success {
|
||||
log.Warn("createCollectionStatus fail reason", zap.String("reason", createCollectionStatus.GetReason()))
|
||||
}
|
||||
s.Equal(createCollectionStatus.GetErrorCode(), commonpb.ErrorCode_Success)
|
||||
|
||||
{
|
||||
pkColumn := integration.NewInt64FieldDataWithStart(integration.Int64Field, rowNum, 0)
|
||||
fVecColumn := integration.NewFloatVectorFieldData(integration.FloatVecField, rowNum, dim)
|
||||
partitionKeyColumn := integration.NewInt64SameFieldData("pid", rowNum, 1)
|
||||
hashKeys := integration.GenerateHashKeys(rowNum)
|
||||
insertResult, err := c.Proxy.Insert(ctx, &milvuspb.InsertRequest{
|
||||
DbName: dbName,
|
||||
CollectionName: collectionName,
|
||||
FieldsData: []*schemapb.FieldData{pkColumn, fVecColumn, partitionKeyColumn},
|
||||
HashKeys: hashKeys,
|
||||
NumRows: uint32(rowNum),
|
||||
})
|
||||
s.NoError(err)
|
||||
s.Equal(insertResult.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success)
|
||||
}
|
||||
|
||||
{
|
||||
pkColumn := integration.NewInt64FieldDataWithStart(integration.Int64Field, rowNum, rowNum)
|
||||
fVecColumn := integration.NewFloatVectorFieldData(integration.FloatVecField, rowNum, dim)
|
||||
partitionKeyColumn := integration.NewInt64SameFieldData("pid", rowNum, 2)
|
||||
hashKeys := integration.GenerateHashKeys(rowNum)
|
||||
insertResult, err := c.Proxy.Insert(ctx, &milvuspb.InsertRequest{
|
||||
DbName: dbName,
|
||||
CollectionName: collectionName,
|
||||
FieldsData: []*schemapb.FieldData{pkColumn, fVecColumn, partitionKeyColumn},
|
||||
HashKeys: hashKeys,
|
||||
NumRows: uint32(rowNum),
|
||||
})
|
||||
s.NoError(err)
|
||||
s.Equal(insertResult.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success)
|
||||
}
|
||||
|
||||
{
|
||||
pkColumn := integration.NewInt64FieldDataWithStart(integration.Int64Field, rowNum, rowNum*2)
|
||||
fVecColumn := integration.NewFloatVectorFieldData(integration.FloatVecField, rowNum, dim)
|
||||
partitionKeyColumn := integration.NewInt64SameFieldData("pid", rowNum, 3)
|
||||
hashKeys := integration.GenerateHashKeys(rowNum)
|
||||
insertResult, err := c.Proxy.Insert(ctx, &milvuspb.InsertRequest{
|
||||
DbName: dbName,
|
||||
CollectionName: collectionName,
|
||||
FieldsData: []*schemapb.FieldData{pkColumn, fVecColumn, partitionKeyColumn},
|
||||
HashKeys: hashKeys,
|
||||
NumRows: uint32(rowNum),
|
||||
})
|
||||
s.NoError(err)
|
||||
s.Equal(insertResult.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success)
|
||||
}
|
||||
|
||||
flushResp, err := c.Proxy.Flush(ctx, &milvuspb.FlushRequest{
|
||||
DbName: dbName,
|
||||
CollectionNames: []string{collectionName},
|
||||
})
|
||||
s.NoError(err)
|
||||
segmentIDs, has := flushResp.GetCollSegIDs()[collectionName]
|
||||
ids := segmentIDs.GetData()
|
||||
s.Require().NotEmpty(segmentIDs)
|
||||
s.Require().True(has)
|
||||
flushTs, has := flushResp.GetCollFlushTs()[collectionName]
|
||||
s.True(has)
|
||||
|
||||
segments, err := c.MetaWatcher.ShowSegments()
|
||||
s.NoError(err)
|
||||
s.NotEmpty(segments)
|
||||
for _, segment := range segments {
|
||||
log.Info("ShowSegments result", zap.String("segment", segment.String()))
|
||||
}
|
||||
s.WaitForFlush(ctx, ids, flushTs, dbName, collectionName)
|
||||
|
||||
// create index
|
||||
createIndexStatus, err := c.Proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{
|
||||
CollectionName: collectionName,
|
||||
FieldName: integration.FloatVecField,
|
||||
IndexName: "_default",
|
||||
ExtraParams: integration.ConstructIndexParam(dim, integration.IndexFaissIvfFlat, metric.L2),
|
||||
})
|
||||
if createIndexStatus.GetErrorCode() != commonpb.ErrorCode_Success {
|
||||
log.Warn("createIndexStatus fail reason", zap.String("reason", createIndexStatus.GetReason()))
|
||||
}
|
||||
s.NoError(err)
|
||||
s.Equal(commonpb.ErrorCode_Success, createIndexStatus.GetErrorCode())
|
||||
|
||||
s.WaitForIndexBuilt(ctx, collectionName, integration.FloatVecField)
|
||||
|
||||
// load
|
||||
loadStatus, err := c.Proxy.LoadCollection(ctx, &milvuspb.LoadCollectionRequest{
|
||||
DbName: dbName,
|
||||
CollectionName: collectionName,
|
||||
})
|
||||
s.NoError(err)
|
||||
if loadStatus.GetErrorCode() != commonpb.ErrorCode_Success {
|
||||
log.Warn("loadStatus fail reason", zap.String("reason", loadStatus.GetReason()))
|
||||
}
|
||||
s.Equal(commonpb.ErrorCode_Success, loadStatus.GetErrorCode())
|
||||
s.WaitForLoad(ctx, collectionName)
|
||||
|
||||
{
|
||||
// search without partition key
|
||||
expr := fmt.Sprintf("%s > 0", integration.Int64Field)
|
||||
nq := 10
|
||||
topk := 10
|
||||
roundDecimal := -1
|
||||
|
||||
params := integration.GetSearchParams(integration.IndexFaissIvfFlat, metric.L2)
|
||||
searchReq := integration.ConstructSearchRequest("", collectionName, expr,
|
||||
integration.FloatVecField, schemapb.DataType_FloatVector, nil, metric.L2, params, nq, dim, topk, roundDecimal)
|
||||
|
||||
searchCheckReport := func() {
|
||||
timeoutCtx, cancelFunc := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancelFunc()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-timeoutCtx.Done():
|
||||
s.Fail("search check timeout")
|
||||
case report := <-c.Extension.GetReportChan():
|
||||
reportInfo := report.(map[string]any)
|
||||
log.Info("search report info", zap.Any("reportInfo", reportInfo))
|
||||
if reportInfo[hookutil.OpTypeKey] == hookutil.OpTypeStorage {
|
||||
continue
|
||||
}
|
||||
s.Equal(hookutil.OpTypeSearch, reportInfo[hookutil.OpTypeKey])
|
||||
s.NotEqualValues(0, reportInfo[hookutil.DataSizeKey])
|
||||
s.EqualValues(rowNum*3, reportInfo[hookutil.RelatedCntKey])
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
go searchCheckReport()
|
||||
searchResult, err := c.Proxy.Search(ctx, searchReq)
|
||||
|
||||
if searchResult.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
|
||||
log.Warn("searchResult fail reason", zap.String("reason", searchResult.GetStatus().GetReason()))
|
||||
}
|
||||
s.NoError(err)
|
||||
s.Equal(commonpb.ErrorCode_Success, searchResult.GetStatus().GetErrorCode())
|
||||
}
|
||||
|
||||
{
|
||||
// search without partition key
|
||||
expr := fmt.Sprintf("%s > 0 && pid == 1", integration.Int64Field)
|
||||
nq := 10
|
||||
topk := 10
|
||||
roundDecimal := -1
|
||||
|
||||
params := integration.GetSearchParams(integration.IndexFaissIvfFlat, metric.L2)
|
||||
searchReq := integration.ConstructSearchRequest("", collectionName, expr,
|
||||
integration.FloatVecField, schemapb.DataType_FloatVector, nil, metric.L2, params, nq, dim, topk, roundDecimal)
|
||||
|
||||
searchCheckReport := func() {
|
||||
timeoutCtx, cancelFunc := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancelFunc()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-timeoutCtx.Done():
|
||||
s.Fail("search check timeout")
|
||||
case report := <-c.Extension.GetReportChan():
|
||||
reportInfo := report.(map[string]any)
|
||||
log.Info("search report info", zap.Any("reportInfo", reportInfo))
|
||||
if reportInfo[hookutil.OpTypeKey] == hookutil.OpTypeStorage {
|
||||
continue
|
||||
}
|
||||
s.Equal(hookutil.OpTypeSearch, reportInfo[hookutil.OpTypeKey])
|
||||
s.NotEqualValues(0, reportInfo[hookutil.DataSizeKey])
|
||||
s.EqualValues(rowNum, reportInfo[hookutil.RelatedCntKey])
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
go searchCheckReport()
|
||||
searchResult, err := c.Proxy.Search(ctx, searchReq)
|
||||
|
||||
if searchResult.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
|
||||
log.Warn("searchResult fail reason", zap.String("reason", searchResult.GetStatus().GetReason()))
|
||||
}
|
||||
s.NoError(err)
|
||||
s.Equal(commonpb.ErrorCode_Success, searchResult.GetStatus().GetErrorCode())
|
||||
}
|
||||
|
||||
{
|
||||
// query without partition key
|
||||
queryCheckReport := func() {
|
||||
timeoutCtx, cancelFunc := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancelFunc()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-timeoutCtx.Done():
|
||||
s.Fail("query check timeout")
|
||||
case report := <-c.Extension.GetReportChan():
|
||||
reportInfo := report.(map[string]any)
|
||||
log.Info("query report info", zap.Any("reportInfo", reportInfo))
|
||||
if reportInfo[hookutil.OpTypeKey] == hookutil.OpTypeStorage {
|
||||
continue
|
||||
}
|
||||
s.Equal(hookutil.OpTypeQuery, reportInfo[hookutil.OpTypeKey])
|
||||
s.NotEqualValues(0, reportInfo[hookutil.DataSizeKey])
|
||||
s.EqualValues(3*rowNum, reportInfo[hookutil.RelatedCntKey])
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
go queryCheckReport()
|
||||
queryResult, err := c.Proxy.Query(ctx, &milvuspb.QueryRequest{
|
||||
DbName: dbName,
|
||||
CollectionName: collectionName,
|
||||
Expr: "",
|
||||
OutputFields: []string{"count(*)"},
|
||||
})
|
||||
if queryResult.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
|
||||
log.Warn("searchResult fail reason", zap.String("reason", queryResult.GetStatus().GetReason()))
|
||||
}
|
||||
s.NoError(err)
|
||||
s.Equal(commonpb.ErrorCode_Success, queryResult.GetStatus().GetErrorCode())
|
||||
}
|
||||
|
||||
{
|
||||
// query with partition key
|
||||
queryCheckReport := func() {
|
||||
timeoutCtx, cancelFunc := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancelFunc()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-timeoutCtx.Done():
|
||||
s.Fail("query check timeout")
|
||||
case report := <-c.Extension.GetReportChan():
|
||||
reportInfo := report.(map[string]any)
|
||||
log.Info("query report info", zap.Any("reportInfo", reportInfo))
|
||||
if reportInfo[hookutil.OpTypeKey] == hookutil.OpTypeStorage {
|
||||
continue
|
||||
}
|
||||
s.Equal(hookutil.OpTypeQuery, reportInfo[hookutil.OpTypeKey])
|
||||
s.NotEqualValues(0, reportInfo[hookutil.DataSizeKey])
|
||||
s.EqualValues(rowNum, reportInfo[hookutil.RelatedCntKey])
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
go queryCheckReport()
|
||||
queryResult, err := c.Proxy.Query(ctx, &milvuspb.QueryRequest{
|
||||
DbName: dbName,
|
||||
CollectionName: collectionName,
|
||||
Expr: "pid == 1",
|
||||
OutputFields: []string{"count(*)"},
|
||||
})
|
||||
if queryResult.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
|
||||
log.Warn("searchResult fail reason", zap.String("reason", queryResult.GetStatus().GetReason()))
|
||||
}
|
||||
s.NoError(err)
|
||||
s.Equal(commonpb.ErrorCode_Success, queryResult.GetStatus().GetErrorCode())
|
||||
}
|
||||
|
||||
{
|
||||
// delete without partition key
|
||||
deleteCheckReport := func() {
|
||||
timeoutCtx, cancelFunc := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancelFunc()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-timeoutCtx.Done():
|
||||
s.Fail("delete check timeout")
|
||||
case report := <-c.Extension.GetReportChan():
|
||||
reportInfo := report.(map[string]any)
|
||||
log.Info("delete report info", zap.Any("reportInfo", reportInfo))
|
||||
if reportInfo[hookutil.OpTypeKey] == hookutil.OpTypeStorage {
|
||||
continue
|
||||
}
|
||||
s.Equal(hookutil.OpTypeDelete, reportInfo[hookutil.OpTypeKey])
|
||||
s.EqualValues(rowNum, reportInfo[hookutil.SuccessCntKey])
|
||||
s.EqualValues(rowNum, reportInfo[hookutil.RelatedCntKey])
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
go deleteCheckReport()
|
||||
deleteResult, err := c.Proxy.Delete(ctx, &milvuspb.DeleteRequest{
|
||||
DbName: dbName,
|
||||
CollectionName: collectionName,
|
||||
Expr: integration.Int64Field + " < 1000",
|
||||
})
|
||||
if deleteResult.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
|
||||
log.Warn("deleteResult fail reason", zap.String("reason", deleteResult.GetStatus().GetReason()))
|
||||
}
|
||||
s.NoError(err)
|
||||
s.Equal(commonpb.ErrorCode_Success, deleteResult.GetStatus().GetErrorCode())
|
||||
}
|
||||
|
||||
{
|
||||
// delete with partition key
|
||||
deleteCheckReport := func() {
|
||||
timeoutCtx, cancelFunc := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancelFunc()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-timeoutCtx.Done():
|
||||
s.Fail("delete check timeout")
|
||||
case report := <-c.Extension.GetReportChan():
|
||||
reportInfo := report.(map[string]any)
|
||||
log.Info("delete report info", zap.Any("reportInfo", reportInfo))
|
||||
if reportInfo[hookutil.OpTypeKey] == hookutil.OpTypeStorage {
|
||||
continue
|
||||
}
|
||||
s.Equal(hookutil.OpTypeDelete, reportInfo[hookutil.OpTypeKey])
|
||||
s.EqualValues(rowNum, reportInfo[hookutil.SuccessCntKey])
|
||||
s.EqualValues(rowNum, reportInfo[hookutil.RelatedCntKey])
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
go deleteCheckReport()
|
||||
deleteResult, err := c.Proxy.Delete(ctx, &milvuspb.DeleteRequest{
|
||||
DbName: dbName,
|
||||
CollectionName: collectionName,
|
||||
Expr: integration.Int64Field + " < 2000 && pid == 2",
|
||||
})
|
||||
if deleteResult.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
|
||||
log.Warn("deleteResult fail reason", zap.String("reason", deleteResult.GetStatus().GetReason()))
|
||||
}
|
||||
s.NoError(err)
|
||||
s.Equal(commonpb.ErrorCode_Success, deleteResult.GetStatus().GetErrorCode())
|
||||
}
|
||||
}
|
||||
|
||||
func TestPartitionKey(t *testing.T) {
|
||||
suite.Run(t, new(PartitionKeySuite))
|
||||
}
|
@ -28,6 +28,7 @@ import (
|
||||
"go.uber.org/zap/zapcore"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus/internal/util/hookutil"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/util/etcd"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
@ -104,6 +105,24 @@ func (s *MiniClusterSuite) SetupTest() {
|
||||
s.Cluster = c
|
||||
|
||||
// start mini cluster
|
||||
nodeIDCheckReport := func() {
|
||||
timeoutCtx, cancelFunc := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancelFunc()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-timeoutCtx.Done():
|
||||
s.Fail("node id check timeout")
|
||||
case report := <-c.Extension.GetReportChan():
|
||||
reportInfo := report.(map[string]any)
|
||||
s.T().Log("node id report info: ", reportInfo)
|
||||
s.Equal(hookutil.OpTypeNodeID, reportInfo[hookutil.OpTypeKey])
|
||||
s.NotEqualValues(0, reportInfo[hookutil.NodeIDKey])
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
go nodeIDCheckReport()
|
||||
s.Require().NoError(s.Cluster.Start())
|
||||
}
|
||||
|
||||
|
@ -58,7 +58,39 @@ func NewInt64FieldData(fieldName string, numRows int) *schemapb.FieldData {
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_LongData{
|
||||
LongData: &schemapb.LongArray{
|
||||
Data: GenerateInt64Array(numRows),
|
||||
Data: GenerateInt64Array(numRows, 0),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func NewInt64FieldDataWithStart(fieldName string, numRows int, start int64) *schemapb.FieldData {
|
||||
return &schemapb.FieldData{
|
||||
Type: schemapb.DataType_Int64,
|
||||
FieldName: fieldName,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_LongData{
|
||||
LongData: &schemapb.LongArray{
|
||||
Data: GenerateInt64Array(numRows, start),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func NewInt64SameFieldData(fieldName string, numRows int, value int64) *schemapb.FieldData {
|
||||
return &schemapb.FieldData{
|
||||
Type: schemapb.DataType_Int64,
|
||||
FieldName: fieldName,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_LongData{
|
||||
LongData: &schemapb.LongArray{
|
||||
Data: GenerateSameInt64Array(numRows, value),
|
||||
},
|
||||
},
|
||||
},
|
||||
@ -144,10 +176,18 @@ func NewBinaryVectorFieldData(fieldName string, numRows, dim int) *schemapb.Fiel
|
||||
}
|
||||
}
|
||||
|
||||
func GenerateInt64Array(numRows int) []int64 {
|
||||
func GenerateInt64Array(numRows int, start int64) []int64 {
|
||||
ret := make([]int64, numRows)
|
||||
for i := 0; i < numRows; i++ {
|
||||
ret[i] = int64(i)
|
||||
ret[i] = int64(i) + start
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func GenerateSameInt64Array(numRows int, value int64) []int64 {
|
||||
ret := make([]int64, numRows)
|
||||
for i := 0; i < numRows; i++ {
|
||||
ret[i] = value
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user