mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-11-30 02:48:45 +08:00
Refactor datanode register
Dataservice should not be stalled during registering datanode. Signed-off-by: sunby <bingyi.sun@zilliz.com>
This commit is contained in:
parent
deba964590
commit
a7dac818ee
@ -55,7 +55,7 @@ func NewDataNode(ctx context.Context, factory msgstream.Factory) *DataNode {
|
||||
ctx: ctx2,
|
||||
cancel: cancel2,
|
||||
Role: typeutil.DataNodeRole,
|
||||
watchDm: make(chan struct{}),
|
||||
watchDm: make(chan struct{}, 1),
|
||||
|
||||
dataSyncService: nil,
|
||||
metaService: nil,
|
||||
@ -106,6 +106,9 @@ func (node *DataNode) Init() error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("Register node failed: %v", err)
|
||||
}
|
||||
if resp.Status.ErrorCode != commonpb.ErrorCode_Success {
|
||||
return fmt.Errorf("Receive error when registering data node, msg: %s", resp.Status.Reason)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-time.After(RPCConnectionTimeout):
|
||||
|
@ -2,6 +2,7 @@ package dataservice
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
@ -26,30 +27,27 @@ type dataNode struct {
|
||||
}
|
||||
type dataNodeCluster struct {
|
||||
sync.RWMutex
|
||||
finishCh chan struct{}
|
||||
nodes []*dataNode
|
||||
nodes []*dataNode
|
||||
}
|
||||
|
||||
func (node *dataNode) String() string {
|
||||
return fmt.Sprintf("id: %d, address: %s:%d", node.id, node.address.ip, node.address.port)
|
||||
}
|
||||
|
||||
func newDataNodeCluster(finishCh chan struct{}) *dataNodeCluster {
|
||||
func newDataNodeCluster() *dataNodeCluster {
|
||||
return &dataNodeCluster{
|
||||
finishCh: finishCh,
|
||||
nodes: make([]*dataNode, 0),
|
||||
nodes: make([]*dataNode, 0),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *dataNodeCluster) Register(dataNode *dataNode) {
|
||||
func (c *dataNodeCluster) Register(dataNode *dataNode) error {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
if c.checkDataNodeNotExist(dataNode.address.ip, dataNode.address.port) {
|
||||
c.nodes = append(c.nodes, dataNode)
|
||||
if len(c.nodes) == Params.DataNodeNum {
|
||||
close(c.finishCh)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return errors.New("datanode already exist")
|
||||
}
|
||||
|
||||
func (c *dataNodeCluster) checkDataNodeNotExist(ip string, port int64) bool {
|
||||
@ -151,6 +149,5 @@ func (c *dataNodeCluster) ShutDownClients() {
|
||||
func (c *dataNodeCluster) Clear() {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
c.finishCh = make(chan struct{})
|
||||
c.nodes = make([]*dataNode, 0)
|
||||
}
|
||||
|
@ -11,8 +11,7 @@ import (
|
||||
func TestDataNodeClusterRegister(t *testing.T) {
|
||||
Params.Init()
|
||||
Params.DataNodeNum = 3
|
||||
ch := make(chan struct{})
|
||||
cluster := newDataNodeCluster(ch)
|
||||
cluster := newDataNodeCluster()
|
||||
ids := make([]int64, 0, Params.DataNodeNum)
|
||||
for i := 0; i < Params.DataNodeNum; i++ {
|
||||
c := newMockDataNodeClient(int64(i))
|
||||
@ -31,8 +30,6 @@ func TestDataNodeClusterRegister(t *testing.T) {
|
||||
})
|
||||
ids = append(ids, int64(i))
|
||||
}
|
||||
_, ok := <-ch
|
||||
assert.False(t, ok)
|
||||
assert.EqualValues(t, Params.DataNodeNum, cluster.GetNumOfNodes())
|
||||
assert.EqualValues(t, ids, cluster.GetNodeIDs())
|
||||
states, err := cluster.GetDataNodeStates(context.TODO())
|
||||
@ -64,7 +61,7 @@ func TestWatchChannels(t *testing.T) {
|
||||
{1, []string{"c1", "c2", "c3", "c4", "c5", "c6", "c7"}, []int{3, 2, 2}},
|
||||
}
|
||||
|
||||
cluster := newDataNodeCluster(make(chan struct{}))
|
||||
cluster := newDataNodeCluster()
|
||||
for _, c := range cases {
|
||||
for i := 0; i < Params.DataNodeNum; i++ {
|
||||
c := newMockDataNodeClient(int64(i))
|
||||
|
@ -11,6 +11,7 @@ import (
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/datapb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/internalpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/masterpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/milvuspb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
|
||||
)
|
||||
@ -98,3 +99,126 @@ func (c *mockDataNodeClient) Stop() error {
|
||||
c.state = internalpb.StateCode_Abnormal
|
||||
return nil
|
||||
}
|
||||
|
||||
type mockMasterService struct {
|
||||
}
|
||||
|
||||
func newMockMasterService() *mockMasterService {
|
||||
return &mockMasterService{}
|
||||
}
|
||||
|
||||
func (m *mockMasterService) Init() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockMasterService) Start() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockMasterService) Stop() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockMasterService) GetComponentStates(ctx context.Context) (*internalpb.ComponentStates, error) {
|
||||
return &internalpb.ComponentStates{
|
||||
State: &internalpb.ComponentInfo{
|
||||
NodeID: 0,
|
||||
Role: "",
|
||||
StateCode: internalpb.StateCode_Healthy,
|
||||
ExtraInfo: []*commonpb.KeyValuePair{},
|
||||
},
|
||||
SubcomponentStates: []*internalpb.ComponentInfo{},
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
Reason: "",
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *mockMasterService) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) {
|
||||
panic("not implemented") // TODO: Implement
|
||||
}
|
||||
|
||||
//DDL request
|
||||
func (m *mockMasterService) CreateCollection(ctx context.Context, req *milvuspb.CreateCollectionRequest) (*commonpb.Status, error) {
|
||||
panic("not implemented") // TODO: Implement
|
||||
}
|
||||
|
||||
func (m *mockMasterService) DropCollection(ctx context.Context, req *milvuspb.DropCollectionRequest) (*commonpb.Status, error) {
|
||||
panic("not implemented") // TODO: Implement
|
||||
}
|
||||
|
||||
func (m *mockMasterService) HasCollection(ctx context.Context, req *milvuspb.HasCollectionRequest) (*milvuspb.BoolResponse, error) {
|
||||
panic("not implemented") // TODO: Implement
|
||||
}
|
||||
|
||||
func (m *mockMasterService) DescribeCollection(ctx context.Context, req *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error) {
|
||||
panic("not implemented") // TODO: Implement
|
||||
}
|
||||
|
||||
func (m *mockMasterService) ShowCollections(ctx context.Context, req *milvuspb.ShowCollectionsRequest) (*milvuspb.ShowCollectionsResponse, error) {
|
||||
return &milvuspb.ShowCollectionsResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
Reason: "",
|
||||
},
|
||||
CollectionNames: []string{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *mockMasterService) CreatePartition(ctx context.Context, req *milvuspb.CreatePartitionRequest) (*commonpb.Status, error) {
|
||||
panic("not implemented") // TODO: Implement
|
||||
}
|
||||
|
||||
func (m *mockMasterService) DropPartition(ctx context.Context, req *milvuspb.DropPartitionRequest) (*commonpb.Status, error) {
|
||||
panic("not implemented") // TODO: Implement
|
||||
}
|
||||
|
||||
func (m *mockMasterService) HasPartition(ctx context.Context, req *milvuspb.HasPartitionRequest) (*milvuspb.BoolResponse, error) {
|
||||
panic("not implemented") // TODO: Implement
|
||||
}
|
||||
|
||||
func (m *mockMasterService) ShowPartitions(ctx context.Context, req *milvuspb.ShowPartitionsRequest) (*milvuspb.ShowPartitionsResponse, error) {
|
||||
panic("not implemented") // TODO: Implement
|
||||
}
|
||||
|
||||
//index builder service
|
||||
func (m *mockMasterService) CreateIndex(ctx context.Context, req *milvuspb.CreateIndexRequest) (*commonpb.Status, error) {
|
||||
panic("not implemented") // TODO: Implement
|
||||
}
|
||||
|
||||
func (m *mockMasterService) DescribeIndex(ctx context.Context, req *milvuspb.DescribeIndexRequest) (*milvuspb.DescribeIndexResponse, error) {
|
||||
panic("not implemented") // TODO: Implement
|
||||
}
|
||||
|
||||
func (m *mockMasterService) DropIndex(ctx context.Context, req *milvuspb.DropIndexRequest) (*commonpb.Status, error) {
|
||||
panic("not implemented") // TODO: Implement
|
||||
}
|
||||
|
||||
//global timestamp allocator
|
||||
func (m *mockMasterService) AllocTimestamp(ctx context.Context, req *masterpb.AllocTimestampRequest) (*masterpb.AllocTimestampResponse, error) {
|
||||
panic("not implemented") // TODO: Implement
|
||||
}
|
||||
|
||||
func (m *mockMasterService) AllocID(ctx context.Context, req *masterpb.AllocIDRequest) (*masterpb.AllocIDResponse, error) {
|
||||
panic("not implemented") // TODO: Implement
|
||||
}
|
||||
|
||||
//segment
|
||||
func (m *mockMasterService) DescribeSegment(ctx context.Context, req *milvuspb.DescribeSegmentRequest) (*milvuspb.DescribeSegmentResponse, error) {
|
||||
panic("not implemented") // TODO: Implement
|
||||
}
|
||||
|
||||
func (m *mockMasterService) ShowSegments(ctx context.Context, req *milvuspb.ShowSegmentsRequest) (*milvuspb.ShowSegmentsResponse, error) {
|
||||
panic("not implemented") // TODO: Implement
|
||||
}
|
||||
|
||||
func (m *mockMasterService) GetDdChannel(ctx context.Context) (*milvuspb.StringResponse, error) {
|
||||
return &milvuspb.StringResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
Reason: "",
|
||||
},
|
||||
Value: "ddchannel",
|
||||
}, nil
|
||||
}
|
||||
|
@ -39,40 +39,44 @@ type (
|
||||
Timestamp = typeutil.Timestamp
|
||||
)
|
||||
type Server struct {
|
||||
ctx context.Context
|
||||
serverLoopCtx context.Context
|
||||
serverLoopCancel context.CancelFunc
|
||||
serverLoopWg sync.WaitGroup
|
||||
state atomic.Value
|
||||
client *etcdkv.EtcdKV
|
||||
meta *meta
|
||||
segAllocator segmentAllocatorInterface
|
||||
statsHandler *statsHandler
|
||||
ddHandler *ddHandler
|
||||
allocator allocatorInterface
|
||||
cluster *dataNodeCluster
|
||||
msgProducer *timesync.MsgProducer
|
||||
registerFinishCh chan struct{}
|
||||
masterClient types.MasterService
|
||||
ttMsgStream msgstream.MsgStream
|
||||
k2sMsgStream msgstream.MsgStream
|
||||
ddChannelName string
|
||||
segmentInfoStream msgstream.MsgStream
|
||||
insertChannels []string
|
||||
msFactory msgstream.Factory
|
||||
ttBarrier timesync.TimeTickBarrier
|
||||
ctx context.Context
|
||||
serverLoopCtx context.Context
|
||||
serverLoopCancel context.CancelFunc
|
||||
serverLoopWg sync.WaitGroup
|
||||
state atomic.Value
|
||||
client *etcdkv.EtcdKV
|
||||
meta *meta
|
||||
segAllocator segmentAllocatorInterface
|
||||
statsHandler *statsHandler
|
||||
ddHandler *ddHandler
|
||||
allocator allocatorInterface
|
||||
cluster *dataNodeCluster
|
||||
msgProducer *timesync.MsgProducer
|
||||
masterClient types.MasterService
|
||||
ttMsgStream msgstream.MsgStream
|
||||
k2sMsgStream msgstream.MsgStream
|
||||
ddChannelMu struct {
|
||||
sync.Mutex
|
||||
name string
|
||||
}
|
||||
segmentInfoStream msgstream.MsgStream
|
||||
insertChannels []string
|
||||
msFactory msgstream.Factory
|
||||
ttBarrier timesync.TimeTickBarrier
|
||||
createDataNodeClient func(addr string) types.DataNode
|
||||
}
|
||||
|
||||
func CreateServer(ctx context.Context, factory msgstream.Factory) (*Server, error) {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
ch := make(chan struct{})
|
||||
s := &Server{
|
||||
ctx: ctx,
|
||||
registerFinishCh: ch,
|
||||
cluster: newDataNodeCluster(ch),
|
||||
msFactory: factory,
|
||||
ctx: ctx,
|
||||
cluster: newDataNodeCluster(),
|
||||
msFactory: factory,
|
||||
}
|
||||
s.insertChannels = s.getInsertChannels()
|
||||
s.createDataNodeClient = func(addr string) types.DataNode {
|
||||
return grpcdatanodeclient.NewClient(addr)
|
||||
}
|
||||
s.UpdateStateCode(internalpb.StateCode_Abnormal)
|
||||
return s, nil
|
||||
}
|
||||
@ -105,10 +109,12 @@ func (s *Server) Start() error {
|
||||
return err
|
||||
}
|
||||
|
||||
s.allocator = newAllocator(s.masterClient)
|
||||
if err = s.initMeta(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.allocator = newAllocator(s.masterClient)
|
||||
|
||||
s.statsHandler = newStatsHandler(s.meta)
|
||||
s.ddHandler = newDDHandler(s.meta, s.segAllocator, s.masterClient)
|
||||
s.initSegmentInfoChannel()
|
||||
@ -116,8 +122,6 @@ func (s *Server) Start() error {
|
||||
if err = s.loadMetaFromMaster(); err != nil {
|
||||
return err
|
||||
}
|
||||
s.waitDataNodeRegister()
|
||||
s.cluster.WatchInsertChannels(s.insertChannels)
|
||||
if err = s.initMsgProducer(); err != nil {
|
||||
return err
|
||||
}
|
||||
@ -149,11 +153,7 @@ func (s *Server) initMeta() error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
err := retry.Retry(100000, time.Millisecond*200, connectEtcdFn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
return retry.Retry(100000, time.Millisecond*200, connectEtcdFn)
|
||||
}
|
||||
|
||||
func (s *Server) initSegmentInfoChannel() {
|
||||
@ -163,6 +163,7 @@ func (s *Server) initSegmentInfoChannel() {
|
||||
s.segmentInfoStream = segmentInfoStream
|
||||
s.segmentInfoStream.Start()
|
||||
}
|
||||
|
||||
func (s *Server) initMsgProducer() error {
|
||||
var err error
|
||||
if s.ttMsgStream, err = s.msFactory.NewMsgStream(s.ctx); err != nil {
|
||||
@ -195,12 +196,8 @@ func (s *Server) loadMetaFromMaster() error {
|
||||
if err = s.checkMasterIsHealthy(); err != nil {
|
||||
return err
|
||||
}
|
||||
if s.ddChannelName == "" {
|
||||
channel, err := s.masterClient.GetDdChannel(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.ddChannelName = channel.Value
|
||||
if err = s.getDDChannel(); err != nil {
|
||||
return err
|
||||
}
|
||||
collections, err := s.masterClient.ShowCollections(ctx, &milvuspb.ShowCollectionsRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
@ -258,6 +255,19 @@ func (s *Server) loadMetaFromMaster() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) getDDChannel() error {
|
||||
s.ddChannelMu.Lock()
|
||||
defer s.ddChannelMu.Unlock()
|
||||
if len(s.ddChannelMu.name) == 0 {
|
||||
resp, err := s.masterClient.GetDdChannel(s.ctx)
|
||||
if err = VerifyResponse(resp, err); err != nil {
|
||||
return err
|
||||
}
|
||||
s.ddChannelMu.name = resp.Value
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) checkMasterIsHealthy() error {
|
||||
ticker := time.NewTicker(300 * time.Millisecond)
|
||||
ctx, cancel := context.WithTimeout(s.ctx, 30*time.Second)
|
||||
@ -364,9 +374,13 @@ func (s *Server) startProxyServiceTimeTickLoop(ctx context.Context) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Debug("Proxy service timetick loop shut down")
|
||||
return
|
||||
default:
|
||||
}
|
||||
msgPack := flushStream.Consume()
|
||||
if msgPack == nil {
|
||||
continue
|
||||
}
|
||||
for _, msg := range msgPack.Msgs {
|
||||
if msg.Type() != commonpb.MsgType_TimeTick {
|
||||
log.Warn("receive unknown msg from proxy service timetick", zap.Stringer("msgType", msg.Type()))
|
||||
@ -384,8 +398,8 @@ func (s *Server) startProxyServiceTimeTickLoop(ctx context.Context) {
|
||||
func (s *Server) startDDChannel(ctx context.Context) {
|
||||
defer s.serverLoopWg.Done()
|
||||
ddStream, _ := s.msFactory.NewMsgStream(ctx)
|
||||
ddStream.AsConsumer([]string{s.ddChannelName}, Params.DataServiceSubscriptionName)
|
||||
log.Debug("dataservice AsConsumer: " + s.ddChannelName + " : " + Params.DataServiceSubscriptionName)
|
||||
ddStream.AsConsumer([]string{s.ddChannelMu.name}, Params.DataServiceSubscriptionName)
|
||||
log.Debug("dataservice AsConsumer: " + s.ddChannelMu.name + " : " + Params.DataServiceSubscriptionName)
|
||||
ddStream.Start()
|
||||
defer ddStream.Close()
|
||||
for {
|
||||
@ -405,17 +419,11 @@ func (s *Server) startDDChannel(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) waitDataNodeRegister() {
|
||||
log.Debug("waiting data node to register")
|
||||
<-s.registerFinishCh
|
||||
log.Debug("all data nodes register")
|
||||
}
|
||||
|
||||
func (s *Server) Stop() error {
|
||||
s.cluster.ShutDownClients()
|
||||
s.ttBarrier.Close()
|
||||
s.ttMsgStream.Close()
|
||||
s.k2sMsgStream.Close()
|
||||
s.ttBarrier.Close()
|
||||
s.msgProducer.Close()
|
||||
s.segmentInfoStream.Close()
|
||||
s.stopServerLoop()
|
||||
@ -475,24 +483,47 @@ func (s *Server) RegisterNode(ctx context.Context, req *datapb.RegisterNodeReque
|
||||
log.Debug("DataService: RegisterNode:", zap.String("IP", req.Address.Ip), zap.Int64("Port", req.Address.Port))
|
||||
node, err := s.newDataNode(req.Address.Ip, req.Address.Port, req.Base.SourceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
ret.Status.Reason = err.Error()
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
s.cluster.Register(node)
|
||||
resp, err := node.client.WatchDmChannels(s.ctx, &datapb.WatchDmChannelsRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: 0,
|
||||
MsgID: 0,
|
||||
Timestamp: 0,
|
||||
SourceID: Params.NodeID,
|
||||
},
|
||||
ChannelNames: s.insertChannels,
|
||||
})
|
||||
|
||||
if s.ddChannelName == "" {
|
||||
resp, err := s.masterClient.GetDdChannel(ctx)
|
||||
if err = VerifyResponse(resp, err); err != nil {
|
||||
if err = VerifyResponse(resp, err); err != nil {
|
||||
ret.Status.Reason = err.Error()
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
if err := s.getDDChannel(); err != nil {
|
||||
ret.Status.Reason = err.Error()
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
if s.ttBarrier != nil {
|
||||
if err = s.ttBarrier.AddPeer(node.id); err != nil {
|
||||
ret.Status.Reason = err.Error()
|
||||
return ret, err
|
||||
return ret, nil
|
||||
}
|
||||
s.ddChannelName = resp.Value
|
||||
}
|
||||
|
||||
if err = s.cluster.Register(node); err != nil {
|
||||
ret.Status.Reason = err.Error()
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
ret.Status.ErrorCode = commonpb.ErrorCode_Success
|
||||
ret.InitParams = &internalpb.InitParams{
|
||||
NodeID: Params.NodeID,
|
||||
StartParams: []*commonpb.KeyValuePair{
|
||||
{Key: "DDChannelName", Value: s.ddChannelName},
|
||||
{Key: "DDChannelName", Value: s.ddChannelMu.name},
|
||||
{Key: "SegmentStatisticsChannelName", Value: Params.StatisticsChannelName},
|
||||
{Key: "TimeTickChannelName", Value: Params.TimeTickChannelName},
|
||||
{Key: "CompleteFlushChannelName", Value: Params.SegmentInfoChannelName},
|
||||
@ -502,7 +533,7 @@ func (s *Server) RegisterNode(ctx context.Context, req *datapb.RegisterNodeReque
|
||||
}
|
||||
|
||||
func (s *Server) newDataNode(ip string, port int64, id UniqueID) (*dataNode, error) {
|
||||
client := grpcdatanodeclient.NewClient(fmt.Sprintf("%s:%d", ip, port))
|
||||
client := s.createDataNodeClient(fmt.Sprintf("%s:%d", ip, port))
|
||||
if err := client.Init(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
62
internal/dataservice/server_test.go
Normal file
62
internal/dataservice/server_test.go
Normal file
@ -0,0 +1,62 @@
|
||||
package dataservice
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zilliztech/milvus-distributed/internal/msgstream"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/datapb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/types"
|
||||
)
|
||||
|
||||
func TestRegisterNode(t *testing.T) {
|
||||
Params.Init()
|
||||
Params.DataNodeNum = 1
|
||||
var err error
|
||||
factory := msgstream.NewPmsFactory()
|
||||
m := map[string]interface{}{
|
||||
"pulsarAddress": Params.PulsarAddress,
|
||||
"receiveBufSize": 1024,
|
||||
"pulsarBufSize": 1024,
|
||||
}
|
||||
err = factory.SetParams(m)
|
||||
assert.Nil(t, err)
|
||||
svr, err := CreateServer(context.TODO(), factory)
|
||||
assert.Nil(t, err)
|
||||
ms := newMockMasterService()
|
||||
err = ms.Init()
|
||||
assert.Nil(t, err)
|
||||
err = ms.Start()
|
||||
assert.Nil(t, err)
|
||||
defer ms.Stop()
|
||||
svr.SetMasterClient(ms)
|
||||
svr.createDataNodeClient = func(addr string) types.DataNode {
|
||||
return newMockDataNodeClient(0)
|
||||
}
|
||||
assert.Nil(t, err)
|
||||
err = svr.Init()
|
||||
assert.Nil(t, err)
|
||||
err = svr.Start()
|
||||
assert.Nil(t, err)
|
||||
defer svr.Stop()
|
||||
t.Run("register node", func(t *testing.T) {
|
||||
resp, err := svr.RegisterNode(context.TODO(), &datapb.RegisterNodeRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: 0,
|
||||
MsgID: 0,
|
||||
Timestamp: 0,
|
||||
SourceID: 1000,
|
||||
},
|
||||
Address: &commonpb.Address{
|
||||
Ip: "localhost",
|
||||
Port: 1000,
|
||||
},
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
assert.EqualValues(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
|
||||
assert.EqualValues(t, Params.DataNodeNum, svr.cluster.GetNumOfNodes())
|
||||
assert.EqualValues(t, []int64{1000}, svr.cluster.GetNodeIDs())
|
||||
})
|
||||
}
|
@ -18,8 +18,7 @@ import (
|
||||
func TestDataNodeTTWatcher(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
Params.Init()
|
||||
c := make(chan struct{})
|
||||
cluster := newDataNodeCluster(c)
|
||||
cluster := newDataNodeCluster()
|
||||
defer cluster.ShutDownClients()
|
||||
schema := newTestSchema()
|
||||
allocator := newMockAllocator()
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
"sync"
|
||||
|
||||
"github.com/zilliztech/milvus-distributed/internal/logutil"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/zilliztech/milvus-distributed/internal/log"
|
||||
ms "github.com/zilliztech/milvus-distributed/internal/msgstream"
|
||||
@ -39,7 +40,8 @@ func (producer *MsgProducer) broadcastMsg() {
|
||||
}
|
||||
tt, err := producer.ttBarrier.GetTimeTick()
|
||||
if err != nil {
|
||||
log.Debug("broadcast get time tick error")
|
||||
log.Debug("broadcast get time tick error", zap.Error(err))
|
||||
return
|
||||
}
|
||||
baseMsg := ms.BaseMsg{
|
||||
BeginTimestamp: tt,
|
||||
|
@ -3,6 +3,7 @@ package timesync
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@ -25,6 +26,7 @@ type (
|
||||
GetTimeTick() (Timestamp, error)
|
||||
Start()
|
||||
Close()
|
||||
AddPeer(peerID UniqueID) error
|
||||
}
|
||||
|
||||
softTimeTickBarrier struct {
|
||||
@ -37,13 +39,13 @@ type (
|
||||
}
|
||||
|
||||
hardTimeTickBarrier struct {
|
||||
peer2Tt map[UniqueID]Timestamp
|
||||
outTt chan Timestamp
|
||||
ttStream ms.MsgStream
|
||||
ctx context.Context
|
||||
wg sync.WaitGroup
|
||||
loopCtx context.Context
|
||||
loopCancel context.CancelFunc
|
||||
peer2Tt map[UniqueID]Timestamp
|
||||
peer2TtMu sync.Mutex
|
||||
outTt chan Timestamp
|
||||
ttStream ms.MsgStream
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
)
|
||||
|
||||
@ -149,7 +151,6 @@ func (ttBarrier *hardTimeTickBarrier) GetTimeTick() (Timestamp, error) {
|
||||
func (ttBarrier *hardTimeTickBarrier) Start() {
|
||||
// Last timestamp synchronized
|
||||
ttBarrier.wg.Add(1)
|
||||
ttBarrier.loopCtx, ttBarrier.loopCancel = context.WithCancel(ttBarrier.ctx)
|
||||
state := Timestamp(0)
|
||||
go func(ctx context.Context) {
|
||||
defer logutil.LogPanic()
|
||||
@ -162,8 +163,10 @@ func (ttBarrier *hardTimeTickBarrier) Start() {
|
||||
default:
|
||||
}
|
||||
ttmsgs := ttBarrier.ttStream.Consume()
|
||||
if len(ttmsgs.Msgs) > 0 {
|
||||
|
||||
if ttmsgs != nil && len(ttmsgs.Msgs) > 0 {
|
||||
log.Debug("receive tt msg")
|
||||
ttBarrier.peer2TtMu.Lock()
|
||||
for _, timetickmsg := range ttmsgs.Msgs {
|
||||
// Suppose ttmsg.Timestamp from stream is always larger than the previous one,
|
||||
// that `ttmsg.Timestamp > oldT`
|
||||
@ -181,20 +184,20 @@ func (ttBarrier *hardTimeTickBarrier) Start() {
|
||||
}
|
||||
|
||||
ttBarrier.peer2Tt[ttmsg.Base.SourceID] = ttmsg.Base.Timestamp
|
||||
|
||||
newState := ttBarrier.minTimestamp()
|
||||
if newState > state {
|
||||
ttBarrier.outTt <- newState
|
||||
state = newState
|
||||
}
|
||||
}
|
||||
ttBarrier.peer2TtMu.Unlock()
|
||||
}
|
||||
}
|
||||
}(ttBarrier.loopCtx)
|
||||
}(ttBarrier.ctx)
|
||||
}
|
||||
|
||||
func (ttBarrier *hardTimeTickBarrier) Close() {
|
||||
ttBarrier.loopCancel()
|
||||
ttBarrier.cancel()
|
||||
ttBarrier.wg.Wait()
|
||||
}
|
||||
|
||||
@ -208,24 +211,29 @@ func (ttBarrier *hardTimeTickBarrier) minTimestamp() Timestamp {
|
||||
return tempMin
|
||||
}
|
||||
|
||||
func NewHardTimeTickBarrier(ctx context.Context, ttStream ms.MsgStream, peerIds []UniqueID) *hardTimeTickBarrier {
|
||||
if len(peerIds) <= 0 {
|
||||
log.Error("[newSoftTimeTickBarrier] peerIds is empty!")
|
||||
return nil
|
||||
func (ttBarrier *hardTimeTickBarrier) AddPeer(peerID UniqueID) error {
|
||||
ttBarrier.peer2TtMu.Lock()
|
||||
defer ttBarrier.peer2TtMu.Unlock()
|
||||
if _, ok := ttBarrier.peer2Tt[peerID]; ok {
|
||||
return fmt.Errorf("peer %d already exist", peerID)
|
||||
}
|
||||
|
||||
sttbarrier := hardTimeTickBarrier{}
|
||||
sttbarrier.ttStream = ttStream
|
||||
sttbarrier.outTt = make(chan Timestamp, 1024)
|
||||
|
||||
sttbarrier.peer2Tt = make(map[UniqueID]Timestamp)
|
||||
sttbarrier.ctx = ctx
|
||||
for _, id := range peerIds {
|
||||
sttbarrier.peer2Tt[id] = Timestamp(0)
|
||||
}
|
||||
if len(peerIds) != len(sttbarrier.peer2Tt) {
|
||||
log.Warn("[newSoftTimeTickBarrier] there are duplicate peerIds!", zap.Int64s("peerIDs", peerIds))
|
||||
}
|
||||
|
||||
return &sttbarrier
|
||||
ttBarrier.peer2Tt[peerID] = Timestamp(0)
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewHardTimeTickBarrier(ctx context.Context, ttStream ms.MsgStream, peerIds []UniqueID) *hardTimeTickBarrier {
|
||||
ttbarrier := hardTimeTickBarrier{}
|
||||
ttbarrier.ttStream = ttStream
|
||||
ttbarrier.outTt = make(chan Timestamp, 1024)
|
||||
|
||||
ttbarrier.peer2Tt = make(map[UniqueID]Timestamp)
|
||||
ttbarrier.ctx, ttbarrier.cancel = context.WithCancel(ctx)
|
||||
for _, id := range peerIds {
|
||||
ttbarrier.peer2Tt[id] = Timestamp(0)
|
||||
}
|
||||
if len(peerIds) != len(ttbarrier.peer2Tt) {
|
||||
log.Warn("[newHardTimeTickBarrier] there are duplicate peerIds!", zap.Int64s("peerIDs", peerIds))
|
||||
}
|
||||
|
||||
return &ttbarrier
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user