mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-12-01 03:18:29 +08:00
Optimize dml_channels (#5783)
* update timetickSync::UpdateTimeTick Signed-off-by: yudong.cai <yudong.cai@zilliz.com> * update dml_channels.go Signed-off-by: yudong.cai <yudong.cai@zilliz.com> * fix unittest Signed-off-by: yudong.cai <yudong.cai@zilliz.com> * remove ProduceAll and BroadcastAll Signed-off-by: yudong.cai <yudong.cai@zilliz.com>
This commit is contained in:
parent
5c18138a6b
commit
587ccc0557
@ -20,117 +20,111 @@ import (
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type dmlStream struct {
|
||||
msgStream msgstream.MsgStream
|
||||
valid bool
|
||||
}
|
||||
|
||||
type dmlChannels struct {
|
||||
core *Core
|
||||
lock sync.RWMutex
|
||||
dml map[string]msgstream.MsgStream
|
||||
dml map[string]*dmlStream
|
||||
}
|
||||
|
||||
func newDMLChannels(c *Core) *dmlChannels {
|
||||
return &dmlChannels{
|
||||
core: c,
|
||||
lock: sync.RWMutex{},
|
||||
dml: make(map[string]msgstream.MsgStream),
|
||||
dml: make(map[string]*dmlStream),
|
||||
}
|
||||
}
|
||||
|
||||
func (d *dmlChannels) GetNumChannles() int {
|
||||
d.lock.RLock()
|
||||
defer d.lock.RUnlock()
|
||||
return len(d.dml)
|
||||
}
|
||||
|
||||
func (d *dmlChannels) ProduceAll(pack *msgstream.MsgPack) {
|
||||
d.lock.RLock()
|
||||
defer d.lock.RUnlock()
|
||||
|
||||
for n, ms := range d.dml {
|
||||
if err := ms.Produce(pack); err != nil {
|
||||
log.Debug("msgstream produce error", zap.String("name", n), zap.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (d *dmlChannels) BroadcastMany(channels []string, pack *msgstream.MsgPack) error {
|
||||
d.lock.RLock()
|
||||
defer d.lock.RUnlock()
|
||||
for _, ch := range channels {
|
||||
ms, ok := d.dml[ch]
|
||||
if !ok {
|
||||
return fmt.Errorf("channel %s not exist", ch)
|
||||
}
|
||||
if err := ms.Broadcast(pack); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *dmlChannels) BroadcastAll(pack *msgstream.MsgPack) {
|
||||
d.lock.RLock()
|
||||
defer d.lock.RUnlock()
|
||||
|
||||
for n, ms := range d.dml {
|
||||
if err := ms.Broadcast(pack); err != nil {
|
||||
log.Debug("msgstream broadcast error", zap.String("name", n), zap.Error(err))
|
||||
count := 0
|
||||
for _, ds := range d.dml {
|
||||
if ds.valid {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
func (d *dmlChannels) Produce(name string, pack *msgstream.MsgPack) error {
|
||||
d.lock.Lock()
|
||||
defer d.lock.Unlock()
|
||||
|
||||
var err error
|
||||
ms, ok := d.dml[name]
|
||||
ds, ok := d.dml[name]
|
||||
if !ok {
|
||||
ms, err = d.core.msFactory.NewMsgStream(d.core.ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create mstream failed, name = %s, error=%w", name, err)
|
||||
}
|
||||
ms.AsProducer([]string{name})
|
||||
d.dml[name] = ms
|
||||
return fmt.Errorf("channel %s not exist", name)
|
||||
}
|
||||
return ms.Produce(pack)
|
||||
|
||||
if err := ds.msgStream.Produce(pack); err != nil {
|
||||
return err
|
||||
}
|
||||
if !ds.valid {
|
||||
ds.msgStream.Close()
|
||||
delete(d.dml, name)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *dmlChannels) Broadcast(name string, pack *msgstream.MsgPack) error {
|
||||
d.lock.Lock()
|
||||
defer d.lock.Unlock()
|
||||
|
||||
if len(name) == 0 {
|
||||
return fmt.Errorf("channel name is empty")
|
||||
}
|
||||
var err error
|
||||
ms, ok := d.dml[name]
|
||||
ds, ok := d.dml[name]
|
||||
if !ok {
|
||||
ms, err = d.core.msFactory.NewMsgStream(d.core.ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create msgtream failed, name = %s, error=%w", name, err)
|
||||
}
|
||||
ms.AsProducer([]string{name})
|
||||
d.dml[name] = ms
|
||||
return fmt.Errorf("channel %s not exist", name)
|
||||
}
|
||||
return ms.Broadcast(pack)
|
||||
if err := ds.msgStream.Broadcast(pack); err != nil {
|
||||
return err
|
||||
}
|
||||
if !ds.valid {
|
||||
ds.msgStream.Close()
|
||||
delete(d.dml, name)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *dmlChannels) BroadcastAll(channels []string, pack *msgstream.MsgPack) error {
|
||||
d.lock.Lock()
|
||||
defer d.lock.Unlock()
|
||||
|
||||
for _, ch := range channels {
|
||||
ds, ok := d.dml[ch]
|
||||
if !ok {
|
||||
return fmt.Errorf("channel %s not exist", ch)
|
||||
}
|
||||
if err := ds.msgStream.Broadcast(pack); err != nil {
|
||||
return err
|
||||
}
|
||||
if !ds.valid {
|
||||
ds.msgStream.Close()
|
||||
delete(d.dml, ch)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *dmlChannels) AddProducerChannels(names ...string) {
|
||||
d.lock.Lock()
|
||||
defer d.lock.Unlock()
|
||||
|
||||
var err error
|
||||
for _, name := range names {
|
||||
log.Debug("add dml channel", zap.String("channel name", name))
|
||||
ms, ok := d.dml[name]
|
||||
_, ok := d.dml[name]
|
||||
if !ok {
|
||||
ms, err = d.core.msFactory.NewMsgStream(d.core.ctx)
|
||||
ms, err := d.core.msFactory.NewMsgStream(d.core.ctx)
|
||||
if err != nil {
|
||||
log.Debug("add msgstream failed", zap.String("name", name), zap.Error(err))
|
||||
continue
|
||||
}
|
||||
ms.AsProducer([]string{name})
|
||||
d.dml[name] = ms
|
||||
d.dml[name] = &dmlStream{
|
||||
msgStream: ms,
|
||||
valid: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -141,22 +135,8 @@ func (d *dmlChannels) RemoveProducerChannels(names ...string) {
|
||||
|
||||
for _, name := range names {
|
||||
log.Debug("delete dml channel", zap.String("channel name", name))
|
||||
if ms, ok := d.dml[name]; ok {
|
||||
ms.Close()
|
||||
delete(d.dml, name)
|
||||
if ds, ok := d.dml[name]; ok {
|
||||
ds.valid = false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (d *dmlChannels) HasChannel(names ...string) bool {
|
||||
d.lock.Lock()
|
||||
defer d.lock.Unlock()
|
||||
|
||||
for _, name := range names {
|
||||
if _, ok := d.dml[name]; !ok {
|
||||
log.Debug("unknown channel", zap.String("channel name", name))
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
@ -602,7 +602,7 @@ func (c *Core) setMsgStreams() error {
|
||||
CreateCollectionRequest: *req,
|
||||
}
|
||||
msgPack.Msgs = append(msgPack.Msgs, msg)
|
||||
return c.dmlChannels.BroadcastMany(channelNames, &msgPack)
|
||||
return c.dmlChannels.BroadcastAll(channelNames, &msgPack)
|
||||
}
|
||||
|
||||
c.SendDdDropCollectionReq = func(ctx context.Context, req *internalpb.DropCollectionRequest, channelNames []string) error {
|
||||
@ -618,7 +618,7 @@ func (c *Core) setMsgStreams() error {
|
||||
DropCollectionRequest: *req,
|
||||
}
|
||||
msgPack.Msgs = append(msgPack.Msgs, msg)
|
||||
return c.dmlChannels.BroadcastMany(channelNames, &msgPack)
|
||||
return c.dmlChannels.BroadcastAll(channelNames, &msgPack)
|
||||
}
|
||||
|
||||
c.SendDdCreatePartitionReq = func(ctx context.Context, req *internalpb.CreatePartitionRequest, channelNames []string) error {
|
||||
@ -634,7 +634,7 @@ func (c *Core) setMsgStreams() error {
|
||||
CreatePartitionRequest: *req,
|
||||
}
|
||||
msgPack.Msgs = append(msgPack.Msgs, msg)
|
||||
return c.dmlChannels.BroadcastMany(channelNames, &msgPack)
|
||||
return c.dmlChannels.BroadcastAll(channelNames, &msgPack)
|
||||
}
|
||||
|
||||
c.SendDdDropPartitionReq = func(ctx context.Context, req *internalpb.DropPartitionRequest, channelNames []string) error {
|
||||
@ -650,7 +650,7 @@ func (c *Core) setMsgStreams() error {
|
||||
DropPartitionRequest: *req,
|
||||
}
|
||||
msgPack.Msgs = append(msgPack.Msgs, msg)
|
||||
return c.dmlChannels.BroadcastMany(channelNames, &msgPack)
|
||||
return c.dmlChannels.BroadcastAll(channelNames, &msgPack)
|
||||
}
|
||||
|
||||
if Params.DataServiceSegmentChannel == "" {
|
||||
@ -1885,12 +1885,6 @@ func (c *Core) UpdateChannelTimeTick(ctx context.Context, in *internalpb.Channel
|
||||
status.Reason = fmt.Sprintf("UpdateChannelTimeTick receive invalid message %d", in.Base.GetMsgType())
|
||||
return status, nil
|
||||
}
|
||||
if !c.dmlChannels.HasChannel(in.ChannelNames...) {
|
||||
log.Debug("update time tick with unkonw channel", zap.Int("input channel size", len(in.ChannelNames)), zap.Strings("input channels", in.ChannelNames))
|
||||
status.ErrorCode = commonpb.ErrorCode_UnexpectedError
|
||||
status.Reason = fmt.Sprintf("update time tick with unknown channel name, input channels = %v", in.ChannelNames)
|
||||
return status, nil
|
||||
}
|
||||
err := c.chanTimeTick.UpdateTimeTick(in)
|
||||
if err != nil {
|
||||
status.ErrorCode = commonpb.ErrorCode_UnexpectedError
|
||||
|
@ -45,14 +45,14 @@ func newTimeTickSync(core *Core) *timetickSync {
|
||||
// sendToChannel send all channels' timetick to sendChan
|
||||
// lock is needed by the invoker
|
||||
func (t *timetickSync) sendToChannel() {
|
||||
if len(t.proxyTimeTick) == 0 {
|
||||
return
|
||||
}
|
||||
for _, v := range t.proxyTimeTick {
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
if len(t.proxyTimeTick) == 0 {
|
||||
return
|
||||
}
|
||||
// clear proxyTimeTick and send a clone
|
||||
ptt := make(map[typeutil.UniqueID]*internalpb.ChannelTimeTickMsg)
|
||||
for k, v := range t.proxyTimeTick {
|
||||
@ -77,9 +77,11 @@ func (t *timetickSync) UpdateTimeTick(in *internalpb.ChannelTimeTickMsg) error {
|
||||
if !ok {
|
||||
return fmt.Errorf("Skip ChannelTimeTickMsg from un-recognized proxy node %d", in.Base.SourceID)
|
||||
}
|
||||
if prev != nil && prev.Timestamps[0] >= in.Timestamps[0] {
|
||||
log.Debug("timestamp go back", zap.Int64("source id", in.Base.SourceID), zap.Uint64("prev ts", prev.Timestamps[0]), zap.Uint64("curr ts", in.Timestamps[0]))
|
||||
return nil
|
||||
if in.Base.SourceID == t.core.session.ServerID {
|
||||
if prev != nil && prev.Timestamps[0] >= in.Timestamps[0] {
|
||||
log.Debug("timestamp go back", zap.Int64("source id", in.Base.SourceID), zap.Uint64("prev ts", prev.Timestamps[0]), zap.Uint64("curr ts", in.Timestamps[0]))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
t.proxyTimeTick[in.Base.SourceID] = in
|
||||
t.sendToChannel()
|
||||
|
Loading…
Reference in New Issue
Block a user