mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-12-01 03:18:29 +08:00
Enhance dml channel operations (#12143)
Signed-off-by: yudong.cai <yudong.cai@zilliz.com>
This commit is contained in:
parent
ac175b6f00
commit
d4c297b1a8
@ -22,11 +22,16 @@ import (
|
||||
"github.com/milvus-io/milvus/internal/msgstream"
|
||||
)
|
||||
|
||||
type dmlMsgStream struct {
|
||||
ms msgstream.MsgStream
|
||||
mutex sync.RWMutex
|
||||
refcnt int64
|
||||
}
|
||||
|
||||
type dmlChannels struct {
|
||||
core *Core
|
||||
namePrefix string
|
||||
capacity int64
|
||||
refcnt sync.Map
|
||||
idx *atomic.Int64
|
||||
pool sync.Map
|
||||
}
|
||||
@ -36,60 +41,71 @@ func newDmlChannels(c *Core, chanNamePrefix string, chanNum int64) *dmlChannels
|
||||
core: c,
|
||||
namePrefix: chanNamePrefix,
|
||||
capacity: chanNum,
|
||||
refcnt: sync.Map{},
|
||||
idx: atomic.NewInt64(0),
|
||||
pool: sync.Map{},
|
||||
}
|
||||
|
||||
var i int64
|
||||
for i = 0; i < chanNum; i++ {
|
||||
name := fmt.Sprintf("%s_%d", d.namePrefix, i)
|
||||
for i := int64(0); i < chanNum; i++ {
|
||||
name := getDmlChannelName(d.namePrefix, i)
|
||||
ms, err := c.msFactory.NewMsgStream(c.ctx)
|
||||
if err != nil {
|
||||
log.Error("Failed to add msgstream", zap.String("name", name), zap.Error(err))
|
||||
panic("Failed to add msgstream")
|
||||
}
|
||||
d.pool.Store(name, &ms)
|
||||
d.pool.Store(name, &dmlMsgStream{
|
||||
ms: ms,
|
||||
mutex: sync.RWMutex{},
|
||||
refcnt: 0,
|
||||
})
|
||||
}
|
||||
log.Debug("init dml channels", zap.Int64("num", chanNum))
|
||||
return d
|
||||
}
|
||||
|
||||
func (d *dmlChannels) GetDmlMsgStreamName() string {
|
||||
cnt := d.idx.Load()
|
||||
name := fmt.Sprintf("%s_%d", d.namePrefix, cnt)
|
||||
d.idx.Store((cnt + 1) % d.capacity)
|
||||
return name
|
||||
cnt := d.idx.Inc()
|
||||
return getDmlChannelName(d.namePrefix, (cnt-1)%d.capacity)
|
||||
}
|
||||
|
||||
// ListChannels lists all dml channel names
|
||||
func (d *dmlChannels) ListChannels() []string {
|
||||
// ListPhysicalChannels lists all dml channel names
|
||||
func (d *dmlChannels) ListPhysicalChannels() []string {
|
||||
var chanNames []string
|
||||
d.refcnt.Range(
|
||||
d.pool.Range(
|
||||
func(k, v interface{}) bool {
|
||||
chanNames = append(chanNames, k.(string))
|
||||
dms := v.(*dmlMsgStream)
|
||||
dms.mutex.RLock()
|
||||
if dms.refcnt > 0 {
|
||||
chanNames = append(chanNames, k.(string))
|
||||
}
|
||||
dms.mutex.RUnlock()
|
||||
return true
|
||||
})
|
||||
return chanNames
|
||||
}
|
||||
|
||||
// GetNumChannels get current dml channel count
|
||||
func (d *dmlChannels) GetNumChannels() int {
|
||||
return len(d.ListChannels())
|
||||
func (d *dmlChannels) GetPhysicalChannelNum() int {
|
||||
return len(d.ListPhysicalChannels())
|
||||
}
|
||||
|
||||
// Broadcast broadcasts msg pack into specified channel
|
||||
func (d *dmlChannels) Broadcast(chanNames []string, pack *msgstream.MsgPack) error {
|
||||
for _, chanName := range chanNames {
|
||||
// only in-use chanName exist in refcnt
|
||||
if _, ok := d.refcnt.Load(chanName); ok {
|
||||
v, _ := d.pool.Load(chanName)
|
||||
if err := (*(v.(*msgstream.MsgStream))).Broadcast(pack); err != nil {
|
||||
v, ok := d.pool.Load(chanName)
|
||||
if !ok {
|
||||
log.Error("invalid channel name", zap.String("chanName", chanName))
|
||||
panic("invalid channel name: " + chanName)
|
||||
}
|
||||
dms := v.(*dmlMsgStream)
|
||||
|
||||
dms.mutex.RLock()
|
||||
if dms.refcnt > 0 {
|
||||
if err := dms.ms.Broadcast(pack); err != nil {
|
||||
log.Error("Broadcast failed", zap.String("chanName", chanName))
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("channel %s not exist", chanName)
|
||||
}
|
||||
dms.mutex.RUnlock()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@ -98,22 +114,28 @@ func (d *dmlChannels) Broadcast(chanNames []string, pack *msgstream.MsgPack) err
|
||||
func (d *dmlChannels) BroadcastMark(chanNames []string, pack *msgstream.MsgPack) (map[string][]byte, error) {
|
||||
result := make(map[string][]byte)
|
||||
for _, chanName := range chanNames {
|
||||
// only in-use chanName exist in refcnt
|
||||
if _, ok := d.refcnt.Load(chanName); ok {
|
||||
v, _ := d.pool.Load(chanName)
|
||||
ids, err := (*(v.(*msgstream.MsgStream))).BroadcastMark(pack)
|
||||
v, ok := d.pool.Load(chanName)
|
||||
if !ok {
|
||||
log.Error("invalid channel name", zap.String("chanName", chanName))
|
||||
panic("invalid channel name: " + chanName)
|
||||
}
|
||||
dms := v.(*dmlMsgStream)
|
||||
|
||||
dms.mutex.RLock()
|
||||
if dms.refcnt > 0 {
|
||||
ids, err := dms.ms.BroadcastMark(pack)
|
||||
if err != nil {
|
||||
log.Error("BroadcastMark failed", zap.String("chanName", chanName))
|
||||
return result, err
|
||||
}
|
||||
for chanName, idList := range ids {
|
||||
for cn, idList := range ids {
|
||||
// idList should have length 1, just flat by iteration
|
||||
for _, id := range idList {
|
||||
result[chanName] = id.Serialize()
|
||||
result[cn] = id.Serialize()
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return result, fmt.Errorf("channel %s not exist", chanName)
|
||||
}
|
||||
dms.mutex.RUnlock()
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
@ -121,38 +143,43 @@ func (d *dmlChannels) BroadcastMark(chanNames []string, pack *msgstream.MsgPack)
|
||||
// AddProducerChannels add named channels as producer
|
||||
func (d *dmlChannels) AddProducerChannels(names ...string) {
|
||||
for _, name := range names {
|
||||
if v, ok := d.pool.Load(name); ok {
|
||||
var cnt int64
|
||||
if _, ok := d.refcnt.Load(name); !ok {
|
||||
ms := *(v.(*msgstream.MsgStream))
|
||||
ms.AsProducer([]string{name})
|
||||
cnt = 1
|
||||
} else {
|
||||
v, _ := d.refcnt.Load(name)
|
||||
cnt = v.(int64) + 1
|
||||
}
|
||||
d.refcnt.Store(name, cnt)
|
||||
log.Debug("assign dml channel", zap.String("chanName", name), zap.Int64("refcnt", cnt))
|
||||
} else {
|
||||
v, ok := d.pool.Load(name)
|
||||
if !ok {
|
||||
log.Error("invalid channel name", zap.String("chanName", name))
|
||||
panic("invalid channel name: " + name)
|
||||
}
|
||||
dms := v.(*dmlMsgStream)
|
||||
|
||||
dms.mutex.Lock()
|
||||
if dms.refcnt == 0 {
|
||||
dms.ms.AsProducer([]string{name})
|
||||
}
|
||||
dms.refcnt++
|
||||
dms.mutex.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// RemoveProducerChannels removes specified channels
|
||||
func (d *dmlChannels) RemoveProducerChannels(names ...string) {
|
||||
for _, name := range names {
|
||||
if v, ok := d.refcnt.Load(name); ok {
|
||||
cnt := v.(int64)
|
||||
if cnt > 1 {
|
||||
d.refcnt.Store(name, cnt-1)
|
||||
} else {
|
||||
v1, _ := d.pool.Load(name)
|
||||
ms := *(v1.(*msgstream.MsgStream))
|
||||
ms.Close()
|
||||
d.refcnt.Delete(name)
|
||||
v, ok := d.pool.Load(name)
|
||||
if !ok {
|
||||
log.Error("invalid channel name", zap.String("chanName", name))
|
||||
panic("invalid channel name: " + name)
|
||||
}
|
||||
dms := v.(*dmlMsgStream)
|
||||
|
||||
dms.mutex.Lock()
|
||||
if dms.refcnt > 0 {
|
||||
dms.refcnt--
|
||||
if dms.refcnt == 0 {
|
||||
dms.ms.Close()
|
||||
}
|
||||
}
|
||||
dms.mutex.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func getDmlChannelName(prefix string, idx int64) string {
|
||||
return fmt.Sprintf("%s_%d", prefix, idx)
|
||||
}
|
||||
|
@ -13,7 +13,6 @@ package rootcoord
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/msgstream"
|
||||
@ -43,36 +42,35 @@ func TestDmlChannels(t *testing.T) {
|
||||
assert.Nil(t, err)
|
||||
|
||||
dml := newDmlChannels(core, dmlChanPrefix, totalDmlChannelNum)
|
||||
chanNames := dml.ListChannels()
|
||||
chanNames := dml.ListPhysicalChannels()
|
||||
assert.Equal(t, 0, len(chanNames))
|
||||
|
||||
randStr := funcutil.RandomString(8)
|
||||
assert.Panics(t, func() { dml.AddProducerChannels(randStr) })
|
||||
|
||||
err = dml.Broadcast([]string{randStr}, nil)
|
||||
assert.NotNil(t, err)
|
||||
assert.EqualError(t, err, fmt.Sprintf("channel %s not exist", randStr))
|
||||
assert.Panics(t, func() { dml.Broadcast([]string{randStr}, nil) })
|
||||
assert.Panics(t, func() { dml.BroadcastMark([]string{randStr}, nil) })
|
||||
assert.Panics(t, func() { dml.RemoveProducerChannels(randStr) })
|
||||
|
||||
// dml_xxx_0 => {chanName0, chanName2}
|
||||
// dml_xxx_1 => {chanName1}
|
||||
chanName0 := dml.GetDmlMsgStreamName()
|
||||
dml.AddProducerChannels(chanName0)
|
||||
assert.Equal(t, 1, dml.GetNumChannels())
|
||||
assert.Equal(t, 1, dml.GetPhysicalChannelNum())
|
||||
|
||||
chanName1 := dml.GetDmlMsgStreamName()
|
||||
dml.AddProducerChannels(chanName1)
|
||||
assert.Equal(t, 2, dml.GetNumChannels())
|
||||
assert.Equal(t, 2, dml.GetPhysicalChannelNum())
|
||||
|
||||
chanName2 := dml.GetDmlMsgStreamName()
|
||||
dml.AddProducerChannels(chanName2)
|
||||
assert.Equal(t, 2, dml.GetNumChannels())
|
||||
assert.Equal(t, 2, dml.GetPhysicalChannelNum())
|
||||
|
||||
dml.RemoveProducerChannels(chanName0)
|
||||
assert.Equal(t, 2, dml.GetNumChannels())
|
||||
assert.Equal(t, 2, dml.GetPhysicalChannelNum())
|
||||
|
||||
dml.RemoveProducerChannels(chanName1)
|
||||
assert.Equal(t, 1, dml.GetNumChannels())
|
||||
assert.Equal(t, 1, dml.GetPhysicalChannelNum())
|
||||
|
||||
dml.RemoveProducerChannels(chanName0)
|
||||
assert.Equal(t, 0, dml.GetNumChannels())
|
||||
assert.Equal(t, 0, dml.GetPhysicalChannelNum())
|
||||
}
|
||||
|
@ -482,7 +482,7 @@ func (c *Core) setMsgStreams() error {
|
||||
metrics.RootCoordDDChannelTimeTick.Set(float64(tsoutil.Mod24H(t)))
|
||||
|
||||
//c.dmlChannels.BroadcastAll(&msgPack)
|
||||
pc := c.dmlChannels.ListChannels()
|
||||
pc := c.dmlChannels.ListPhysicalChannels()
|
||||
pt := make([]uint64, len(pc))
|
||||
for i := 0; i < len(pt); i++ {
|
||||
pt[i] = t
|
||||
|
@ -680,7 +680,7 @@ func TestRootCoord(t *testing.T) {
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_Success, status.ErrorCode)
|
||||
|
||||
assert.Equal(t, shardsNum, int32(core.dmlChannels.GetNumChannels()))
|
||||
assert.Equal(t, shardsNum, int32(core.dmlChannels.GetPhysicalChannelNum()))
|
||||
|
||||
createMeta, err := core.MetaTable.GetCollectionByName(collName, 0)
|
||||
assert.Nil(t, err)
|
||||
|
@ -310,7 +310,7 @@ func (t *timetickSync) GetProxyNum() int {
|
||||
|
||||
// GetChanNum return the num of channel
|
||||
func (t *timetickSync) GetChanNum() int {
|
||||
return t.core.dmlChannels.GetNumChannels()
|
||||
return t.core.dmlChannels.GetPhysicalChannelNum()
|
||||
}
|
||||
|
||||
func minTimeTick(tt ...typeutil.Timestamp) typeutil.Timestamp {
|
||||
|
Loading…
Reference in New Issue
Block a user