Refactor datanode register

Dataservice should not be stalled during registering datanode.

Signed-off-by: sunby <bingyi.sun@zilliz.com>
This commit is contained in:
sunby 2021-04-13 09:47:02 +08:00 committed by yefu.chen
parent deba964590
commit a7dac818ee
9 changed files with 332 additions and 109 deletions

View File

@ -55,7 +55,7 @@ func NewDataNode(ctx context.Context, factory msgstream.Factory) *DataNode {
ctx: ctx2, ctx: ctx2,
cancel: cancel2, cancel: cancel2,
Role: typeutil.DataNodeRole, Role: typeutil.DataNodeRole,
watchDm: make(chan struct{}), watchDm: make(chan struct{}, 1),
dataSyncService: nil, dataSyncService: nil,
metaService: nil, metaService: nil,
@ -106,6 +106,9 @@ func (node *DataNode) Init() error {
if err != nil { if err != nil {
return fmt.Errorf("Register node failed: %v", err) return fmt.Errorf("Register node failed: %v", err)
} }
if resp.Status.ErrorCode != commonpb.ErrorCode_Success {
return fmt.Errorf("Receive error when registering data node, msg: %s", resp.Status.Reason)
}
select { select {
case <-time.After(RPCConnectionTimeout): case <-time.After(RPCConnectionTimeout):

View File

@ -2,6 +2,7 @@ package dataservice
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"sync" "sync"
@ -26,30 +27,27 @@ type dataNode struct {
} }
type dataNodeCluster struct { type dataNodeCluster struct {
sync.RWMutex sync.RWMutex
finishCh chan struct{} nodes []*dataNode
nodes []*dataNode
} }
func (node *dataNode) String() string { func (node *dataNode) String() string {
return fmt.Sprintf("id: %d, address: %s:%d", node.id, node.address.ip, node.address.port) return fmt.Sprintf("id: %d, address: %s:%d", node.id, node.address.ip, node.address.port)
} }
func newDataNodeCluster(finishCh chan struct{}) *dataNodeCluster { func newDataNodeCluster() *dataNodeCluster {
return &dataNodeCluster{ return &dataNodeCluster{
finishCh: finishCh, nodes: make([]*dataNode, 0),
nodes: make([]*dataNode, 0),
} }
} }
func (c *dataNodeCluster) Register(dataNode *dataNode) { func (c *dataNodeCluster) Register(dataNode *dataNode) error {
c.Lock() c.Lock()
defer c.Unlock() defer c.Unlock()
if c.checkDataNodeNotExist(dataNode.address.ip, dataNode.address.port) { if c.checkDataNodeNotExist(dataNode.address.ip, dataNode.address.port) {
c.nodes = append(c.nodes, dataNode) c.nodes = append(c.nodes, dataNode)
if len(c.nodes) == Params.DataNodeNum { return nil
close(c.finishCh)
}
} }
return errors.New("datanode already exist")
} }
func (c *dataNodeCluster) checkDataNodeNotExist(ip string, port int64) bool { func (c *dataNodeCluster) checkDataNodeNotExist(ip string, port int64) bool {
@ -151,6 +149,5 @@ func (c *dataNodeCluster) ShutDownClients() {
func (c *dataNodeCluster) Clear() { func (c *dataNodeCluster) Clear() {
c.Lock() c.Lock()
defer c.Unlock() defer c.Unlock()
c.finishCh = make(chan struct{})
c.nodes = make([]*dataNode, 0) c.nodes = make([]*dataNode, 0)
} }

View File

@ -11,8 +11,7 @@ import (
func TestDataNodeClusterRegister(t *testing.T) { func TestDataNodeClusterRegister(t *testing.T) {
Params.Init() Params.Init()
Params.DataNodeNum = 3 Params.DataNodeNum = 3
ch := make(chan struct{}) cluster := newDataNodeCluster()
cluster := newDataNodeCluster(ch)
ids := make([]int64, 0, Params.DataNodeNum) ids := make([]int64, 0, Params.DataNodeNum)
for i := 0; i < Params.DataNodeNum; i++ { for i := 0; i < Params.DataNodeNum; i++ {
c := newMockDataNodeClient(int64(i)) c := newMockDataNodeClient(int64(i))
@ -31,8 +30,6 @@ func TestDataNodeClusterRegister(t *testing.T) {
}) })
ids = append(ids, int64(i)) ids = append(ids, int64(i))
} }
_, ok := <-ch
assert.False(t, ok)
assert.EqualValues(t, Params.DataNodeNum, cluster.GetNumOfNodes()) assert.EqualValues(t, Params.DataNodeNum, cluster.GetNumOfNodes())
assert.EqualValues(t, ids, cluster.GetNodeIDs()) assert.EqualValues(t, ids, cluster.GetNodeIDs())
states, err := cluster.GetDataNodeStates(context.TODO()) states, err := cluster.GetDataNodeStates(context.TODO())
@ -64,7 +61,7 @@ func TestWatchChannels(t *testing.T) {
{1, []string{"c1", "c2", "c3", "c4", "c5", "c6", "c7"}, []int{3, 2, 2}}, {1, []string{"c1", "c2", "c3", "c4", "c5", "c6", "c7"}, []int{3, 2, 2}},
} }
cluster := newDataNodeCluster(make(chan struct{})) cluster := newDataNodeCluster()
for _, c := range cases { for _, c := range cases {
for i := 0; i < Params.DataNodeNum; i++ { for i := 0; i < Params.DataNodeNum; i++ {
c := newMockDataNodeClient(int64(i)) c := newMockDataNodeClient(int64(i))

View File

@ -11,6 +11,7 @@ import (
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb" "github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
"github.com/zilliztech/milvus-distributed/internal/proto/datapb" "github.com/zilliztech/milvus-distributed/internal/proto/datapb"
"github.com/zilliztech/milvus-distributed/internal/proto/internalpb" "github.com/zilliztech/milvus-distributed/internal/proto/internalpb"
"github.com/zilliztech/milvus-distributed/internal/proto/masterpb"
"github.com/zilliztech/milvus-distributed/internal/proto/milvuspb" "github.com/zilliztech/milvus-distributed/internal/proto/milvuspb"
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb" "github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
) )
@ -98,3 +99,126 @@ func (c *mockDataNodeClient) Stop() error {
c.state = internalpb.StateCode_Abnormal c.state = internalpb.StateCode_Abnormal
return nil return nil
} }
type mockMasterService struct {
}
func newMockMasterService() *mockMasterService {
return &mockMasterService{}
}
func (m *mockMasterService) Init() error {
return nil
}
func (m *mockMasterService) Start() error {
return nil
}
func (m *mockMasterService) Stop() error {
return nil
}
func (m *mockMasterService) GetComponentStates(ctx context.Context) (*internalpb.ComponentStates, error) {
return &internalpb.ComponentStates{
State: &internalpb.ComponentInfo{
NodeID: 0,
Role: "",
StateCode: internalpb.StateCode_Healthy,
ExtraInfo: []*commonpb.KeyValuePair{},
},
SubcomponentStates: []*internalpb.ComponentInfo{},
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
Reason: "",
},
}, nil
}
func (m *mockMasterService) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) {
panic("not implemented") // TODO: Implement
}
//DDL request
func (m *mockMasterService) CreateCollection(ctx context.Context, req *milvuspb.CreateCollectionRequest) (*commonpb.Status, error) {
panic("not implemented") // TODO: Implement
}
func (m *mockMasterService) DropCollection(ctx context.Context, req *milvuspb.DropCollectionRequest) (*commonpb.Status, error) {
panic("not implemented") // TODO: Implement
}
func (m *mockMasterService) HasCollection(ctx context.Context, req *milvuspb.HasCollectionRequest) (*milvuspb.BoolResponse, error) {
panic("not implemented") // TODO: Implement
}
func (m *mockMasterService) DescribeCollection(ctx context.Context, req *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error) {
panic("not implemented") // TODO: Implement
}
func (m *mockMasterService) ShowCollections(ctx context.Context, req *milvuspb.ShowCollectionsRequest) (*milvuspb.ShowCollectionsResponse, error) {
return &milvuspb.ShowCollectionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
Reason: "",
},
CollectionNames: []string{},
}, nil
}
func (m *mockMasterService) CreatePartition(ctx context.Context, req *milvuspb.CreatePartitionRequest) (*commonpb.Status, error) {
panic("not implemented") // TODO: Implement
}
func (m *mockMasterService) DropPartition(ctx context.Context, req *milvuspb.DropPartitionRequest) (*commonpb.Status, error) {
panic("not implemented") // TODO: Implement
}
func (m *mockMasterService) HasPartition(ctx context.Context, req *milvuspb.HasPartitionRequest) (*milvuspb.BoolResponse, error) {
panic("not implemented") // TODO: Implement
}
func (m *mockMasterService) ShowPartitions(ctx context.Context, req *milvuspb.ShowPartitionsRequest) (*milvuspb.ShowPartitionsResponse, error) {
panic("not implemented") // TODO: Implement
}
//index builder service
func (m *mockMasterService) CreateIndex(ctx context.Context, req *milvuspb.CreateIndexRequest) (*commonpb.Status, error) {
panic("not implemented") // TODO: Implement
}
func (m *mockMasterService) DescribeIndex(ctx context.Context, req *milvuspb.DescribeIndexRequest) (*milvuspb.DescribeIndexResponse, error) {
panic("not implemented") // TODO: Implement
}
func (m *mockMasterService) DropIndex(ctx context.Context, req *milvuspb.DropIndexRequest) (*commonpb.Status, error) {
panic("not implemented") // TODO: Implement
}
//global timestamp allocator
func (m *mockMasterService) AllocTimestamp(ctx context.Context, req *masterpb.AllocTimestampRequest) (*masterpb.AllocTimestampResponse, error) {
panic("not implemented") // TODO: Implement
}
func (m *mockMasterService) AllocID(ctx context.Context, req *masterpb.AllocIDRequest) (*masterpb.AllocIDResponse, error) {
panic("not implemented") // TODO: Implement
}
//segment
func (m *mockMasterService) DescribeSegment(ctx context.Context, req *milvuspb.DescribeSegmentRequest) (*milvuspb.DescribeSegmentResponse, error) {
panic("not implemented") // TODO: Implement
}
func (m *mockMasterService) ShowSegments(ctx context.Context, req *milvuspb.ShowSegmentsRequest) (*milvuspb.ShowSegmentsResponse, error) {
panic("not implemented") // TODO: Implement
}
func (m *mockMasterService) GetDdChannel(ctx context.Context) (*milvuspb.StringResponse, error) {
return &milvuspb.StringResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
Reason: "",
},
Value: "ddchannel",
}, nil
}

View File

@ -39,40 +39,44 @@ type (
Timestamp = typeutil.Timestamp Timestamp = typeutil.Timestamp
) )
type Server struct { type Server struct {
ctx context.Context ctx context.Context
serverLoopCtx context.Context serverLoopCtx context.Context
serverLoopCancel context.CancelFunc serverLoopCancel context.CancelFunc
serverLoopWg sync.WaitGroup serverLoopWg sync.WaitGroup
state atomic.Value state atomic.Value
client *etcdkv.EtcdKV client *etcdkv.EtcdKV
meta *meta meta *meta
segAllocator segmentAllocatorInterface segAllocator segmentAllocatorInterface
statsHandler *statsHandler statsHandler *statsHandler
ddHandler *ddHandler ddHandler *ddHandler
allocator allocatorInterface allocator allocatorInterface
cluster *dataNodeCluster cluster *dataNodeCluster
msgProducer *timesync.MsgProducer msgProducer *timesync.MsgProducer
registerFinishCh chan struct{} masterClient types.MasterService
masterClient types.MasterService ttMsgStream msgstream.MsgStream
ttMsgStream msgstream.MsgStream k2sMsgStream msgstream.MsgStream
k2sMsgStream msgstream.MsgStream ddChannelMu struct {
ddChannelName string sync.Mutex
segmentInfoStream msgstream.MsgStream name string
insertChannels []string }
msFactory msgstream.Factory segmentInfoStream msgstream.MsgStream
ttBarrier timesync.TimeTickBarrier insertChannels []string
msFactory msgstream.Factory
ttBarrier timesync.TimeTickBarrier
createDataNodeClient func(addr string) types.DataNode
} }
func CreateServer(ctx context.Context, factory msgstream.Factory) (*Server, error) { func CreateServer(ctx context.Context, factory msgstream.Factory) (*Server, error) {
rand.Seed(time.Now().UnixNano()) rand.Seed(time.Now().UnixNano())
ch := make(chan struct{})
s := &Server{ s := &Server{
ctx: ctx, ctx: ctx,
registerFinishCh: ch, cluster: newDataNodeCluster(),
cluster: newDataNodeCluster(ch), msFactory: factory,
msFactory: factory,
} }
s.insertChannels = s.getInsertChannels() s.insertChannels = s.getInsertChannels()
s.createDataNodeClient = func(addr string) types.DataNode {
return grpcdatanodeclient.NewClient(addr)
}
s.UpdateStateCode(internalpb.StateCode_Abnormal) s.UpdateStateCode(internalpb.StateCode_Abnormal)
return s, nil return s, nil
} }
@ -105,10 +109,12 @@ func (s *Server) Start() error {
return err return err
} }
s.allocator = newAllocator(s.masterClient)
if err = s.initMeta(); err != nil { if err = s.initMeta(); err != nil {
return err return err
} }
s.allocator = newAllocator(s.masterClient)
s.statsHandler = newStatsHandler(s.meta) s.statsHandler = newStatsHandler(s.meta)
s.ddHandler = newDDHandler(s.meta, s.segAllocator, s.masterClient) s.ddHandler = newDDHandler(s.meta, s.segAllocator, s.masterClient)
s.initSegmentInfoChannel() s.initSegmentInfoChannel()
@ -116,8 +122,6 @@ func (s *Server) Start() error {
if err = s.loadMetaFromMaster(); err != nil { if err = s.loadMetaFromMaster(); err != nil {
return err return err
} }
s.waitDataNodeRegister()
s.cluster.WatchInsertChannels(s.insertChannels)
if err = s.initMsgProducer(); err != nil { if err = s.initMsgProducer(); err != nil {
return err return err
} }
@ -149,11 +153,7 @@ func (s *Server) initMeta() error {
} }
return nil return nil
} }
err := retry.Retry(100000, time.Millisecond*200, connectEtcdFn) return retry.Retry(100000, time.Millisecond*200, connectEtcdFn)
if err != nil {
return err
}
return nil
} }
func (s *Server) initSegmentInfoChannel() { func (s *Server) initSegmentInfoChannel() {
@ -163,6 +163,7 @@ func (s *Server) initSegmentInfoChannel() {
s.segmentInfoStream = segmentInfoStream s.segmentInfoStream = segmentInfoStream
s.segmentInfoStream.Start() s.segmentInfoStream.Start()
} }
func (s *Server) initMsgProducer() error { func (s *Server) initMsgProducer() error {
var err error var err error
if s.ttMsgStream, err = s.msFactory.NewMsgStream(s.ctx); err != nil { if s.ttMsgStream, err = s.msFactory.NewMsgStream(s.ctx); err != nil {
@ -195,12 +196,8 @@ func (s *Server) loadMetaFromMaster() error {
if err = s.checkMasterIsHealthy(); err != nil { if err = s.checkMasterIsHealthy(); err != nil {
return err return err
} }
if s.ddChannelName == "" { if err = s.getDDChannel(); err != nil {
channel, err := s.masterClient.GetDdChannel(ctx) return err
if err != nil {
return err
}
s.ddChannelName = channel.Value
} }
collections, err := s.masterClient.ShowCollections(ctx, &milvuspb.ShowCollectionsRequest{ collections, err := s.masterClient.ShowCollections(ctx, &milvuspb.ShowCollectionsRequest{
Base: &commonpb.MsgBase{ Base: &commonpb.MsgBase{
@ -258,6 +255,19 @@ func (s *Server) loadMetaFromMaster() error {
return nil return nil
} }
func (s *Server) getDDChannel() error {
s.ddChannelMu.Lock()
defer s.ddChannelMu.Unlock()
if len(s.ddChannelMu.name) == 0 {
resp, err := s.masterClient.GetDdChannel(s.ctx)
if err = VerifyResponse(resp, err); err != nil {
return err
}
s.ddChannelMu.name = resp.Value
}
return nil
}
func (s *Server) checkMasterIsHealthy() error { func (s *Server) checkMasterIsHealthy() error {
ticker := time.NewTicker(300 * time.Millisecond) ticker := time.NewTicker(300 * time.Millisecond)
ctx, cancel := context.WithTimeout(s.ctx, 30*time.Second) ctx, cancel := context.WithTimeout(s.ctx, 30*time.Second)
@ -364,9 +374,13 @@ func (s *Server) startProxyServiceTimeTickLoop(ctx context.Context) {
select { select {
case <-ctx.Done(): case <-ctx.Done():
log.Debug("Proxy service timetick loop shut down") log.Debug("Proxy service timetick loop shut down")
return
default: default:
} }
msgPack := flushStream.Consume() msgPack := flushStream.Consume()
if msgPack == nil {
continue
}
for _, msg := range msgPack.Msgs { for _, msg := range msgPack.Msgs {
if msg.Type() != commonpb.MsgType_TimeTick { if msg.Type() != commonpb.MsgType_TimeTick {
log.Warn("receive unknown msg from proxy service timetick", zap.Stringer("msgType", msg.Type())) log.Warn("receive unknown msg from proxy service timetick", zap.Stringer("msgType", msg.Type()))
@ -384,8 +398,8 @@ func (s *Server) startProxyServiceTimeTickLoop(ctx context.Context) {
func (s *Server) startDDChannel(ctx context.Context) { func (s *Server) startDDChannel(ctx context.Context) {
defer s.serverLoopWg.Done() defer s.serverLoopWg.Done()
ddStream, _ := s.msFactory.NewMsgStream(ctx) ddStream, _ := s.msFactory.NewMsgStream(ctx)
ddStream.AsConsumer([]string{s.ddChannelName}, Params.DataServiceSubscriptionName) ddStream.AsConsumer([]string{s.ddChannelMu.name}, Params.DataServiceSubscriptionName)
log.Debug("dataservice AsConsumer: " + s.ddChannelName + " : " + Params.DataServiceSubscriptionName) log.Debug("dataservice AsConsumer: " + s.ddChannelMu.name + " : " + Params.DataServiceSubscriptionName)
ddStream.Start() ddStream.Start()
defer ddStream.Close() defer ddStream.Close()
for { for {
@ -405,17 +419,11 @@ func (s *Server) startDDChannel(ctx context.Context) {
} }
} }
func (s *Server) waitDataNodeRegister() {
log.Debug("waiting data node to register")
<-s.registerFinishCh
log.Debug("all data nodes register")
}
func (s *Server) Stop() error { func (s *Server) Stop() error {
s.cluster.ShutDownClients() s.cluster.ShutDownClients()
s.ttBarrier.Close()
s.ttMsgStream.Close() s.ttMsgStream.Close()
s.k2sMsgStream.Close() s.k2sMsgStream.Close()
s.ttBarrier.Close()
s.msgProducer.Close() s.msgProducer.Close()
s.segmentInfoStream.Close() s.segmentInfoStream.Close()
s.stopServerLoop() s.stopServerLoop()
@ -475,24 +483,47 @@ func (s *Server) RegisterNode(ctx context.Context, req *datapb.RegisterNodeReque
log.Debug("DataService: RegisterNode:", zap.String("IP", req.Address.Ip), zap.Int64("Port", req.Address.Port)) log.Debug("DataService: RegisterNode:", zap.String("IP", req.Address.Ip), zap.Int64("Port", req.Address.Port))
node, err := s.newDataNode(req.Address.Ip, req.Address.Port, req.Base.SourceID) node, err := s.newDataNode(req.Address.Ip, req.Address.Port, req.Base.SourceID)
if err != nil { if err != nil {
return nil, err ret.Status.Reason = err.Error()
return ret, nil
} }
s.cluster.Register(node) resp, err := node.client.WatchDmChannels(s.ctx, &datapb.WatchDmChannelsRequest{
Base: &commonpb.MsgBase{
MsgType: 0,
MsgID: 0,
Timestamp: 0,
SourceID: Params.NodeID,
},
ChannelNames: s.insertChannels,
})
if s.ddChannelName == "" { if err = VerifyResponse(resp, err); err != nil {
resp, err := s.masterClient.GetDdChannel(ctx) ret.Status.Reason = err.Error()
if err = VerifyResponse(resp, err); err != nil { return ret, nil
}
if err := s.getDDChannel(); err != nil {
ret.Status.Reason = err.Error()
return ret, nil
}
if s.ttBarrier != nil {
if err = s.ttBarrier.AddPeer(node.id); err != nil {
ret.Status.Reason = err.Error() ret.Status.Reason = err.Error()
return ret, err return ret, nil
} }
s.ddChannelName = resp.Value
} }
if err = s.cluster.Register(node); err != nil {
ret.Status.Reason = err.Error()
return ret, nil
}
ret.Status.ErrorCode = commonpb.ErrorCode_Success ret.Status.ErrorCode = commonpb.ErrorCode_Success
ret.InitParams = &internalpb.InitParams{ ret.InitParams = &internalpb.InitParams{
NodeID: Params.NodeID, NodeID: Params.NodeID,
StartParams: []*commonpb.KeyValuePair{ StartParams: []*commonpb.KeyValuePair{
{Key: "DDChannelName", Value: s.ddChannelName}, {Key: "DDChannelName", Value: s.ddChannelMu.name},
{Key: "SegmentStatisticsChannelName", Value: Params.StatisticsChannelName}, {Key: "SegmentStatisticsChannelName", Value: Params.StatisticsChannelName},
{Key: "TimeTickChannelName", Value: Params.TimeTickChannelName}, {Key: "TimeTickChannelName", Value: Params.TimeTickChannelName},
{Key: "CompleteFlushChannelName", Value: Params.SegmentInfoChannelName}, {Key: "CompleteFlushChannelName", Value: Params.SegmentInfoChannelName},
@ -502,7 +533,7 @@ func (s *Server) RegisterNode(ctx context.Context, req *datapb.RegisterNodeReque
} }
func (s *Server) newDataNode(ip string, port int64, id UniqueID) (*dataNode, error) { func (s *Server) newDataNode(ip string, port int64, id UniqueID) (*dataNode, error) {
client := grpcdatanodeclient.NewClient(fmt.Sprintf("%s:%d", ip, port)) client := s.createDataNodeClient(fmt.Sprintf("%s:%d", ip, port))
if err := client.Init(); err != nil { if err := client.Init(); err != nil {
return nil, err return nil, err
} }

View File

@ -0,0 +1,62 @@
package dataservice
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/zilliztech/milvus-distributed/internal/msgstream"
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
"github.com/zilliztech/milvus-distributed/internal/proto/datapb"
"github.com/zilliztech/milvus-distributed/internal/types"
)
func TestRegisterNode(t *testing.T) {
Params.Init()
Params.DataNodeNum = 1
var err error
factory := msgstream.NewPmsFactory()
m := map[string]interface{}{
"pulsarAddress": Params.PulsarAddress,
"receiveBufSize": 1024,
"pulsarBufSize": 1024,
}
err = factory.SetParams(m)
assert.Nil(t, err)
svr, err := CreateServer(context.TODO(), factory)
assert.Nil(t, err)
ms := newMockMasterService()
err = ms.Init()
assert.Nil(t, err)
err = ms.Start()
assert.Nil(t, err)
defer ms.Stop()
svr.SetMasterClient(ms)
svr.createDataNodeClient = func(addr string) types.DataNode {
return newMockDataNodeClient(0)
}
assert.Nil(t, err)
err = svr.Init()
assert.Nil(t, err)
err = svr.Start()
assert.Nil(t, err)
defer svr.Stop()
t.Run("register node", func(t *testing.T) {
resp, err := svr.RegisterNode(context.TODO(), &datapb.RegisterNodeRequest{
Base: &commonpb.MsgBase{
MsgType: 0,
MsgID: 0,
Timestamp: 0,
SourceID: 1000,
},
Address: &commonpb.Address{
Ip: "localhost",
Port: 1000,
},
})
assert.Nil(t, err)
assert.EqualValues(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
assert.EqualValues(t, Params.DataNodeNum, svr.cluster.GetNumOfNodes())
assert.EqualValues(t, []int64{1000}, svr.cluster.GetNodeIDs())
})
}

View File

@ -18,8 +18,7 @@ import (
func TestDataNodeTTWatcher(t *testing.T) { func TestDataNodeTTWatcher(t *testing.T) {
ctx := context.Background() ctx := context.Background()
Params.Init() Params.Init()
c := make(chan struct{}) cluster := newDataNodeCluster()
cluster := newDataNodeCluster(c)
defer cluster.ShutDownClients() defer cluster.ShutDownClients()
schema := newTestSchema() schema := newTestSchema()
allocator := newMockAllocator() allocator := newMockAllocator()

View File

@ -5,6 +5,7 @@ import (
"sync" "sync"
"github.com/zilliztech/milvus-distributed/internal/logutil" "github.com/zilliztech/milvus-distributed/internal/logutil"
"go.uber.org/zap"
"github.com/zilliztech/milvus-distributed/internal/log" "github.com/zilliztech/milvus-distributed/internal/log"
ms "github.com/zilliztech/milvus-distributed/internal/msgstream" ms "github.com/zilliztech/milvus-distributed/internal/msgstream"
@ -39,7 +40,8 @@ func (producer *MsgProducer) broadcastMsg() {
} }
tt, err := producer.ttBarrier.GetTimeTick() tt, err := producer.ttBarrier.GetTimeTick()
if err != nil { if err != nil {
log.Debug("broadcast get time tick error") log.Debug("broadcast get time tick error", zap.Error(err))
return
} }
baseMsg := ms.BaseMsg{ baseMsg := ms.BaseMsg{
BeginTimestamp: tt, BeginTimestamp: tt,

View File

@ -3,6 +3,7 @@ package timesync
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"math" "math"
"sync" "sync"
"sync/atomic" "sync/atomic"
@ -25,6 +26,7 @@ type (
GetTimeTick() (Timestamp, error) GetTimeTick() (Timestamp, error)
Start() Start()
Close() Close()
AddPeer(peerID UniqueID) error
} }
softTimeTickBarrier struct { softTimeTickBarrier struct {
@ -37,13 +39,13 @@ type (
} }
hardTimeTickBarrier struct { hardTimeTickBarrier struct {
peer2Tt map[UniqueID]Timestamp peer2Tt map[UniqueID]Timestamp
outTt chan Timestamp peer2TtMu sync.Mutex
ttStream ms.MsgStream outTt chan Timestamp
ctx context.Context ttStream ms.MsgStream
wg sync.WaitGroup ctx context.Context
loopCtx context.Context cancel context.CancelFunc
loopCancel context.CancelFunc wg sync.WaitGroup
} }
) )
@ -149,7 +151,6 @@ func (ttBarrier *hardTimeTickBarrier) GetTimeTick() (Timestamp, error) {
func (ttBarrier *hardTimeTickBarrier) Start() { func (ttBarrier *hardTimeTickBarrier) Start() {
// Last timestamp synchronized // Last timestamp synchronized
ttBarrier.wg.Add(1) ttBarrier.wg.Add(1)
ttBarrier.loopCtx, ttBarrier.loopCancel = context.WithCancel(ttBarrier.ctx)
state := Timestamp(0) state := Timestamp(0)
go func(ctx context.Context) { go func(ctx context.Context) {
defer logutil.LogPanic() defer logutil.LogPanic()
@ -162,8 +163,10 @@ func (ttBarrier *hardTimeTickBarrier) Start() {
default: default:
} }
ttmsgs := ttBarrier.ttStream.Consume() ttmsgs := ttBarrier.ttStream.Consume()
if len(ttmsgs.Msgs) > 0 {
if ttmsgs != nil && len(ttmsgs.Msgs) > 0 {
log.Debug("receive tt msg") log.Debug("receive tt msg")
ttBarrier.peer2TtMu.Lock()
for _, timetickmsg := range ttmsgs.Msgs { for _, timetickmsg := range ttmsgs.Msgs {
// Suppose ttmsg.Timestamp from stream is always larger than the previous one, // Suppose ttmsg.Timestamp from stream is always larger than the previous one,
// that `ttmsg.Timestamp > oldT` // that `ttmsg.Timestamp > oldT`
@ -181,20 +184,20 @@ func (ttBarrier *hardTimeTickBarrier) Start() {
} }
ttBarrier.peer2Tt[ttmsg.Base.SourceID] = ttmsg.Base.Timestamp ttBarrier.peer2Tt[ttmsg.Base.SourceID] = ttmsg.Base.Timestamp
newState := ttBarrier.minTimestamp() newState := ttBarrier.minTimestamp()
if newState > state { if newState > state {
ttBarrier.outTt <- newState ttBarrier.outTt <- newState
state = newState state = newState
} }
} }
ttBarrier.peer2TtMu.Unlock()
} }
} }
}(ttBarrier.loopCtx) }(ttBarrier.ctx)
} }
func (ttBarrier *hardTimeTickBarrier) Close() { func (ttBarrier *hardTimeTickBarrier) Close() {
ttBarrier.loopCancel() ttBarrier.cancel()
ttBarrier.wg.Wait() ttBarrier.wg.Wait()
} }
@ -208,24 +211,29 @@ func (ttBarrier *hardTimeTickBarrier) minTimestamp() Timestamp {
return tempMin return tempMin
} }
func NewHardTimeTickBarrier(ctx context.Context, ttStream ms.MsgStream, peerIds []UniqueID) *hardTimeTickBarrier { func (ttBarrier *hardTimeTickBarrier) AddPeer(peerID UniqueID) error {
if len(peerIds) <= 0 { ttBarrier.peer2TtMu.Lock()
log.Error("[newSoftTimeTickBarrier] peerIds is empty!") defer ttBarrier.peer2TtMu.Unlock()
return nil if _, ok := ttBarrier.peer2Tt[peerID]; ok {
return fmt.Errorf("peer %d already exist", peerID)
} }
ttBarrier.peer2Tt[peerID] = Timestamp(0)
sttbarrier := hardTimeTickBarrier{} return nil
sttbarrier.ttStream = ttStream }
sttbarrier.outTt = make(chan Timestamp, 1024)
func NewHardTimeTickBarrier(ctx context.Context, ttStream ms.MsgStream, peerIds []UniqueID) *hardTimeTickBarrier {
sttbarrier.peer2Tt = make(map[UniqueID]Timestamp) ttbarrier := hardTimeTickBarrier{}
sttbarrier.ctx = ctx ttbarrier.ttStream = ttStream
for _, id := range peerIds { ttbarrier.outTt = make(chan Timestamp, 1024)
sttbarrier.peer2Tt[id] = Timestamp(0)
} ttbarrier.peer2Tt = make(map[UniqueID]Timestamp)
if len(peerIds) != len(sttbarrier.peer2Tt) { ttbarrier.ctx, ttbarrier.cancel = context.WithCancel(ctx)
log.Warn("[newSoftTimeTickBarrier] there are duplicate peerIds!", zap.Int64s("peerIDs", peerIds)) for _, id := range peerIds {
} ttbarrier.peer2Tt[id] = Timestamp(0)
}
return &sttbarrier if len(peerIds) != len(ttbarrier.peer2Tt) {
log.Warn("[newHardTimeTickBarrier] there are duplicate peerIds!", zap.Int64s("peerIDs", peerIds))
}
return &ttbarrier
} }