// Licensed to the LF AI & Data foundation under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package mocks import ( "context" "net" "sync" "testing" "time" "github.com/stretchr/testify/mock" clientv3 "go.etcd.io/etcd/client/v3" "go.uber.org/zap" "google.golang.org/grpc" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/internal/proto/querypb" . "github.com/milvus-io/milvus/internal/querycoordv2/params" "github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/typeutil" ) type MockQueryNode struct { *MockQueryNodeServer ID int64 addr string ctx context.Context cancel context.CancelFunc session *sessionutil.Session server *grpc.Server rwmutex sync.RWMutex channels map[int64][]string channelVersion map[string]int64 segments map[int64]map[string][]int64 segmentVersion map[int64]int64 } func NewMockQueryNode(t *testing.T, etcdCli *clientv3.Client, nodeID int64) *MockQueryNode { ctx, cancel := context.WithCancel(context.Background()) node := &MockQueryNode{ MockQueryNodeServer: NewMockQueryNodeServer(t), ctx: ctx, cancel: cancel, session: sessionutil.NewSessionWithEtcd(ctx, Params.EtcdCfg.MetaRootPath.GetValue(), etcdCli), channels: make(map[int64][]string), segments: make(map[int64]map[string][]int64), ID: nodeID, } return node } func (node *MockQueryNode) Start() error { // Start gRPC server lis, err := net.Listen("tcp", "localhost:0") if err != nil { return err } node.addr = lis.Addr().String() node.server = grpc.NewServer() querypb.RegisterQueryNodeServer(node.server, node) go func() { err = node.server.Serve(lis) }() successStatus := merr.Success() node.EXPECT().GetDataDistribution(mock.Anything, mock.Anything).Return(&querypb.GetDataDistributionResponse{ Status: successStatus, NodeID: node.ID, Channels: node.getAllChannels(), Segments: node.getAllSegments(), }, nil).Maybe() node.EXPECT().WatchDmChannels(mock.Anything, mock.Anything).Run(func(ctx context.Context, req *querypb.WatchDmChannelsRequest) { node.rwmutex.Lock() defer node.rwmutex.Unlock() node.channels[req.GetCollectionID()] = append(node.channels[req.GetCollectionID()], req.GetInfos()[0].GetChannelName()) }).Return(successStatus, nil).Maybe() node.EXPECT().LoadSegments(mock.Anything, mock.Anything).Run(func(ctx context.Context, req *querypb.LoadSegmentsRequest) { node.rwmutex.Lock() defer node.rwmutex.Unlock() shardSegments, ok := node.segments[req.GetCollectionID()] if !ok { shardSegments = make(map[string][]int64) node.segments[req.GetCollectionID()] = shardSegments } segment := req.GetInfos()[0] shardSegments[segment.GetInsertChannel()] = append(shardSegments[segment.GetInsertChannel()], segment.GetSegmentID()) node.segmentVersion[segment.GetSegmentID()] = req.GetVersion() }).Return(successStatus, nil).Maybe() node.EXPECT().GetComponentStates(mock.Anything, mock.AnythingOfType("*milvuspb.GetComponentStatesRequest")). Call.Return(func(context.Context, *milvuspb.GetComponentStatesRequest) *milvuspb.ComponentStates { select { case <-node.ctx.Done(): return nil default: return &milvuspb.ComponentStates{ Status: successStatus, } } }, func(context.Context, *milvuspb.GetComponentStatesRequest) error { select { case <-node.ctx.Done(): return grpc.ErrServerStopped default: return nil } }).Maybe() // Register node.session.Init(typeutil.QueryNodeRole, node.addr, false, true) node.session.ServerID = node.ID node.session.Register() log.Debug("mock QueryNode started", zap.Int64("nodeID", node.ID), zap.String("nodeAddr", node.addr)) return err } func (node *MockQueryNode) Stopping() { node.session.GoingStop() } func (node *MockQueryNode) Stop() { node.cancel() node.server.GracefulStop() node.session.Revoke(time.Second) } func (node *MockQueryNode) getAllChannels() []*querypb.ChannelVersionInfo { node.rwmutex.RLock() defer node.rwmutex.RUnlock() ret := make([]*querypb.ChannelVersionInfo, 0) for collection, channels := range node.channels { for _, channel := range channels { ret = append(ret, &querypb.ChannelVersionInfo{ Channel: channel, Collection: collection, Version: node.channelVersion[channel], }) } } return ret } func (node *MockQueryNode) getAllSegments() []*querypb.SegmentVersionInfo { node.rwmutex.RLock() defer node.rwmutex.RUnlock() ret := make([]*querypb.SegmentVersionInfo, 0) for collection, shardSegments := range node.segments { for shard, segments := range shardSegments { for _, segment := range segments { ret = append(ret, &querypb.SegmentVersionInfo{ ID: segment, Collection: collection, Channel: shard, Version: node.segmentVersion[segment], }) } } } return ret }