mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-12-01 11:29:48 +08:00
Improve the close
method in the graph (#19100)
Signed-off-by: SimFG <bang.fu@zilliz.com> Signed-off-by: SimFG <bang.fu@zilliz.com>
This commit is contained in:
parent
3927ae9952
commit
ceea04c274
@ -89,13 +89,6 @@ func (fg *TimeTickedFlowGraph) Close() {
|
||||
fg.stopOnce.Do(func() {
|
||||
for _, v := range fg.nodeCtx {
|
||||
if v.node.IsInputNode() {
|
||||
// close inputNode first
|
||||
v.Close()
|
||||
}
|
||||
}
|
||||
for _, v := range fg.nodeCtx {
|
||||
if !v.node.IsInputNode() {
|
||||
// close other nodes
|
||||
v.Close()
|
||||
}
|
||||
}
|
||||
|
@ -28,8 +28,9 @@ import (
|
||||
// InputNode is the entry point of flowgragh
|
||||
type InputNode struct {
|
||||
BaseNode
|
||||
inStream msgstream.MsgStream
|
||||
name string
|
||||
inStream msgstream.MsgStream
|
||||
name string
|
||||
closeMsgChan chan struct{}
|
||||
}
|
||||
|
||||
// IsInputNode returns whether Node is InputNode
|
||||
@ -44,10 +45,15 @@ func (inNode *InputNode) Start() {
|
||||
|
||||
// Close implements node
|
||||
func (inNode *InputNode) Close() {
|
||||
inNode.inStream.Close()
|
||||
log.Debug("message stream closed",
|
||||
zap.String("node name", inNode.name),
|
||||
)
|
||||
select {
|
||||
case <-inNode.closeMsgChan:
|
||||
return
|
||||
default:
|
||||
close(inNode.closeMsgChan)
|
||||
log.Debug("message stream closed",
|
||||
zap.String("node name", inNode.name),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Name returns node name
|
||||
@ -62,37 +68,44 @@ func (inNode *InputNode) InStream() msgstream.MsgStream {
|
||||
|
||||
// Operate consume a message pack from msgstream and return
|
||||
func (inNode *InputNode) Operate(in []Msg) []Msg {
|
||||
msgPack, ok := <-inNode.inStream.Chan()
|
||||
if !ok {
|
||||
log.Warn("MsgStream closed", zap.Any("input node", inNode.Name()))
|
||||
return []Msg{}
|
||||
}
|
||||
select {
|
||||
case <-inNode.closeMsgChan:
|
||||
inNode.inStream.Close()
|
||||
return []Msg{&MsgStreamMsg{
|
||||
isCloseMsg: true,
|
||||
}}
|
||||
case msgPack, ok := <-inNode.inStream.Chan():
|
||||
if !ok {
|
||||
log.Warn("MsgStream closed", zap.Any("input node", inNode.Name()))
|
||||
return []Msg{}
|
||||
}
|
||||
|
||||
// TODO: add status
|
||||
if msgPack == nil {
|
||||
return nil
|
||||
}
|
||||
var spans []opentracing.Span
|
||||
for _, msg := range msgPack.Msgs {
|
||||
sp, ctx := trace.StartSpanFromContext(msg.TraceCtx())
|
||||
sp.LogFields(oplog.String("input_node name", inNode.Name()))
|
||||
spans = append(spans, sp)
|
||||
msg.SetTraceCtx(ctx)
|
||||
}
|
||||
// TODO: add status
|
||||
if msgPack == nil {
|
||||
return nil
|
||||
}
|
||||
var spans []opentracing.Span
|
||||
for _, msg := range msgPack.Msgs {
|
||||
sp, ctx := trace.StartSpanFromContext(msg.TraceCtx())
|
||||
sp.LogFields(oplog.String("input_node name", inNode.Name()))
|
||||
spans = append(spans, sp)
|
||||
msg.SetTraceCtx(ctx)
|
||||
}
|
||||
|
||||
var msgStreamMsg Msg = &MsgStreamMsg{
|
||||
tsMessages: msgPack.Msgs,
|
||||
timestampMin: msgPack.BeginTs,
|
||||
timestampMax: msgPack.EndTs,
|
||||
startPositions: msgPack.StartPositions,
|
||||
endPositions: msgPack.EndPositions,
|
||||
}
|
||||
var msgStreamMsg Msg = &MsgStreamMsg{
|
||||
tsMessages: msgPack.Msgs,
|
||||
timestampMin: msgPack.BeginTs,
|
||||
timestampMax: msgPack.EndTs,
|
||||
startPositions: msgPack.StartPositions,
|
||||
endPositions: msgPack.EndPositions,
|
||||
}
|
||||
|
||||
for _, span := range spans {
|
||||
span.Finish()
|
||||
}
|
||||
for _, span := range spans {
|
||||
span.Finish()
|
||||
}
|
||||
|
||||
return []Msg{msgStreamMsg}
|
||||
return []Msg{msgStreamMsg}
|
||||
}
|
||||
}
|
||||
|
||||
// NewInputNode composes an InputNode with provided MsgStream, name and parameters
|
||||
@ -102,8 +115,9 @@ func NewInputNode(inStream msgstream.MsgStream, nodeName string, maxQueueLength
|
||||
baseNode.SetMaxParallelism(maxParallelism)
|
||||
|
||||
return &InputNode{
|
||||
BaseNode: baseNode,
|
||||
inStream: inStream,
|
||||
name: nodeName,
|
||||
BaseNode: baseNode,
|
||||
inStream: inStream,
|
||||
name: nodeName,
|
||||
closeMsgChan: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
@ -41,10 +41,7 @@ func TestInputNode(t *testing.T) {
|
||||
produceStream.Produce(&msgPack)
|
||||
|
||||
nodeName := "input_node"
|
||||
inputNode := &InputNode{
|
||||
inStream: msgStream,
|
||||
name: nodeName,
|
||||
}
|
||||
inputNode := NewInputNode(msgStream, nodeName, 100, 100)
|
||||
defer inputNode.Close()
|
||||
|
||||
isInputNode := inputNode.IsInputNode()
|
||||
|
@ -32,6 +32,7 @@ type MsgStreamMsg struct {
|
||||
timestampMax Timestamp
|
||||
startPositions []*MsgPosition
|
||||
endPositions []*MsgPosition
|
||||
isCloseMsg bool
|
||||
}
|
||||
|
||||
// GenerateMsgStreamMsg is used to create a new MsgStreamMsg object
|
||||
|
@ -58,24 +58,29 @@ type nodeCtx struct {
|
||||
downstream []*nodeCtx
|
||||
downstreamInputChanIdx map[string]int
|
||||
|
||||
closeCh chan struct{} // notify work to exit
|
||||
closeWg sync.WaitGroup // block Close until work exit
|
||||
closeCh chan struct{} // notify work to exit
|
||||
}
|
||||
|
||||
// Start invoke Node `Start` method and start a worker goroutine
|
||||
func (nodeCtx *nodeCtx) Start() {
|
||||
nodeCtx.node.Start()
|
||||
|
||||
nodeCtx.closeWg.Add(1)
|
||||
go nodeCtx.work()
|
||||
}
|
||||
|
||||
func isCloseMsg(msgs []Msg) bool {
|
||||
if len(msgs) == 1 {
|
||||
msg, ok := msgs[0].(*MsgStreamMsg)
|
||||
return ok && msg.isCloseMsg
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// work handles node work spinning
|
||||
// 1. collectMessage from upstream or just produce Msg from InputNode
|
||||
// 2. invoke node.Operate
|
||||
// 3. deliver the Operate result to downstream nodes
|
||||
func (nodeCtx *nodeCtx) work() {
|
||||
defer nodeCtx.closeWg.Done()
|
||||
name := fmt.Sprintf("nodeCtxTtChecker-%s", nodeCtx.node.Name())
|
||||
var checker *timerecord.GroupChecker
|
||||
if enableTtChecker {
|
||||
@ -98,8 +103,19 @@ func (nodeCtx *nodeCtx) work() {
|
||||
nodeCtx.collectInputMessages()
|
||||
inputs = nodeCtx.inputMessages
|
||||
}
|
||||
n := nodeCtx.node
|
||||
res = n.Operate(inputs)
|
||||
// the input message decides whether the operate method is executed
|
||||
if isCloseMsg(inputs) {
|
||||
res = inputs
|
||||
}
|
||||
if len(res) == 0 {
|
||||
n := nodeCtx.node
|
||||
res = n.Operate(inputs)
|
||||
}
|
||||
// the res decide whether the node should be closed.
|
||||
if isCloseMsg(res) {
|
||||
close(nodeCtx.closeCh)
|
||||
nodeCtx.node.Close()
|
||||
}
|
||||
|
||||
if enableTtChecker {
|
||||
checker.Check(name)
|
||||
@ -127,13 +143,7 @@ func (nodeCtx *nodeCtx) work() {
|
||||
// Close handles cleanup logic and notify worker to quit
|
||||
func (nodeCtx *nodeCtx) Close() {
|
||||
if nodeCtx.node.IsInputNode() {
|
||||
nodeCtx.node.Close() // close input msgStream
|
||||
close(nodeCtx.closeCh)
|
||||
nodeCtx.closeWg.Wait()
|
||||
} else {
|
||||
close(nodeCtx.closeCh)
|
||||
nodeCtx.closeWg.Wait()
|
||||
nodeCtx.node.Close() // close output msgStream, and etc...
|
||||
nodeCtx.node.Close()
|
||||
}
|
||||
}
|
||||
|
||||
@ -146,10 +156,7 @@ func (nodeCtx *nodeCtx) deliverMsg(wg *sync.WaitGroup, msg Msg, inputChanIdx int
|
||||
log.Warn(fmt.Sprintln(err))
|
||||
}
|
||||
}()
|
||||
select {
|
||||
case <-nodeCtx.closeCh:
|
||||
case nodeCtx.inputChannels[inputChanIdx] <- msg:
|
||||
}
|
||||
nodeCtx.inputChannels[inputChanIdx] <- msg
|
||||
}
|
||||
|
||||
func (nodeCtx *nodeCtx) collectInputMessages() {
|
||||
@ -161,17 +168,13 @@ func (nodeCtx *nodeCtx) collectInputMessages() {
|
||||
// and move them to inputMessages.
|
||||
for i := 0; i < inputsNum; i++ {
|
||||
channel := nodeCtx.inputChannels[i]
|
||||
select {
|
||||
case <-nodeCtx.closeCh:
|
||||
msg, ok := <-channel
|
||||
if !ok {
|
||||
// TODO: add status
|
||||
log.Warn("input channel closed")
|
||||
return
|
||||
case msg, ok := <-channel:
|
||||
if !ok {
|
||||
// TODO: add status
|
||||
log.Warn("input channel closed")
|
||||
return
|
||||
}
|
||||
nodeCtx.inputMessages[i] = msg
|
||||
}
|
||||
nodeCtx.inputMessages[i] = msg
|
||||
}
|
||||
|
||||
// timeTick alignment check
|
||||
@ -191,16 +194,12 @@ func (nodeCtx *nodeCtx) collectInputMessages() {
|
||||
for nodeCtx.inputMessages[i].TimeTick() != latestTime {
|
||||
log.Debug("Try to align timestamp", zap.Uint64("t1", latestTime), zap.Uint64("t2", nodeCtx.inputMessages[i].TimeTick()))
|
||||
channel := nodeCtx.inputChannels[i]
|
||||
select {
|
||||
case <-nodeCtx.closeCh:
|
||||
msg, ok := <-channel
|
||||
if !ok {
|
||||
log.Warn("input channel closed")
|
||||
return
|
||||
case msg, ok := <-channel:
|
||||
if !ok {
|
||||
log.Warn("input channel closed")
|
||||
return
|
||||
}
|
||||
nodeCtx.inputMessages[i] = msg
|
||||
}
|
||||
nodeCtx.inputMessages[i] = msg
|
||||
}
|
||||
}
|
||||
sign <- struct{}{}
|
||||
@ -210,7 +209,6 @@ func (nodeCtx *nodeCtx) collectInputMessages() {
|
||||
case <-time.After(10 * time.Second):
|
||||
panic("Fatal, misaligned time tick, please restart pulsar")
|
||||
case <-sign:
|
||||
case <-nodeCtx.closeCh:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -73,10 +73,7 @@ func TestNodeCtx_Start(t *testing.T) {
|
||||
produceStream.Produce(&msgPack)
|
||||
|
||||
nodeName := "input_node"
|
||||
inputNode := &InputNode{
|
||||
inStream: msgStream,
|
||||
name: nodeName,
|
||||
}
|
||||
inputNode := NewInputNode(msgStream, nodeName, 100, 100)
|
||||
|
||||
node := &nodeCtx{
|
||||
node: inputNode,
|
||||
|
Loading…
Reference in New Issue
Block a user