mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-12-01 19:39:21 +08:00
Add timeout in dispatcher, AsConsumer and Seek (#26686)
See also: #25309 Signed-off-by: yangxuan <xuan.yang@zilliz.com>
This commit is contained in:
parent
0901b76732
commit
7f1ae35e72
2
Makefile
2
Makefile
@ -88,7 +88,7 @@ else
|
||||
@GO111MODULE=on env bash $(PWD)/scripts/gofmt.sh internal/
|
||||
@GO111MODULE=on env bash $(PWD)/scripts/gofmt.sh tests/integration/
|
||||
@GO111MODULE=on env bash $(PWD)/scripts/gofmt.sh tests/go/
|
||||
@GO111MODULE=on env bash $(PWD)/scripts/gofmt.sh pkg/
|
||||
@GO111MODULE=on env bash $(PWD)/scripts/gofmt.sh pkg/
|
||||
endif
|
||||
|
||||
lint: tools/bin/revive
|
||||
|
@ -22,7 +22,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/util/tsoutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
@ -623,7 +622,6 @@ func Test_compactionTrigger_force_maxSegmentLimit(t *testing.T) {
|
||||
collectionID int64
|
||||
compactTime *compactTime
|
||||
}
|
||||
paramtable.Init()
|
||||
vecFieldID := int64(201)
|
||||
segmentInfos := &SegmentsInfo{
|
||||
segments: make(map[UniqueID]*SegmentInfo),
|
||||
@ -830,7 +828,6 @@ func Test_compactionTrigger_noplan(t *testing.T) {
|
||||
collectionID int64
|
||||
compactTime *compactTime
|
||||
}
|
||||
paramtable.Init()
|
||||
Params.DataCoordCfg.MinSegmentToMerge.DefaultValue = "4"
|
||||
vecFieldID := int64(201)
|
||||
tests := []struct {
|
||||
@ -972,7 +969,6 @@ func Test_compactionTrigger_PrioritizedCandi(t *testing.T) {
|
||||
collectionID int64
|
||||
compactTime *compactTime
|
||||
}
|
||||
paramtable.Init()
|
||||
vecFieldID := int64(201)
|
||||
|
||||
genSeg := func(segID, numRows int64) *datapb.SegmentInfo {
|
||||
@ -1155,7 +1151,6 @@ func Test_compactionTrigger_SmallCandi(t *testing.T) {
|
||||
collectionID int64
|
||||
compactTime *compactTime
|
||||
}
|
||||
paramtable.Init()
|
||||
vecFieldID := int64(201)
|
||||
|
||||
genSeg := func(segID, numRows int64) *datapb.SegmentInfo {
|
||||
@ -1338,7 +1333,6 @@ func Test_compactionTrigger_SqueezeNonPlannedSegs(t *testing.T) {
|
||||
collectionID int64
|
||||
compactTime *compactTime
|
||||
}
|
||||
paramtable.Init()
|
||||
vecFieldID := int64(201)
|
||||
|
||||
genSeg := func(segID, numRows int64) *datapb.SegmentInfo {
|
||||
@ -1517,7 +1511,6 @@ func Test_compactionTrigger_noplan_random_size(t *testing.T) {
|
||||
collectionID int64
|
||||
compactTime *compactTime
|
||||
}
|
||||
paramtable.Init()
|
||||
|
||||
segmentInfos := &SegmentsInfo{
|
||||
segments: make(map[UniqueID]*SegmentInfo),
|
||||
@ -1689,8 +1682,6 @@ func Test_compactionTrigger_noplan_random_size(t *testing.T) {
|
||||
|
||||
// Test shouldDoSingleCompaction
|
||||
func Test_compactionTrigger_shouldDoSingleCompaction(t *testing.T) {
|
||||
paramtable.Init()
|
||||
|
||||
trigger := newCompactionTrigger(&meta{}, &compactionPlanHandler{}, newMockAllocator(), newMockHandler())
|
||||
|
||||
// Test too many deltalogs.
|
||||
|
@ -604,7 +604,7 @@ func (s *Server) startDataNodeTtLoop(ctx context.Context) {
|
||||
}
|
||||
subName := fmt.Sprintf("%s-%d-datanodeTl", Params.CommonCfg.DataCoordSubName.GetValue(), paramtable.GetNodeID())
|
||||
|
||||
ttMsgStream.AsConsumer([]string{timeTickChannel}, subName, mqwrapper.SubscriptionPositionLatest)
|
||||
ttMsgStream.AsConsumer(context.TODO(), []string{timeTickChannel}, subName, mqwrapper.SubscriptionPositionLatest)
|
||||
log.Info("DataCoord creates the timetick channel consumer",
|
||||
zap.String("timeTickChannel", timeTickChannel),
|
||||
zap.String("subscription", subName))
|
||||
|
@ -54,6 +54,7 @@ import (
|
||||
"github.com/milvus-io/milvus/internal/util/sessionutil"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/pkg/util/etcd"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
@ -4378,3 +4379,225 @@ func TestDataCoord_EnableActiveStandby(t *testing.T) {
|
||||
svr := testDataCoordBase(t)
|
||||
defer closeTestServer(t, svr)
|
||||
}
|
||||
|
||||
func TestDataNodeTtChannel(t *testing.T) {
|
||||
paramtable.Get().Save(Params.DataNodeCfg.DataNodeTimeTickByRPC.Key, "false")
|
||||
defer paramtable.Get().Reset(Params.DataNodeCfg.DataNodeTimeTickByRPC.Key)
|
||||
genMsg := func(msgType commonpb.MsgType, ch string, t Timestamp) *msgstream.DataNodeTtMsg {
|
||||
return &msgstream.DataNodeTtMsg{
|
||||
BaseMsg: msgstream.BaseMsg{
|
||||
HashValues: []uint32{0},
|
||||
},
|
||||
DataNodeTtMsg: msgpb.DataNodeTtMsg{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: msgType,
|
||||
MsgID: 0,
|
||||
Timestamp: t,
|
||||
SourceID: 0,
|
||||
},
|
||||
ChannelName: ch,
|
||||
Timestamp: t,
|
||||
},
|
||||
}
|
||||
}
|
||||
t.Run("Test segment flush after tt", func(t *testing.T) {
|
||||
ch := make(chan any, 1)
|
||||
svr := newTestServer(t, ch)
|
||||
defer closeTestServer(t, svr)
|
||||
|
||||
svr.meta.AddCollection(&collectionInfo{
|
||||
ID: 0,
|
||||
Schema: newTestSchema(),
|
||||
Partitions: []int64{0},
|
||||
})
|
||||
|
||||
ttMsgStream, err := svr.factory.NewMsgStream(context.TODO())
|
||||
assert.NoError(t, err)
|
||||
ttMsgStream.AsProducer([]string{Params.CommonCfg.DataCoordTimeTick.GetValue()})
|
||||
defer ttMsgStream.Close()
|
||||
info := &NodeInfo{
|
||||
Address: "localhost:7777",
|
||||
NodeID: 0,
|
||||
}
|
||||
err = svr.cluster.Register(info)
|
||||
assert.NoError(t, err)
|
||||
|
||||
resp, err := svr.AssignSegmentID(context.TODO(), &datapb.AssignSegmentIDRequest{
|
||||
NodeID: 0,
|
||||
PeerRole: "",
|
||||
SegmentIDRequests: []*datapb.SegmentIDRequest{
|
||||
{
|
||||
CollectionID: 0,
|
||||
PartitionID: 0,
|
||||
ChannelName: "ch-1",
|
||||
Count: 100,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
|
||||
assert.EqualValues(t, 1, len(resp.SegIDAssignments))
|
||||
assign := resp.SegIDAssignments[0]
|
||||
|
||||
resp2, err := svr.Flush(context.TODO(), &datapb.FlushRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Flush,
|
||||
MsgID: 0,
|
||||
Timestamp: 0,
|
||||
SourceID: 0,
|
||||
},
|
||||
DbID: 0,
|
||||
CollectionID: 0,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, commonpb.ErrorCode_Success, resp2.Status.ErrorCode)
|
||||
|
||||
msgPack := msgstream.MsgPack{}
|
||||
msg := genMsg(commonpb.MsgType_DataNodeTt, "ch-1", assign.ExpireTime)
|
||||
msg.SegmentsStats = append(msg.SegmentsStats, &commonpb.SegmentStats{
|
||||
SegmentID: assign.GetSegID(),
|
||||
NumRows: 1,
|
||||
})
|
||||
msgPack.Msgs = append(msgPack.Msgs, msg)
|
||||
err = ttMsgStream.Produce(&msgPack)
|
||||
assert.NoError(t, err)
|
||||
|
||||
flushMsg := <-ch
|
||||
flushReq := flushMsg.(*datapb.FlushSegmentsRequest)
|
||||
assert.EqualValues(t, 1, len(flushReq.SegmentIDs))
|
||||
assert.EqualValues(t, assign.SegID, flushReq.SegmentIDs[0])
|
||||
})
|
||||
|
||||
t.Run("flush segment with different channels", func(t *testing.T) {
|
||||
ch := make(chan any, 1)
|
||||
svr := newTestServer(t, ch)
|
||||
defer closeTestServer(t, svr)
|
||||
svr.meta.AddCollection(&collectionInfo{
|
||||
ID: 0,
|
||||
Schema: newTestSchema(),
|
||||
Partitions: []int64{0},
|
||||
})
|
||||
ttMsgStream, err := svr.factory.NewMsgStream(context.TODO())
|
||||
assert.NoError(t, err)
|
||||
ttMsgStream.AsProducer([]string{Params.CommonCfg.DataCoordTimeTick.GetValue()})
|
||||
defer ttMsgStream.Close()
|
||||
info := &NodeInfo{
|
||||
Address: "localhost:7777",
|
||||
NodeID: 0,
|
||||
}
|
||||
err = svr.cluster.Register(info)
|
||||
assert.NoError(t, err)
|
||||
resp, err := svr.AssignSegmentID(context.TODO(), &datapb.AssignSegmentIDRequest{
|
||||
NodeID: 0,
|
||||
PeerRole: "",
|
||||
SegmentIDRequests: []*datapb.SegmentIDRequest{
|
||||
{
|
||||
CollectionID: 0,
|
||||
PartitionID: 0,
|
||||
ChannelName: "ch-1",
|
||||
Count: 100,
|
||||
},
|
||||
{
|
||||
CollectionID: 0,
|
||||
PartitionID: 0,
|
||||
ChannelName: "ch-2",
|
||||
Count: 100,
|
||||
},
|
||||
},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
|
||||
assert.EqualValues(t, 2, len(resp.SegIDAssignments))
|
||||
var assign *datapb.SegmentIDAssignment
|
||||
for _, segment := range resp.SegIDAssignments {
|
||||
if segment.GetChannelName() == "ch-1" {
|
||||
assign = segment
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.NotNil(t, assign)
|
||||
resp2, err := svr.Flush(context.TODO(), &datapb.FlushRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Flush,
|
||||
MsgID: 0,
|
||||
Timestamp: 0,
|
||||
SourceID: 0,
|
||||
},
|
||||
DbID: 0,
|
||||
CollectionID: 0,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, commonpb.ErrorCode_Success, resp2.Status.ErrorCode)
|
||||
|
||||
msgPack := msgstream.MsgPack{}
|
||||
msg := genMsg(commonpb.MsgType_DataNodeTt, "ch-1", assign.ExpireTime)
|
||||
msg.SegmentsStats = append(msg.SegmentsStats, &commonpb.SegmentStats{
|
||||
SegmentID: assign.GetSegID(),
|
||||
NumRows: 1,
|
||||
})
|
||||
msgPack.Msgs = append(msgPack.Msgs, msg)
|
||||
err = ttMsgStream.Produce(&msgPack)
|
||||
assert.NoError(t, err)
|
||||
flushMsg := <-ch
|
||||
flushReq := flushMsg.(*datapb.FlushSegmentsRequest)
|
||||
assert.EqualValues(t, 1, len(flushReq.SegmentIDs))
|
||||
assert.EqualValues(t, assign.SegID, flushReq.SegmentIDs[0])
|
||||
})
|
||||
|
||||
t.Run("test expire allocation after receiving tt msg", func(t *testing.T) {
|
||||
ch := make(chan any, 1)
|
||||
helper := ServerHelper{
|
||||
eventAfterHandleDataNodeTt: func() { ch <- struct{}{} },
|
||||
}
|
||||
svr := newTestServer(t, nil, WithServerHelper(helper))
|
||||
defer closeTestServer(t, svr)
|
||||
|
||||
svr.meta.AddCollection(&collectionInfo{
|
||||
ID: 0,
|
||||
Schema: newTestSchema(),
|
||||
Partitions: []int64{0},
|
||||
})
|
||||
|
||||
ttMsgStream, err := svr.factory.NewMsgStream(context.TODO())
|
||||
assert.NoError(t, err)
|
||||
ttMsgStream.AsProducer([]string{Params.CommonCfg.DataCoordTimeTick.GetValue()})
|
||||
defer ttMsgStream.Close()
|
||||
node := &NodeInfo{
|
||||
NodeID: 0,
|
||||
Address: "localhost:7777",
|
||||
}
|
||||
err = svr.cluster.Register(node)
|
||||
assert.NoError(t, err)
|
||||
|
||||
resp, err := svr.AssignSegmentID(context.TODO(), &datapb.AssignSegmentIDRequest{
|
||||
NodeID: 0,
|
||||
PeerRole: "",
|
||||
SegmentIDRequests: []*datapb.SegmentIDRequest{
|
||||
{
|
||||
CollectionID: 0,
|
||||
PartitionID: 0,
|
||||
ChannelName: "ch-1",
|
||||
Count: 100,
|
||||
},
|
||||
},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
|
||||
assert.EqualValues(t, 1, len(resp.SegIDAssignments))
|
||||
|
||||
assignedSegmentID := resp.SegIDAssignments[0].SegID
|
||||
segment := svr.meta.GetHealthySegment(assignedSegmentID)
|
||||
assert.EqualValues(t, 1, len(segment.allocations))
|
||||
|
||||
msgPack := msgstream.MsgPack{}
|
||||
msg := genMsg(commonpb.MsgType_DataNodeTt, "ch-1", resp.SegIDAssignments[0].ExpireTime)
|
||||
msgPack.Msgs = append(msgPack.Msgs, msg)
|
||||
err = ttMsgStream.Produce(&msgPack)
|
||||
assert.NoError(t, err)
|
||||
|
||||
<-ch
|
||||
segment = svr.meta.GetHealthySegment(assignedSegmentID)
|
||||
assert.EqualValues(t, 0, len(segment.allocations))
|
||||
})
|
||||
}
|
||||
|
@ -454,7 +454,7 @@ func (dsService *dataSyncService) getChannelLatestMsgID(ctx context.Context, cha
|
||||
zap.String("pChannelName", pChannelName),
|
||||
zap.String("subscription", subName),
|
||||
)
|
||||
dmlStream.AsConsumer([]string{pChannelName}, subName, mqwrapper.SubscriptionPositionUnknown)
|
||||
dmlStream.AsConsumer(ctx, []string{pChannelName}, subName, mqwrapper.SubscriptionPositionUnknown)
|
||||
id, err := dmlStream.GetLatestMsgID(pChannelName)
|
||||
if err != nil {
|
||||
log.Error("fail to GetLatestMsgID", zap.String("pChannelName", pChannelName), zap.Error(err))
|
||||
|
@ -28,6 +28,7 @@ import (
|
||||
"github.com/milvus-io/milvus/internal/storage"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/retry"
|
||||
"github.com/milvus-io/milvus/pkg/util/tsoutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
@ -127,6 +128,10 @@ func (dn *deleteNode) Operate(in []Msg) []Msg {
|
||||
return dn.flushManager.flushDelData(buf, segmentToFlush, fgMsg.endPositions[0])
|
||||
}, getFlowGraphRetryOpt())
|
||||
if err != nil {
|
||||
if merr.IsCanceledOrTimeout(err) {
|
||||
log.Warn("skip syncing delete data for context done", zap.Int64("segmentID", segmentToFlush))
|
||||
continue
|
||||
}
|
||||
log.Fatal("failed to flush delete data", zap.Int64("segmentID", segmentToFlush), zap.Error(err))
|
||||
}
|
||||
// remove delete buf
|
||||
|
@ -17,6 +17,7 @@
|
||||
package datanode
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
@ -45,7 +46,7 @@ func newDmInputNode(dispatcherClient msgdispatcher.Client, seekPos *msgpb.MsgPos
|
||||
var err error
|
||||
var input <-chan *msgstream.MsgPack
|
||||
if seekPos != nil && len(seekPos.MsgID) != 0 {
|
||||
input, err = dispatcherClient.Register(dmNodeConfig.vChannelName, seekPos, mqwrapper.SubscriptionPositionUnknown)
|
||||
input, err = dispatcherClient.Register(context.TODO(), dmNodeConfig.vChannelName, seekPos, mqwrapper.SubscriptionPositionUnknown)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -54,7 +55,7 @@ func newDmInputNode(dispatcherClient msgdispatcher.Client, seekPos *msgpb.MsgPos
|
||||
zap.Time("tsTime", tsoutil.PhysicalTime(seekPos.GetTimestamp())),
|
||||
zap.Duration("tsLag", time.Since(tsoutil.PhysicalTime(seekPos.GetTimestamp()))))
|
||||
} else {
|
||||
input, err = dispatcherClient.Register(dmNodeConfig.vChannelName, nil, mqwrapper.SubscriptionPositionEarliest)
|
||||
input, err = dispatcherClient.Register(context.TODO(), dmNodeConfig.vChannelName, nil, mqwrapper.SubscriptionPositionEarliest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -71,7 +71,8 @@ func (mtm *mockTtMsgStream) Chan() <-chan *msgstream.MsgPack {
|
||||
|
||||
func (mtm *mockTtMsgStream) AsProducer(channels []string) {}
|
||||
|
||||
func (mtm *mockTtMsgStream) AsConsumer(channels []string, subName string, position mqwrapper.SubscriptionInitialPosition) {
|
||||
func (mtm *mockTtMsgStream) AsConsumer(ctx context.Context, channels []string, subName string, position mqwrapper.SubscriptionInitialPosition) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mtm *mockTtMsgStream) SetRepackFunc(repackFunc msgstream.RepackFunc) {}
|
||||
@ -88,7 +89,7 @@ func (mtm *mockTtMsgStream) Broadcast(*msgstream.MsgPack) (map[string][]msgstrea
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (mtm *mockTtMsgStream) Seek(offset []*msgpb.MsgPosition) error {
|
||||
func (mtm *mockTtMsgStream) Seek(ctx context.Context, offset []*msgpb.MsgPosition) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -38,6 +38,7 @@ import (
|
||||
"github.com/milvus-io/milvus/pkg/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/util/retry"
|
||||
"github.com/milvus-io/milvus/pkg/util/tsoutil"
|
||||
@ -495,6 +496,14 @@ func (ibNode *insertBufferNode) Sync(fgMsg *flowGraphMsg, seg2Upload []UniqueID,
|
||||
metrics.DataNodeAutoFlushBufferCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.FailLabel).Inc()
|
||||
metrics.DataNodeAutoFlushBufferCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.TotalLabel).Inc()
|
||||
}
|
||||
|
||||
if merr.IsCanceledOrTimeout(err) {
|
||||
log.Warn("skip syncing buffer data for context done",
|
||||
zap.Int64("segmentID", task.segmentID),
|
||||
zap.Error(err),
|
||||
)
|
||||
continue
|
||||
}
|
||||
log.Fatal("insertBufferNode failed to flushBufferData",
|
||||
zap.Int64("segmentID", task.segmentID),
|
||||
zap.Error(err),
|
||||
|
@ -30,8 +30,6 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
|
||||
"google.golang.org/grpc/credentials"
|
||||
|
||||
management "github.com/milvus-io/milvus/internal/http"
|
||||
@ -62,6 +60,7 @@ import (
|
||||
"github.com/milvus-io/milvus/pkg/util/etcd"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/logutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
clientv3 "go.etcd.io/etcd/client/v3"
|
||||
"go.uber.org/atomic"
|
||||
|
@ -17,6 +17,7 @@
|
||||
package rmq
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
@ -38,7 +39,7 @@ type rmqClient struct {
|
||||
client client.Client
|
||||
}
|
||||
|
||||
func NewClientWithDefaultOptions() (mqwrapper.Client, error) {
|
||||
func NewClientWithDefaultOptions(ctx context.Context) (mqwrapper.Client, error) {
|
||||
option := client.Options{Server: server.Rmq}
|
||||
return NewClient(option)
|
||||
}
|
||||
|
@ -66,8 +66,8 @@ func TestMqMsgStream_AsConsumer(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
|
||||
// repeat calling AsConsumer
|
||||
m.AsConsumer([]string{"a"}, "b", mqwrapper.SubscriptionPositionUnknown)
|
||||
m.AsConsumer([]string{"a"}, "b", mqwrapper.SubscriptionPositionUnknown)
|
||||
m.AsConsumer(context.Background(), []string{"a"}, "b", mqwrapper.SubscriptionPositionUnknown)
|
||||
m.AsConsumer(context.Background(), []string{"a"}, "b", mqwrapper.SubscriptionPositionUnknown)
|
||||
}
|
||||
|
||||
func TestMqMsgStream_ComputeProduceChannelIndexes(t *testing.T) {
|
||||
@ -240,7 +240,7 @@ func TestMqMsgStream_SeekNotSubscribed(t *testing.T) {
|
||||
ChannelName: "b",
|
||||
},
|
||||
}
|
||||
err = m.Seek(p)
|
||||
err = m.Seek(context.Background(), p)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
@ -265,7 +265,7 @@ func initRmqStream(ctx context.Context,
|
||||
) (msgstream.MsgStream, msgstream.MsgStream) {
|
||||
factory := msgstream.ProtoUDFactory{}
|
||||
|
||||
rmqClient, _ := NewClientWithDefaultOptions()
|
||||
rmqClient, _ := NewClientWithDefaultOptions(ctx)
|
||||
inputStream, _ := msgstream.NewMqMsgStream(ctx, 100, 100, rmqClient, factory.NewUnmarshalDispatcher())
|
||||
inputStream.AsProducer(producerChannels)
|
||||
for _, opt := range opts {
|
||||
@ -273,9 +273,9 @@ func initRmqStream(ctx context.Context,
|
||||
}
|
||||
var input msgstream.MsgStream = inputStream
|
||||
|
||||
rmqClient2, _ := NewClientWithDefaultOptions()
|
||||
rmqClient2, _ := NewClientWithDefaultOptions(ctx)
|
||||
outputStream, _ := msgstream.NewMqMsgStream(ctx, 100, 100, rmqClient2, factory.NewUnmarshalDispatcher())
|
||||
outputStream.AsConsumer(consumerChannels, consumerGroupName, mqwrapper.SubscriptionPositionEarliest)
|
||||
outputStream.AsConsumer(ctx, consumerChannels, consumerGroupName, mqwrapper.SubscriptionPositionEarliest)
|
||||
var output msgstream.MsgStream = outputStream
|
||||
|
||||
return input, output
|
||||
@ -289,7 +289,7 @@ func initRmqTtStream(ctx context.Context,
|
||||
) (msgstream.MsgStream, msgstream.MsgStream) {
|
||||
factory := msgstream.ProtoUDFactory{}
|
||||
|
||||
rmqClient, _ := NewClientWithDefaultOptions()
|
||||
rmqClient, _ := NewClientWithDefaultOptions(ctx)
|
||||
inputStream, _ := msgstream.NewMqMsgStream(ctx, 100, 100, rmqClient, factory.NewUnmarshalDispatcher())
|
||||
inputStream.AsProducer(producerChannels)
|
||||
for _, opt := range opts {
|
||||
@ -297,9 +297,9 @@ func initRmqTtStream(ctx context.Context,
|
||||
}
|
||||
var input msgstream.MsgStream = inputStream
|
||||
|
||||
rmqClient2, _ := NewClientWithDefaultOptions()
|
||||
rmqClient2, _ := NewClientWithDefaultOptions(ctx)
|
||||
outputStream, _ := msgstream.NewMqTtMsgStream(ctx, 100, 100, rmqClient2, factory.NewUnmarshalDispatcher())
|
||||
outputStream.AsConsumer(consumerChannels, consumerGroupName, mqwrapper.SubscriptionPositionEarliest)
|
||||
outputStream.AsConsumer(ctx, consumerChannels, consumerGroupName, mqwrapper.SubscriptionPositionEarliest)
|
||||
var output msgstream.MsgStream = outputStream
|
||||
|
||||
return input, output
|
||||
@ -399,11 +399,11 @@ func TestStream_RmqTtMsgStream_DuplicatedIDs(t *testing.T) {
|
||||
|
||||
factory := msgstream.ProtoUDFactory{}
|
||||
|
||||
rmqClient, _ := NewClientWithDefaultOptions()
|
||||
rmqClient, _ := NewClientWithDefaultOptions(ctx)
|
||||
outputStream, _ = msgstream.NewMqTtMsgStream(context.Background(), 100, 100, rmqClient, factory.NewUnmarshalDispatcher())
|
||||
consumerSubName = funcutil.RandomString(8)
|
||||
outputStream.AsConsumer(consumerChannels, consumerSubName, mqwrapper.SubscriptionPositionUnknown)
|
||||
outputStream.Seek(receivedMsg.StartPositions)
|
||||
outputStream.AsConsumer(ctx, consumerChannels, consumerSubName, mqwrapper.SubscriptionPositionUnknown)
|
||||
outputStream.Seek(ctx, receivedMsg.StartPositions)
|
||||
seekMsg := consumer(ctx, outputStream)
|
||||
assert.Equal(t, len(seekMsg.Msgs), 1+2)
|
||||
assert.EqualValues(t, seekMsg.Msgs[0].BeginTs(), 1)
|
||||
@ -501,12 +501,12 @@ func TestStream_RmqTtMsgStream_Seek(t *testing.T) {
|
||||
|
||||
factory := msgstream.ProtoUDFactory{}
|
||||
|
||||
rmqClient, _ := NewClientWithDefaultOptions()
|
||||
rmqClient, _ := NewClientWithDefaultOptions(ctx)
|
||||
outputStream, _ = msgstream.NewMqTtMsgStream(context.Background(), 100, 100, rmqClient, factory.NewUnmarshalDispatcher())
|
||||
consumerSubName = funcutil.RandomString(8)
|
||||
outputStream.AsConsumer(consumerChannels, consumerSubName, mqwrapper.SubscriptionPositionUnknown)
|
||||
outputStream.AsConsumer(ctx, consumerChannels, consumerSubName, mqwrapper.SubscriptionPositionUnknown)
|
||||
|
||||
outputStream.Seek(receivedMsg3.StartPositions)
|
||||
outputStream.Seek(ctx, receivedMsg3.StartPositions)
|
||||
seekMsg := consumer(ctx, outputStream)
|
||||
assert.Equal(t, len(seekMsg.Msgs), 3)
|
||||
result := []uint64{14, 12, 13}
|
||||
@ -549,9 +549,9 @@ func TestStream_RMqMsgStream_SeekInvalidMessage(t *testing.T) {
|
||||
outputStream.Close()
|
||||
|
||||
factory := msgstream.ProtoUDFactory{}
|
||||
rmqClient2, _ := NewClientWithDefaultOptions()
|
||||
rmqClient2, _ := NewClientWithDefaultOptions(ctx)
|
||||
outputStream2, _ := msgstream.NewMqMsgStream(ctx, 100, 100, rmqClient2, factory.NewUnmarshalDispatcher())
|
||||
outputStream2.AsConsumer(consumerChannels, funcutil.RandomString(8), mqwrapper.SubscriptionPositionUnknown)
|
||||
outputStream2.AsConsumer(ctx, consumerChannels, funcutil.RandomString(8), mqwrapper.SubscriptionPositionUnknown)
|
||||
|
||||
id := common.Endian.Uint64(seekPosition.MsgID) + 10
|
||||
bs := make([]byte, 8)
|
||||
@ -565,7 +565,7 @@ func TestStream_RMqMsgStream_SeekInvalidMessage(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
err = outputStream2.Seek(p)
|
||||
err = outputStream2.Seek(ctx, p)
|
||||
assert.NoError(t, err)
|
||||
|
||||
for i := 10; i < 20; i++ {
|
||||
@ -589,7 +589,7 @@ func TestStream_RmqTtMsgStream_AsConsumerWithPosition(t *testing.T) {
|
||||
|
||||
factory := msgstream.ProtoUDFactory{}
|
||||
|
||||
rmqClient, _ := NewClientWithDefaultOptions()
|
||||
rmqClient, _ := NewClientWithDefaultOptions(context.Background())
|
||||
|
||||
otherInputStream, _ := msgstream.NewMqMsgStream(context.Background(), 100, 100, rmqClient, factory.NewUnmarshalDispatcher())
|
||||
otherInputStream.AsProducer([]string{"root_timetick"})
|
||||
@ -602,9 +602,9 @@ func TestStream_RmqTtMsgStream_AsConsumerWithPosition(t *testing.T) {
|
||||
inputStream.Produce(getTimeTickMsgPack(int64(i)))
|
||||
}
|
||||
|
||||
rmqClient2, _ := NewClientWithDefaultOptions()
|
||||
rmqClient2, _ := NewClientWithDefaultOptions(context.Background())
|
||||
outputStream, _ := msgstream.NewMqMsgStream(context.Background(), 100, 100, rmqClient2, factory.NewUnmarshalDispatcher())
|
||||
outputStream.AsConsumer(consumerChannels, consumerSubName, mqwrapper.SubscriptionPositionLatest)
|
||||
outputStream.AsConsumer(context.Background(), consumerChannels, consumerSubName, mqwrapper.SubscriptionPositionLatest)
|
||||
|
||||
inputStream.Produce(getTimeTickMsgPack(1000))
|
||||
pack := <-outputStream.Chan()
|
||||
|
@ -252,7 +252,8 @@ func (ms *simpleMockMsgStream) Chan() <-chan *msgstream.MsgPack {
|
||||
func (ms *simpleMockMsgStream) AsProducer(channels []string) {
|
||||
}
|
||||
|
||||
func (ms *simpleMockMsgStream) AsConsumer(channels []string, subName string, position mqwrapper.SubscriptionInitialPosition) {
|
||||
func (ms *simpleMockMsgStream) AsConsumer(ctx context.Context, channels []string, subName string, position mqwrapper.SubscriptionInitialPosition) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ms *simpleMockMsgStream) SetRepackFunc(repackFunc msgstream.RepackFunc) {
|
||||
@ -292,7 +293,7 @@ func (ms *simpleMockMsgStream) GetProduceChannels() []string {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ms *simpleMockMsgStream) Seek(offset []*msgstream.MsgPosition) error {
|
||||
func (ms *simpleMockMsgStream) Seek(ctx context.Context, offset []*msgstream.MsgPosition) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -493,9 +493,12 @@ func (sd *shardDelegator) readDeleteFromMsgstream(ctx context.Context, position
|
||||
// Random the subname in case we trying to load same delta at the same time
|
||||
subName := fmt.Sprintf("querynode-delta-loader-%d-%d-%d", paramtable.GetNodeID(), sd.collectionID, rand.Int())
|
||||
log.Info("from dml check point load delete", zap.Any("position", position), zap.String("vChannel", vchannelName), zap.String("subName", subName), zap.Time("positionTs", ts))
|
||||
stream.AsConsumer([]string{pChannelName}, subName, mqwrapper.SubscriptionPositionUnknown)
|
||||
err = stream.AsConsumer(context.TODO(), []string{pChannelName}, subName, mqwrapper.SubscriptionPositionUnknown)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = stream.Seek([]*msgpb.MsgPosition{position})
|
||||
err = stream.Seek(context.TODO(), []*msgpb.MsgPosition{position})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -468,6 +468,79 @@ func (s *DelegatorDataSuite) TestLoadSegments() {
|
||||
}, sealed[0].Segments)
|
||||
})
|
||||
|
||||
s.Run("load_segments_with_streaming_delete_failed", func() {
|
||||
defer func() {
|
||||
s.workerManager.ExpectedCalls = nil
|
||||
s.loader.ExpectedCalls = nil
|
||||
}()
|
||||
|
||||
s.loader.EXPECT().LoadBloomFilterSet(mock.Anything, s.collectionID, mock.AnythingOfType("int64"), mock.Anything).
|
||||
Call.Return(func(ctx context.Context, collectionID int64, version int64, infos ...*querypb.SegmentLoadInfo) []*pkoracle.BloomFilterSet {
|
||||
return lo.Map(infos, func(info *querypb.SegmentLoadInfo, _ int) *pkoracle.BloomFilterSet {
|
||||
bfs := pkoracle.NewBloomFilterSet(info.GetSegmentID(), info.GetPartitionID(), commonpb.SegmentState_Sealed)
|
||||
bf := bloom.NewWithEstimates(storage.BloomFilterSize, storage.MaxBloomFalsePositive)
|
||||
pks := &storage.PkStatistics{
|
||||
PkFilter: bf,
|
||||
}
|
||||
pks.UpdatePKRange(&storage.Int64FieldData{
|
||||
Data: []int64{10, 20, 30},
|
||||
})
|
||||
bfs.AddHistoricalStats(pks)
|
||||
return bfs
|
||||
})
|
||||
}, func(ctx context.Context, collectionID int64, version int64, infos ...*querypb.SegmentLoadInfo) error {
|
||||
return nil
|
||||
})
|
||||
|
||||
workers := make(map[int64]*cluster.MockWorker)
|
||||
worker1 := &cluster.MockWorker{}
|
||||
workers[1] = worker1
|
||||
|
||||
worker1.EXPECT().LoadSegments(mock.Anything, mock.AnythingOfType("*querypb.LoadSegmentsRequest")).
|
||||
Return(nil)
|
||||
worker1.EXPECT().Delete(mock.Anything, mock.AnythingOfType("*querypb.DeleteRequest")).Return(nil)
|
||||
s.workerManager.EXPECT().GetWorker(mock.Anything, mock.AnythingOfType("int64")).Call.Return(func(_ context.Context, nodeID int64) cluster.Worker {
|
||||
return workers[nodeID]
|
||||
}, nil)
|
||||
|
||||
s.delegator.ProcessDelete([]*DeleteData{
|
||||
{
|
||||
PartitionID: 500,
|
||||
PrimaryKeys: []storage.PrimaryKey{
|
||||
storage.NewInt64PrimaryKey(1),
|
||||
storage.NewInt64PrimaryKey(10),
|
||||
},
|
||||
Timestamps: []uint64{10, 10},
|
||||
RowCount: 2,
|
||||
},
|
||||
}, 10)
|
||||
|
||||
s.mq.EXPECT().AsConsumer(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||
s.mq.EXPECT().Seek(mock.Anything, mock.Anything).Return(nil)
|
||||
s.mq.EXPECT().Close()
|
||||
ch := make(chan *msgstream.MsgPack, 10)
|
||||
close(ch)
|
||||
|
||||
s.mq.EXPECT().Chan().Return(ch)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
err := s.delegator.LoadSegments(ctx, &querypb.LoadSegmentsRequest{
|
||||
Base: commonpbutil.NewMsgBase(),
|
||||
DstNodeID: 1,
|
||||
CollectionID: s.collectionID,
|
||||
Infos: []*querypb.SegmentLoadInfo{
|
||||
{
|
||||
SegmentID: 300,
|
||||
PartitionID: 500,
|
||||
StartPosition: &msgpb.MsgPosition{Timestamp: 2},
|
||||
DeltaPosition: &msgpb.MsgPosition{Timestamp: 2},
|
||||
},
|
||||
},
|
||||
})
|
||||
s.Error(err)
|
||||
})
|
||||
|
||||
s.Run("get_worker_fail", func() {
|
||||
defer func() {
|
||||
s.workerManager.ExpectedCalls = nil
|
||||
|
@ -79,7 +79,7 @@ func (suite *PipelineManagerTestSuite) TestBasic() {
|
||||
// mock collection manager
|
||||
suite.collectionManager.EXPECT().Get(suite.collectionID).Return(&segments.Collection{})
|
||||
// mock mq factory
|
||||
suite.msgDispatcher.EXPECT().Register(suite.channel, mock.Anything, mqwrapper.SubscriptionPositionUnknown).Return(suite.msgChan, nil)
|
||||
suite.msgDispatcher.EXPECT().Register(mock.Anything, suite.channel, mock.Anything, mqwrapper.SubscriptionPositionUnknown).Return(suite.msgChan, nil)
|
||||
suite.msgDispatcher.EXPECT().Deregister(suite.channel)
|
||||
|
||||
//build manager
|
||||
|
@ -112,7 +112,7 @@ func (suite *PipelineTestSuite) TestBasic() {
|
||||
suite.collectionManager.EXPECT().Get(suite.collectionID).Return(collection)
|
||||
|
||||
// mock mq factory
|
||||
suite.msgDispatcher.EXPECT().Register(suite.channel, mock.Anything, mqwrapper.SubscriptionPositionUnknown).Return(suite.msgChan, nil)
|
||||
suite.msgDispatcher.EXPECT().Register(mock.Anything, suite.channel, mock.Anything, mqwrapper.SubscriptionPositionUnknown).Return(suite.msgChan, nil)
|
||||
suite.msgDispatcher.EXPECT().Deregister(suite.channel)
|
||||
|
||||
// mock delegator
|
||||
|
@ -280,8 +280,8 @@ func (suite *ServiceSuite) TestWatchDmChannelsInt64() {
|
||||
|
||||
// mocks
|
||||
suite.factory.EXPECT().NewTtMsgStream(mock.Anything).Return(suite.msgStream, nil)
|
||||
suite.msgStream.EXPECT().AsConsumer([]string{suite.pchannel}, mock.Anything, mock.Anything).Return()
|
||||
suite.msgStream.EXPECT().Seek(mock.Anything).Return(nil)
|
||||
suite.msgStream.EXPECT().AsConsumer(mock.Anything, []string{suite.pchannel}, mock.Anything, mock.Anything).Return(nil)
|
||||
suite.msgStream.EXPECT().Seek(mock.Anything, mock.Anything).Return(nil)
|
||||
suite.msgStream.EXPECT().Chan().Return(suite.msgChan)
|
||||
suite.msgStream.EXPECT().Close()
|
||||
|
||||
@ -329,8 +329,8 @@ func (suite *ServiceSuite) TestWatchDmChannelsVarchar() {
|
||||
|
||||
// mocks
|
||||
suite.factory.EXPECT().NewTtMsgStream(mock.Anything).Return(suite.msgStream, nil)
|
||||
suite.msgStream.EXPECT().AsConsumer([]string{suite.pchannel}, mock.Anything, mock.Anything).Return()
|
||||
suite.msgStream.EXPECT().Seek(mock.Anything).Return(nil)
|
||||
suite.msgStream.EXPECT().AsConsumer(mock.Anything, []string{suite.pchannel}, mock.Anything, mock.Anything).Return(nil)
|
||||
suite.msgStream.EXPECT().Seek(mock.Anything, mock.Anything).Return(nil)
|
||||
suite.msgStream.EXPECT().Chan().Return(suite.msgChan)
|
||||
suite.msgStream.EXPECT().Close()
|
||||
|
||||
@ -382,9 +382,9 @@ func (suite *ServiceSuite) TestWatchDmChannels_Failed() {
|
||||
|
||||
// init msgstream failed
|
||||
suite.factory.EXPECT().NewTtMsgStream(mock.Anything).Return(suite.msgStream, nil)
|
||||
suite.msgStream.EXPECT().AsConsumer([]string{suite.pchannel}, mock.Anything, mock.Anything).Return()
|
||||
suite.msgStream.EXPECT().AsConsumer(mock.Anything, []string{suite.pchannel}, mock.Anything, mock.Anything).Return(nil)
|
||||
suite.msgStream.EXPECT().Close().Return()
|
||||
suite.msgStream.EXPECT().Seek(mock.Anything).Return(errors.New("mock error"))
|
||||
suite.msgStream.EXPECT().Seek(mock.Anything, mock.Anything).Return(errors.New("mock error"))
|
||||
|
||||
status, err = suite.node.WatchDmChannels(ctx, req)
|
||||
suite.NoError(err)
|
||||
|
@ -188,7 +188,7 @@ func newDmlChannels(ctx context.Context, factory msgstream.Factory, chanNamePref
|
||||
|
||||
if params.PreCreatedTopicEnabled.GetAsBool() {
|
||||
subName := fmt.Sprintf("pre-created-topic-check-%s", name)
|
||||
ms.AsConsumer([]string{name}, subName, mqwrapper.SubscriptionPositionUnknown)
|
||||
ms.AsConsumer(ctx, []string{name}, subName, mqwrapper.SubscriptionPositionUnknown)
|
||||
// check topic exist and check the existed topic whether empty or not
|
||||
// kafka and rmq will err if the topic does not yet exist, pulsar will not
|
||||
// if one of the topics is not empty, panic
|
||||
|
@ -283,7 +283,8 @@ func (ms *FailMsgStream) Close() {}
|
||||
func (ms *FailMsgStream) Chan() <-chan *msgstream.MsgPack { return nil }
|
||||
func (ms *FailMsgStream) AsProducer(channels []string) {}
|
||||
func (ms *FailMsgStream) AsReader(channels []string, subName string) {}
|
||||
func (ms *FailMsgStream) AsConsumer(channels []string, subName string, position mqwrapper.SubscriptionInitialPosition) {
|
||||
func (ms *FailMsgStream) AsConsumer(ctx context.Context, channels []string, subName string, position mqwrapper.SubscriptionInitialPosition) error {
|
||||
return nil
|
||||
}
|
||||
func (ms *FailMsgStream) SetRepackFunc(repackFunc msgstream.RepackFunc) {}
|
||||
func (ms *FailMsgStream) GetProduceChannels() []string { return nil }
|
||||
@ -294,8 +295,8 @@ func (ms *FailMsgStream) Broadcast(*msgstream.MsgPack) (map[string][]msgstream.M
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
func (ms *FailMsgStream) Consume() *msgstream.MsgPack { return nil }
|
||||
func (ms *FailMsgStream) Seek(offset []*msgstream.MsgPosition) error { return nil }
|
||||
func (ms *FailMsgStream) Consume() *msgstream.MsgPack { return nil }
|
||||
func (ms *FailMsgStream) Seek(ctx context.Context, offset []*msgstream.MsgPosition) error { return nil }
|
||||
|
||||
func (ms *FailMsgStream) GetLatestMsgID(channel string) (msgstream.MessageID, error) {
|
||||
return nil, nil
|
||||
|
@ -32,7 +32,7 @@ func TestInputNode(t *testing.T) {
|
||||
|
||||
msgStream, _ := factory.NewMsgStream(context.TODO())
|
||||
channels := []string{"cc"}
|
||||
msgStream.AsConsumer(channels, "sub", mqwrapper.SubscriptionPositionEarliest)
|
||||
msgStream.AsConsumer(context.Background(), channels, "sub", mqwrapper.SubscriptionPositionEarliest)
|
||||
|
||||
msgPack := generateMsgPack()
|
||||
produceStream, _ := factory.NewMsgStream(context.TODO())
|
||||
|
@ -62,7 +62,7 @@ func TestNodeCtx_Start(t *testing.T) {
|
||||
|
||||
msgStream, _ := factory.NewMsgStream(context.TODO())
|
||||
channels := []string{"cc"}
|
||||
msgStream.AsConsumer(channels, "sub", mqwrapper.SubscriptionPositionEarliest)
|
||||
msgStream.AsConsumer(context.TODO(), channels, "sub", mqwrapper.SubscriptionPositionEarliest)
|
||||
|
||||
produceStream, _ := factory.NewMsgStream(context.TODO())
|
||||
produceStream.AsProducer(channels)
|
||||
|
@ -17,6 +17,7 @@
|
||||
package pipeline
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@ -68,7 +69,7 @@ func (p *streamPipeline) ConsumeMsgStream(position *msgpb.MsgPosition) error {
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
p.input, err = p.dispatcher.Register(p.vChannel, position, mqwrapper.SubscriptionPositionUnknown)
|
||||
p.input, err = p.dispatcher.Register(context.TODO(), p.vChannel, position, mqwrapper.SubscriptionPositionUnknown)
|
||||
if err != nil {
|
||||
log.Error("dispatcher register failed", zap.String("channel", position.ChannelName))
|
||||
return WrapErrRegDispather(err)
|
||||
|
@ -45,7 +45,7 @@ func (suite *StreamPipelineSuite) SetupTest() {
|
||||
suite.inChannel = make(chan *msgstream.MsgPack, 1)
|
||||
suite.outChannel = make(chan msgstream.Timestamp)
|
||||
suite.msgDispatcher = msgdispatcher.NewMockClient(suite.T())
|
||||
suite.msgDispatcher.EXPECT().Register(suite.channel, mock.Anything, mqwrapper.SubscriptionPositionUnknown).Return(suite.inChannel, nil)
|
||||
suite.msgDispatcher.EXPECT().Register(mock.Anything, suite.channel, mock.Anything, mqwrapper.SubscriptionPositionUnknown).Return(suite.inChannel, nil)
|
||||
suite.msgDispatcher.EXPECT().Deregister(suite.channel)
|
||||
suite.pipeline = NewPipelineWithStream(suite.msgDispatcher, 0, false, suite.channel)
|
||||
suite.length = 4
|
||||
|
@ -69,7 +69,7 @@ type MockLogger_Record_Call struct {
|
||||
}
|
||||
|
||||
// Record is a helper method to define mock.On call
|
||||
// - _a0 Evt
|
||||
// - _a0 Evt
|
||||
func (_e *MockLogger_Expecter) Record(_a0 interface{}) *MockLogger_Record_Call {
|
||||
return &MockLogger_Record_Call{Call: _e.mock.On("Record", _a0)}
|
||||
}
|
||||
@ -102,8 +102,8 @@ type MockLogger_RecordFunc_Call struct {
|
||||
}
|
||||
|
||||
// RecordFunc is a helper method to define mock.On call
|
||||
// - _a0 Level
|
||||
// - _a1 func() Evt
|
||||
// - _a0 Level
|
||||
// - _a1 func() Evt
|
||||
func (_e *MockLogger_Expecter) RecordFunc(_a0 interface{}, _a1 interface{}) *MockLogger_RecordFunc_Call {
|
||||
return &MockLogger_RecordFunc_Call{Call: _e.mock.On("RecordFunc", _a0, _a1)}
|
||||
}
|
||||
|
@ -41,6 +41,7 @@ require (
|
||||
go.uber.org/zap v1.20.0
|
||||
golang.org/x/crypto v0.9.0
|
||||
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17
|
||||
golang.org/x/net v0.10.0
|
||||
golang.org/x/sync v0.1.0
|
||||
google.golang.org/grpc v1.54.0
|
||||
google.golang.org/protobuf v1.30.0
|
||||
@ -153,7 +154,6 @@ require (
|
||||
go.opentelemetry.io/otel/metric v0.35.0 // indirect
|
||||
go.opentelemetry.io/proto/otlp v0.19.0 // indirect
|
||||
go.uber.org/multierr v1.7.0 // indirect
|
||||
golang.org/x/net v0.10.0 // indirect
|
||||
golang.org/x/oauth2 v0.6.0 // indirect
|
||||
golang.org/x/sys v0.8.0 // indirect
|
||||
golang.org/x/term v0.8.0 // indirect
|
||||
|
@ -17,6 +17,7 @@
|
||||
package msgdispatcher
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
|
||||
@ -35,7 +36,7 @@ type (
|
||||
)
|
||||
|
||||
type Client interface {
|
||||
Register(vchannel string, pos *Pos, subPos SubPos) (<-chan *MsgPack, error)
|
||||
Register(ctx context.Context, vchannel string, pos *Pos, subPos SubPos) (<-chan *MsgPack, error)
|
||||
Deregister(vchannel string)
|
||||
Close()
|
||||
}
|
||||
@ -60,7 +61,7 @@ func NewClient(factory msgstream.Factory, role string, nodeID int64) Client {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *client) Register(vchannel string, pos *Pos, subPos SubPos) (<-chan *MsgPack, error) {
|
||||
func (c *client) Register(ctx context.Context, vchannel string, pos *Pos, subPos SubPos) (<-chan *MsgPack, error) {
|
||||
log := log.With(zap.String("role", c.role),
|
||||
zap.Int64("nodeID", c.nodeID), zap.String("vchannel", vchannel))
|
||||
pchannel := funcutil.ToPhysicalChannel(vchannel)
|
||||
@ -73,7 +74,7 @@ func (c *client) Register(vchannel string, pos *Pos, subPos SubPos) (<-chan *Msg
|
||||
c.managers[pchannel] = manager
|
||||
go manager.Run()
|
||||
}
|
||||
ch, err := manager.Add(vchannel, pos, subPos)
|
||||
ch, err := manager.Add(ctx, vchannel, pos, subPos)
|
||||
if err != nil {
|
||||
if manager.Num() == 0 {
|
||||
manager.Close()
|
||||
|
@ -17,10 +17,12 @@
|
||||
package msgdispatcher
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/atomic"
|
||||
@ -32,10 +34,25 @@ import (
|
||||
func TestClient(t *testing.T) {
|
||||
client := NewClient(newMockFactory(), typeutil.ProxyRole, 1)
|
||||
assert.NotNil(t, client)
|
||||
_, err := client.Register("mock_vchannel_0", nil, mqwrapper.SubscriptionPositionUnknown)
|
||||
_, err := client.Register(context.Background(), "mock_vchannel_0", nil, mqwrapper.SubscriptionPositionUnknown)
|
||||
assert.NoError(t, err)
|
||||
_, err = client.Register(context.Background(), "mock_vchannel_1", nil, mqwrapper.SubscriptionPositionUnknown)
|
||||
assert.NoError(t, err)
|
||||
assert.NotPanics(t, func() {
|
||||
client.Deregister("mock_vchannel_0")
|
||||
client.Close()
|
||||
})
|
||||
|
||||
t.Run("with timeout ctx", func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Millisecond)
|
||||
defer cancel()
|
||||
<-time.After(2 * time.Millisecond)
|
||||
|
||||
client := NewClient(newMockFactory(), typeutil.DataNodeRole, 1)
|
||||
defer client.Close()
|
||||
assert.NotNil(t, client)
|
||||
_, err := client.Register(ctx, "mock_vchannel_1", nil, mqwrapper.SubscriptionPositionUnknown)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
@ -49,7 +66,7 @@ func TestClient_Concurrency(t *testing.T) {
|
||||
vchannel := fmt.Sprintf("mock-vchannel-%d-%d", i, rand.Int())
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
_, err := client1.Register(vchannel, nil, mqwrapper.SubscriptionPositionUnknown)
|
||||
_, err := client1.Register(context.Background(), vchannel, nil, mqwrapper.SubscriptionPositionUnknown)
|
||||
assert.NoError(t, err)
|
||||
for j := 0; j < rand.Intn(2); j++ {
|
||||
client1.Deregister(vchannel)
|
||||
|
@ -78,7 +78,8 @@ type Dispatcher struct {
|
||||
stream msgstream.MsgStream
|
||||
}
|
||||
|
||||
func NewDispatcher(factory msgstream.Factory,
|
||||
func NewDispatcher(ctx context.Context,
|
||||
factory msgstream.Factory,
|
||||
isMain bool,
|
||||
pchannel string,
|
||||
position *Pos,
|
||||
@ -90,14 +91,19 @@ func NewDispatcher(factory msgstream.Factory,
|
||||
log := log.With(zap.String("pchannel", pchannel),
|
||||
zap.String("subName", subName), zap.Bool("isMain", isMain))
|
||||
log.Info("creating dispatcher...")
|
||||
stream, err := factory.NewTtMsgStream(context.Background())
|
||||
stream, err := factory.NewTtMsgStream(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if position != nil && len(position.MsgID) != 0 {
|
||||
position.ChannelName = funcutil.ToPhysicalChannel(position.ChannelName)
|
||||
stream.AsConsumer([]string{pchannel}, subName, mqwrapper.SubscriptionPositionUnknown)
|
||||
err = stream.Seek([]*Pos{position})
|
||||
err = stream.AsConsumer(ctx, []string{pchannel}, subName, mqwrapper.SubscriptionPositionUnknown)
|
||||
if err != nil {
|
||||
log.Error("asConsumer failed", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = stream.Seek(ctx, []*Pos{position})
|
||||
if err != nil {
|
||||
stream.Close()
|
||||
log.Error("seek failed", zap.Error(err))
|
||||
@ -107,7 +113,11 @@ func NewDispatcher(factory msgstream.Factory,
|
||||
log.Info("seek successfully", zap.Time("posTime", posTime),
|
||||
zap.Duration("tsLag", time.Since(posTime)))
|
||||
} else {
|
||||
stream.AsConsumer([]string{pchannel}, subName, subPos)
|
||||
err := stream.AsConsumer(ctx, []string{pchannel}, subName, subPos)
|
||||
if err != nil {
|
||||
log.Error("asConsumer failed", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
log.Info("asConsumer successfully")
|
||||
}
|
||||
|
||||
|
@ -21,15 +21,19 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"golang.org/x/net/context"
|
||||
|
||||
"github.com/milvus-io/milvus/pkg/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/pkg/mq/msgstream/mqwrapper"
|
||||
)
|
||||
|
||||
func TestDispatcher(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
t.Run("test base", func(t *testing.T) {
|
||||
d, err := NewDispatcher(newMockFactory(), true, "mock_pchannel_0", nil,
|
||||
d, err := NewDispatcher(ctx, newMockFactory(), true, "mock_pchannel_0", nil,
|
||||
"mock_subName_0", mqwrapper.SubscriptionPositionEarliest, nil, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.NotPanics(t, func() {
|
||||
@ -49,8 +53,23 @@ func TestDispatcher(t *testing.T) {
|
||||
assert.Equal(t, pos.Timestamp, curTs)
|
||||
})
|
||||
|
||||
t.Run("test AsConsumer fail", func(t *testing.T) {
|
||||
ms := msgstream.NewMockMsgStream(t)
|
||||
ms.EXPECT().AsConsumer(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(errors.New("mock error"))
|
||||
factory := &msgstream.MockMqFactory{
|
||||
NewMsgStreamFunc: func(ctx context.Context) (msgstream.MsgStream, error) {
|
||||
return ms, nil
|
||||
},
|
||||
}
|
||||
d, err := NewDispatcher(ctx, factory, true, "mock_pchannel_0", nil,
|
||||
"mock_subName_0", mqwrapper.SubscriptionPositionEarliest, nil, nil)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, d)
|
||||
})
|
||||
|
||||
t.Run("test target", func(t *testing.T) {
|
||||
d, err := NewDispatcher(newMockFactory(), true, "mock_pchannel_0", nil,
|
||||
d, err := NewDispatcher(ctx, newMockFactory(), true, "mock_pchannel_0", nil,
|
||||
"mock_subName_0", mqwrapper.SubscriptionPositionEarliest, nil, nil)
|
||||
assert.NoError(t, err)
|
||||
output := make(chan *msgstream.MsgPack, 1024)
|
||||
@ -113,7 +132,7 @@ func TestDispatcher(t *testing.T) {
|
||||
}
|
||||
|
||||
func BenchmarkDispatcher_handle(b *testing.B) {
|
||||
d, err := NewDispatcher(newMockFactory(), true, "mock_pchannel_0", nil,
|
||||
d, err := NewDispatcher(context.Background(), newMockFactory(), true, "mock_pchannel_0", nil,
|
||||
"mock_subName_0", mqwrapper.SubscriptionPositionEarliest, nil, nil)
|
||||
assert.NoError(b, err)
|
||||
|
||||
|
@ -39,7 +39,7 @@ var (
|
||||
)
|
||||
|
||||
type DispatcherManager interface {
|
||||
Add(vchannel string, pos *Pos, subPos SubPos) (<-chan *MsgPack, error)
|
||||
Add(ctx context.Context, vchannel string, pos *Pos, subPos SubPos) (<-chan *MsgPack, error)
|
||||
Remove(vchannel string)
|
||||
Num() int
|
||||
Run()
|
||||
@ -85,14 +85,14 @@ func (c *dispatcherManager) constructSubName(vchannel string, isMain bool) strin
|
||||
return fmt.Sprintf("%s-%d-%s-%t", c.role, c.nodeID, vchannel, isMain)
|
||||
}
|
||||
|
||||
func (c *dispatcherManager) Add(vchannel string, pos *Pos, subPos SubPos) (<-chan *MsgPack, error) {
|
||||
func (c *dispatcherManager) Add(ctx context.Context, vchannel string, pos *Pos, subPos SubPos) (<-chan *MsgPack, error) {
|
||||
log := log.With(zap.String("role", c.role),
|
||||
zap.Int64("nodeID", c.nodeID), zap.String("vchannel", vchannel))
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
isMain := c.mainDispatcher == nil
|
||||
d, err := NewDispatcher(c.factory, isMain, c.pchannel, pos,
|
||||
d, err := NewDispatcher(ctx, c.factory, isMain, c.pchannel, pos,
|
||||
c.constructSubName(vchannel, isMain), subPos, c.lagNotifyChan, c.lagTargets)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -236,7 +236,7 @@ func (c *dispatcherManager) split(t *target) {
|
||||
var newSolo *Dispatcher
|
||||
err := retry.Do(context.Background(), func() error {
|
||||
var err error
|
||||
newSolo, err = NewDispatcher(c.factory, false, c.pchannel, t.pos,
|
||||
newSolo, err = NewDispatcher(context.Background(), c.factory, false, c.pchannel, t.pos,
|
||||
c.constructSubName(t.vchannel, false), mqwrapper.SubscriptionPositionUnknown, c.lagNotifyChan, c.lagTargets)
|
||||
return err
|
||||
}, retry.Attempts(10))
|
||||
|
@ -46,7 +46,7 @@ func TestManager(t *testing.T) {
|
||||
for j := 0; j < r; j++ {
|
||||
offset++
|
||||
t.Logf("dyh add, %s", fmt.Sprintf("mock-pchannel-0_vchannel_%d", offset))
|
||||
_, err := c.Add(fmt.Sprintf("mock-pchannel-0_vchannel_%d", offset), nil, mqwrapper.SubscriptionPositionUnknown)
|
||||
_, err := c.Add(context.Background(), fmt.Sprintf("mock-pchannel-0_vchannel_%d", offset), nil, mqwrapper.SubscriptionPositionUnknown)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, offset, c.Num())
|
||||
}
|
||||
@ -61,13 +61,14 @@ func TestManager(t *testing.T) {
|
||||
|
||||
t.Run("test merge and split", func(t *testing.T) {
|
||||
prefix := fmt.Sprintf("mock%d", time.Now().UnixNano())
|
||||
ctx := context.Background()
|
||||
c := NewDispatcherManager(prefix+"_pchannel_0", typeutil.ProxyRole, 1, newMockFactory())
|
||||
assert.NotNil(t, c)
|
||||
_, err := c.Add("mock_vchannel_0", nil, mqwrapper.SubscriptionPositionUnknown)
|
||||
_, err := c.Add(ctx, "mock_vchannel_0", nil, mqwrapper.SubscriptionPositionUnknown)
|
||||
assert.NoError(t, err)
|
||||
_, err = c.Add("mock_vchannel_1", nil, mqwrapper.SubscriptionPositionUnknown)
|
||||
_, err = c.Add(ctx, "mock_vchannel_1", nil, mqwrapper.SubscriptionPositionUnknown)
|
||||
assert.NoError(t, err)
|
||||
_, err = c.Add("mock_vchannel_2", nil, mqwrapper.SubscriptionPositionUnknown)
|
||||
_, err = c.Add(ctx, "mock_vchannel_2", nil, mqwrapper.SubscriptionPositionUnknown)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 3, c.Num())
|
||||
|
||||
@ -85,13 +86,14 @@ func TestManager(t *testing.T) {
|
||||
|
||||
t.Run("test run and close", func(t *testing.T) {
|
||||
prefix := fmt.Sprintf("mock%d", time.Now().UnixNano())
|
||||
ctx := context.Background()
|
||||
c := NewDispatcherManager(prefix+"_pchannel_0", typeutil.ProxyRole, 1, newMockFactory())
|
||||
assert.NotNil(t, c)
|
||||
_, err := c.Add("mock_vchannel_0", nil, mqwrapper.SubscriptionPositionUnknown)
|
||||
_, err := c.Add(ctx, "mock_vchannel_0", nil, mqwrapper.SubscriptionPositionUnknown)
|
||||
assert.NoError(t, err)
|
||||
_, err = c.Add("mock_vchannel_1", nil, mqwrapper.SubscriptionPositionUnknown)
|
||||
_, err = c.Add(ctx, "mock_vchannel_1", nil, mqwrapper.SubscriptionPositionUnknown)
|
||||
assert.NoError(t, err)
|
||||
_, err = c.Add("mock_vchannel_2", nil, mqwrapper.SubscriptionPositionUnknown)
|
||||
_, err = c.Add(ctx, "mock_vchannel_2", nil, mqwrapper.SubscriptionPositionUnknown)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 3, c.Num())
|
||||
|
||||
@ -105,6 +107,28 @@ func TestManager(t *testing.T) {
|
||||
c.Close()
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("test add timeout", func(t *testing.T) {
|
||||
prefix := fmt.Sprintf("mock%d", time.Now().UnixNano())
|
||||
ctx := context.Background()
|
||||
ctx, cancel := context.WithTimeout(ctx, time.Millisecond*2)
|
||||
defer cancel()
|
||||
time.Sleep(time.Millisecond * 2)
|
||||
c := NewDispatcherManager(prefix+"_pchannel_0", typeutil.ProxyRole, 1, newMockFactory())
|
||||
go c.Run()
|
||||
assert.NotNil(t, c)
|
||||
_, err := c.Add(ctx, "mock_vchannel_0", nil, mqwrapper.SubscriptionPositionUnknown)
|
||||
assert.Error(t, err)
|
||||
_, err = c.Add(ctx, "mock_vchannel_1", nil, mqwrapper.SubscriptionPositionUnknown)
|
||||
assert.Error(t, err)
|
||||
_, err = c.Add(ctx, "mock_vchannel_2", nil, mqwrapper.SubscriptionPositionUnknown)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, 0, c.Num())
|
||||
|
||||
assert.NotPanics(t, func() {
|
||||
c.Close()
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
type vchannelHelper struct {
|
||||
@ -255,7 +279,7 @@ func (suite *SimulationSuite) TestDispatchToVchannels() {
|
||||
suite.vchannels = make(map[string]*vchannelHelper, vchannelNum)
|
||||
for i := 0; i < vchannelNum; i++ {
|
||||
vchannel := fmt.Sprintf("%s_vchannelv%d", suite.pchannel, i)
|
||||
output, err := suite.manager.Add(vchannel, nil, mqwrapper.SubscriptionPositionEarliest)
|
||||
output, err := suite.manager.Add(context.Background(), vchannel, nil, mqwrapper.SubscriptionPositionEarliest)
|
||||
assert.NoError(suite.T(), err)
|
||||
suite.vchannels[vchannel] = &vchannelHelper{output: output}
|
||||
}
|
||||
@ -289,7 +313,7 @@ func (suite *SimulationSuite) TestMerge() {
|
||||
|
||||
for i := 0; i < vchannelNum; i++ {
|
||||
vchannel := fmt.Sprintf("%s_vchannelv%d", suite.pchannel, i)
|
||||
output, err := suite.manager.Add(vchannel, positions[rand.Intn(len(positions))],
|
||||
output, err := suite.manager.Add(context.Background(), vchannel, positions[rand.Intn(len(positions))],
|
||||
mqwrapper.SubscriptionPositionUnknown) // seek from random position
|
||||
assert.NoError(suite.T(), err)
|
||||
suite.vchannels[vchannel] = &vchannelHelper{output: output}
|
||||
@ -325,7 +349,7 @@ func (suite *SimulationSuite) TestSplit() {
|
||||
DefaultTargetChanSize = 10
|
||||
}
|
||||
vchannel := fmt.Sprintf("%s_vchannelv%d", suite.pchannel, i)
|
||||
_, err := suite.manager.Add(vchannel, nil, mqwrapper.SubscriptionPositionEarliest)
|
||||
_, err := suite.manager.Add(context.Background(), vchannel, nil, mqwrapper.SubscriptionPositionEarliest)
|
||||
assert.NoError(suite.T(), err)
|
||||
}
|
||||
|
||||
|
@ -3,10 +3,13 @@
|
||||
package msgdispatcher
|
||||
|
||||
import (
|
||||
msgpb "github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
|
||||
context "context"
|
||||
|
||||
mqwrapper "github.com/milvus-io/milvus/pkg/mq/msgstream/mqwrapper"
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
|
||||
msgpb "github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
|
||||
|
||||
msgstream "github.com/milvus-io/milvus/pkg/mq/msgstream"
|
||||
)
|
||||
|
||||
@ -66,7 +69,7 @@ type MockClient_Deregister_Call struct {
|
||||
}
|
||||
|
||||
// Deregister is a helper method to define mock.On call
|
||||
// - vchannel string
|
||||
// - vchannel string
|
||||
func (_e *MockClient_Expecter) Deregister(vchannel interface{}) *MockClient_Deregister_Call {
|
||||
return &MockClient_Deregister_Call{Call: _e.mock.On("Deregister", vchannel)}
|
||||
}
|
||||
@ -88,25 +91,25 @@ func (_c *MockClient_Deregister_Call) RunAndReturn(run func(string)) *MockClient
|
||||
return _c
|
||||
}
|
||||
|
||||
// Register provides a mock function with given fields: vchannel, pos, subPos
|
||||
func (_m *MockClient) Register(vchannel string, pos *msgpb.MsgPosition, subPos mqwrapper.SubscriptionInitialPosition) (<-chan *msgstream.MsgPack, error) {
|
||||
ret := _m.Called(vchannel, pos, subPos)
|
||||
// Register provides a mock function with given fields: ctx, vchannel, pos, subPos
|
||||
func (_m *MockClient) Register(ctx context.Context, vchannel string, pos *msgpb.MsgPosition, subPos mqwrapper.SubscriptionInitialPosition) (<-chan *msgstream.MsgPack, error) {
|
||||
ret := _m.Called(ctx, vchannel, pos, subPos)
|
||||
|
||||
var r0 <-chan *msgstream.MsgPack
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(string, *msgpb.MsgPosition, mqwrapper.SubscriptionInitialPosition) (<-chan *msgstream.MsgPack, error)); ok {
|
||||
return rf(vchannel, pos, subPos)
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string, *msgpb.MsgPosition, mqwrapper.SubscriptionInitialPosition) (<-chan *msgstream.MsgPack, error)); ok {
|
||||
return rf(ctx, vchannel, pos, subPos)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(string, *msgpb.MsgPosition, mqwrapper.SubscriptionInitialPosition) <-chan *msgstream.MsgPack); ok {
|
||||
r0 = rf(vchannel, pos, subPos)
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string, *msgpb.MsgPosition, mqwrapper.SubscriptionInitialPosition) <-chan *msgstream.MsgPack); ok {
|
||||
r0 = rf(ctx, vchannel, pos, subPos)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(<-chan *msgstream.MsgPack)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(string, *msgpb.MsgPosition, mqwrapper.SubscriptionInitialPosition) error); ok {
|
||||
r1 = rf(vchannel, pos, subPos)
|
||||
if rf, ok := ret.Get(1).(func(context.Context, string, *msgpb.MsgPosition, mqwrapper.SubscriptionInitialPosition) error); ok {
|
||||
r1 = rf(ctx, vchannel, pos, subPos)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
@ -120,16 +123,17 @@ type MockClient_Register_Call struct {
|
||||
}
|
||||
|
||||
// Register is a helper method to define mock.On call
|
||||
// - vchannel string
|
||||
// - pos *msgpb.MsgPosition
|
||||
// - subPos mqwrapper.SubscriptionInitialPosition
|
||||
func (_e *MockClient_Expecter) Register(vchannel interface{}, pos interface{}, subPos interface{}) *MockClient_Register_Call {
|
||||
return &MockClient_Register_Call{Call: _e.mock.On("Register", vchannel, pos, subPos)}
|
||||
// - ctx context.Context
|
||||
// - vchannel string
|
||||
// - pos *msgpb.MsgPosition
|
||||
// - subPos mqwrapper.SubscriptionInitialPosition
|
||||
func (_e *MockClient_Expecter) Register(ctx interface{}, vchannel interface{}, pos interface{}, subPos interface{}) *MockClient_Register_Call {
|
||||
return &MockClient_Register_Call{Call: _e.mock.On("Register", ctx, vchannel, pos, subPos)}
|
||||
}
|
||||
|
||||
func (_c *MockClient_Register_Call) Run(run func(vchannel string, pos *msgpb.MsgPosition, subPos mqwrapper.SubscriptionInitialPosition)) *MockClient_Register_Call {
|
||||
func (_c *MockClient_Register_Call) Run(run func(ctx context.Context, vchannel string, pos *msgpb.MsgPosition, subPos mqwrapper.SubscriptionInitialPosition)) *MockClient_Register_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(string), args[1].(*msgpb.MsgPosition), args[2].(mqwrapper.SubscriptionInitialPosition))
|
||||
run(args[0].(context.Context), args[1].(string), args[2].(*msgpb.MsgPosition), args[3].(mqwrapper.SubscriptionInitialPosition))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
@ -139,7 +143,7 @@ func (_c *MockClient_Register_Call) Return(_a0 <-chan *msgstream.MsgPack, _a1 er
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockClient_Register_Call) RunAndReturn(run func(string, *msgpb.MsgPosition, mqwrapper.SubscriptionInitialPosition) (<-chan *msgstream.MsgPack, error)) *MockClient_Register_Call {
|
||||
func (_c *MockClient_Register_Call) RunAndReturn(run func(context.Context, string, *msgpb.MsgPosition, mqwrapper.SubscriptionInitialPosition) (<-chan *msgstream.MsgPack, error)) *MockClient_Register_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
@ -66,7 +66,7 @@ func getSeekPositions(factory msgstream.Factory, pchannel string, maxNum int) ([
|
||||
return nil, err
|
||||
}
|
||||
defer stream.Close()
|
||||
stream.AsConsumer([]string{pchannel}, fmt.Sprintf("%d", rand.Int()), mqwrapper.SubscriptionPositionEarliest)
|
||||
stream.AsConsumer(context.TODO(), []string{pchannel}, fmt.Sprintf("%d", rand.Int()), mqwrapper.SubscriptionPositionEarliest)
|
||||
positions := make([]*msgstream.MsgPosition, 0)
|
||||
timeoutCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
@ -14,7 +14,7 @@ var _ Factory = &CommonFactory{}
|
||||
// It contains a function field named newer, which is a function that creates
|
||||
// an mqwrapper.Client when called.
|
||||
type CommonFactory struct {
|
||||
Newer func() (mqwrapper.Client, error) // client constructor
|
||||
Newer func(context.Context) (mqwrapper.Client, error) // client constructor
|
||||
DispatcherFactory ProtoUDFactory
|
||||
ReceiveBufSize int64
|
||||
MQBufSize int64
|
||||
@ -23,7 +23,7 @@ type CommonFactory struct {
|
||||
// NewMsgStream is used to generate a new Msgstream object
|
||||
func (f *CommonFactory) NewMsgStream(ctx context.Context) (ms MsgStream, err error) {
|
||||
defer wrapError(&err, "NewMsgStream")
|
||||
cli, err := f.Newer()
|
||||
cli, err := f.Newer(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -33,7 +33,7 @@ func (f *CommonFactory) NewMsgStream(ctx context.Context) (ms MsgStream, err err
|
||||
// NewTtMsgStream is used to generate a new TtMsgstream object
|
||||
func (f *CommonFactory) NewTtMsgStream(ctx context.Context) (ms MsgStream, err error) {
|
||||
defer wrapError(&err, "NewTtMsgStream")
|
||||
cli, err := f.Newer()
|
||||
cli, err := f.Newer(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -50,7 +50,7 @@ func (f *CommonFactory) NewMsgStreamDisposer(ctx context.Context) func([]string,
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
msgs.AsConsumer(channels, subName, mqwrapper.SubscriptionPositionUnknown)
|
||||
msgs.AsConsumer(ctx, channels, subName, mqwrapper.SubscriptionPositionUnknown)
|
||||
msgs.Close()
|
||||
return nil
|
||||
}
|
||||
|
@ -764,8 +764,8 @@ func consume(ctx context.Context, mq MsgStream) *MsgPack {
|
||||
func createAndSeekConsumer(ctx context.Context, t *testing.T, newer streamNewer, channels []string, seekPositions []*msgpb.MsgPosition) MsgStream {
|
||||
consumer, err := newer(ctx)
|
||||
assert.NoError(t, err)
|
||||
consumer.AsConsumer(channels, funcutil.RandomString(8), mqwrapper.SubscriptionPositionUnknown)
|
||||
err = consumer.Seek(seekPositions)
|
||||
consumer.AsConsumer(context.Background(), channels, funcutil.RandomString(8), mqwrapper.SubscriptionPositionUnknown)
|
||||
err = consumer.Seek(context.Background(), seekPositions)
|
||||
assert.NoError(t, err)
|
||||
return consumer
|
||||
}
|
||||
@ -780,14 +780,14 @@ func createProducer(ctx context.Context, t *testing.T, newer streamNewer, channe
|
||||
func createConsumer(ctx context.Context, t *testing.T, newer streamNewer, channels []string) MsgStream {
|
||||
consumer, err := newer(ctx)
|
||||
assert.NoError(t, err)
|
||||
consumer.AsConsumer(channels, funcutil.RandomString(8), mqwrapper.SubscriptionPositionEarliest)
|
||||
consumer.AsConsumer(context.Background(), channels, funcutil.RandomString(8), mqwrapper.SubscriptionPositionEarliest)
|
||||
return consumer
|
||||
}
|
||||
|
||||
func createLatestConsumer(ctx context.Context, t *testing.T, newer streamNewer, channels []string) MsgStream {
|
||||
consumer, err := newer(ctx)
|
||||
assert.NoError(t, err)
|
||||
consumer.AsConsumer(channels, funcutil.RandomString(8), mqwrapper.SubscriptionPositionLatest)
|
||||
consumer.AsConsumer(context.Background(), channels, funcutil.RandomString(8), mqwrapper.SubscriptionPositionLatest)
|
||||
return consumer
|
||||
}
|
||||
|
||||
@ -801,7 +801,7 @@ func createStream(ctx context.Context, t *testing.T, newer []streamNewer, channe
|
||||
|
||||
consumer, err := newer[1](ctx)
|
||||
assert.NoError(t, err)
|
||||
consumer.AsConsumer(channels, funcutil.RandomString(8), mqwrapper.SubscriptionPositionEarliest)
|
||||
consumer.AsConsumer(context.Background(), channels, funcutil.RandomString(8), mqwrapper.SubscriptionPositionEarliest)
|
||||
|
||||
return producer, consumer
|
||||
}
|
||||
|
@ -45,7 +45,7 @@ func TestNmq(t *testing.T) {
|
||||
f1 := NewNatsmqFactory()
|
||||
f2 := NewNatsmqFactory()
|
||||
|
||||
client, err := nmq.NewClientWithDefaultOptions()
|
||||
client, err := nmq.NewClientWithDefaultOptions(context.Background())
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
@ -14,3 +14,7 @@ func NewMockMqFactory() *MockMqFactory {
|
||||
func (m MockMqFactory) NewMsgStream(ctx context.Context) (MsgStream, error) {
|
||||
return m.NewMsgStreamFunc(ctx)
|
||||
}
|
||||
|
||||
func (m MockMqFactory) NewTtMsgStream(ctx context.Context) (MsgStream, error) {
|
||||
return m.NewMsgStreamFunc(ctx)
|
||||
}
|
||||
|
@ -3,9 +3,12 @@
|
||||
package msgstream
|
||||
|
||||
import (
|
||||
msgpb "github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
|
||||
context "context"
|
||||
|
||||
mqwrapper "github.com/milvus-io/milvus/pkg/mq/msgstream/mqwrapper"
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
|
||||
msgpb "github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
|
||||
)
|
||||
|
||||
// MockMsgStream is an autogenerated mock type for the MsgStream type
|
||||
@ -21,9 +24,18 @@ func (_m *MockMsgStream) EXPECT() *MockMsgStream_Expecter {
|
||||
return &MockMsgStream_Expecter{mock: &_m.Mock}
|
||||
}
|
||||
|
||||
// AsConsumer provides a mock function with given fields: channels, subName, position
|
||||
func (_m *MockMsgStream) AsConsumer(channels []string, subName string, position mqwrapper.SubscriptionInitialPosition) {
|
||||
_m.Called(channels, subName, position)
|
||||
// AsConsumer provides a mock function with given fields: ctx, channels, subName, position
|
||||
func (_m *MockMsgStream) AsConsumer(ctx context.Context, channels []string, subName string, position mqwrapper.SubscriptionInitialPosition) error {
|
||||
ret := _m.Called(ctx, channels, subName, position)
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, []string, string, mqwrapper.SubscriptionInitialPosition) error); ok {
|
||||
r0 = rf(ctx, channels, subName, position)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// MockMsgStream_AsConsumer_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AsConsumer'
|
||||
@ -32,26 +44,27 @@ type MockMsgStream_AsConsumer_Call struct {
|
||||
}
|
||||
|
||||
// AsConsumer is a helper method to define mock.On call
|
||||
// - channels []string
|
||||
// - subName string
|
||||
// - position mqwrapper.SubscriptionInitialPosition
|
||||
func (_e *MockMsgStream_Expecter) AsConsumer(channels interface{}, subName interface{}, position interface{}) *MockMsgStream_AsConsumer_Call {
|
||||
return &MockMsgStream_AsConsumer_Call{Call: _e.mock.On("AsConsumer", channels, subName, position)}
|
||||
// - ctx context.Context
|
||||
// - channels []string
|
||||
// - subName string
|
||||
// - position mqwrapper.SubscriptionInitialPosition
|
||||
func (_e *MockMsgStream_Expecter) AsConsumer(ctx interface{}, channels interface{}, subName interface{}, position interface{}) *MockMsgStream_AsConsumer_Call {
|
||||
return &MockMsgStream_AsConsumer_Call{Call: _e.mock.On("AsConsumer", ctx, channels, subName, position)}
|
||||
}
|
||||
|
||||
func (_c *MockMsgStream_AsConsumer_Call) Run(run func(channels []string, subName string, position mqwrapper.SubscriptionInitialPosition)) *MockMsgStream_AsConsumer_Call {
|
||||
func (_c *MockMsgStream_AsConsumer_Call) Run(run func(ctx context.Context, channels []string, subName string, position mqwrapper.SubscriptionInitialPosition)) *MockMsgStream_AsConsumer_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].([]string), args[1].(string), args[2].(mqwrapper.SubscriptionInitialPosition))
|
||||
run(args[0].(context.Context), args[1].([]string), args[2].(string), args[3].(mqwrapper.SubscriptionInitialPosition))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockMsgStream_AsConsumer_Call) Return() *MockMsgStream_AsConsumer_Call {
|
||||
_c.Call.Return()
|
||||
func (_c *MockMsgStream_AsConsumer_Call) Return(_a0 error) *MockMsgStream_AsConsumer_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockMsgStream_AsConsumer_Call) RunAndReturn(run func([]string, string, mqwrapper.SubscriptionInitialPosition)) *MockMsgStream_AsConsumer_Call {
|
||||
func (_c *MockMsgStream_AsConsumer_Call) RunAndReturn(run func(context.Context, []string, string, mqwrapper.SubscriptionInitialPosition) error) *MockMsgStream_AsConsumer_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
@ -67,7 +80,7 @@ type MockMsgStream_AsProducer_Call struct {
|
||||
}
|
||||
|
||||
// AsProducer is a helper method to define mock.On call
|
||||
// - channels []string
|
||||
// - channels []string
|
||||
func (_e *MockMsgStream_Expecter) AsProducer(channels interface{}) *MockMsgStream_AsProducer_Call {
|
||||
return &MockMsgStream_AsProducer_Call{Call: _e.mock.On("AsProducer", channels)}
|
||||
}
|
||||
@ -121,7 +134,7 @@ type MockMsgStream_Broadcast_Call struct {
|
||||
}
|
||||
|
||||
// Broadcast is a helper method to define mock.On call
|
||||
// - _a0 *MsgPack
|
||||
// - _a0 *MsgPack
|
||||
func (_e *MockMsgStream_Expecter) Broadcast(_a0 interface{}) *MockMsgStream_Broadcast_Call {
|
||||
return &MockMsgStream_Broadcast_Call{Call: _e.mock.On("Broadcast", _a0)}
|
||||
}
|
||||
@ -206,7 +219,7 @@ type MockMsgStream_CheckTopicValid_Call struct {
|
||||
}
|
||||
|
||||
// CheckTopicValid is a helper method to define mock.On call
|
||||
// - channel string
|
||||
// - channel string
|
||||
func (_e *MockMsgStream_Expecter) CheckTopicValid(channel interface{}) *MockMsgStream_CheckTopicValid_Call {
|
||||
return &MockMsgStream_CheckTopicValid_Call{Call: _e.mock.On("CheckTopicValid", channel)}
|
||||
}
|
||||
@ -292,7 +305,7 @@ type MockMsgStream_GetLatestMsgID_Call struct {
|
||||
}
|
||||
|
||||
// GetLatestMsgID is a helper method to define mock.On call
|
||||
// - channel string
|
||||
// - channel string
|
||||
func (_e *MockMsgStream_Expecter) GetLatestMsgID(channel interface{}) *MockMsgStream_GetLatestMsgID_Call {
|
||||
return &MockMsgStream_GetLatestMsgID_Call{Call: _e.mock.On("GetLatestMsgID", channel)}
|
||||
}
|
||||
@ -377,7 +390,7 @@ type MockMsgStream_Produce_Call struct {
|
||||
}
|
||||
|
||||
// Produce is a helper method to define mock.On call
|
||||
// - _a0 *MsgPack
|
||||
// - _a0 *MsgPack
|
||||
func (_e *MockMsgStream_Expecter) Produce(_a0 interface{}) *MockMsgStream_Produce_Call {
|
||||
return &MockMsgStream_Produce_Call{Call: _e.mock.On("Produce", _a0)}
|
||||
}
|
||||
@ -399,13 +412,13 @@ func (_c *MockMsgStream_Produce_Call) RunAndReturn(run func(*MsgPack) error) *Mo
|
||||
return _c
|
||||
}
|
||||
|
||||
// Seek provides a mock function with given fields: offset
|
||||
func (_m *MockMsgStream) Seek(offset []*msgpb.MsgPosition) error {
|
||||
ret := _m.Called(offset)
|
||||
// Seek provides a mock function with given fields: ctx, offset
|
||||
func (_m *MockMsgStream) Seek(ctx context.Context, offset []*msgpb.MsgPosition) error {
|
||||
ret := _m.Called(ctx, offset)
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func([]*msgpb.MsgPosition) error); ok {
|
||||
r0 = rf(offset)
|
||||
if rf, ok := ret.Get(0).(func(context.Context, []*msgpb.MsgPosition) error); ok {
|
||||
r0 = rf(ctx, offset)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
@ -419,14 +432,15 @@ type MockMsgStream_Seek_Call struct {
|
||||
}
|
||||
|
||||
// Seek is a helper method to define mock.On call
|
||||
// - offset []*msgpb.MsgPosition
|
||||
func (_e *MockMsgStream_Expecter) Seek(offset interface{}) *MockMsgStream_Seek_Call {
|
||||
return &MockMsgStream_Seek_Call{Call: _e.mock.On("Seek", offset)}
|
||||
// - ctx context.Context
|
||||
// - offset []*msgpb.MsgPosition
|
||||
func (_e *MockMsgStream_Expecter) Seek(ctx interface{}, offset interface{}) *MockMsgStream_Seek_Call {
|
||||
return &MockMsgStream_Seek_Call{Call: _e.mock.On("Seek", ctx, offset)}
|
||||
}
|
||||
|
||||
func (_c *MockMsgStream_Seek_Call) Run(run func(offset []*msgpb.MsgPosition)) *MockMsgStream_Seek_Call {
|
||||
func (_c *MockMsgStream_Seek_Call) Run(run func(ctx context.Context, offset []*msgpb.MsgPosition)) *MockMsgStream_Seek_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].([]*msgpb.MsgPosition))
|
||||
run(args[0].(context.Context), args[1].([]*msgpb.MsgPosition))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
@ -436,7 +450,7 @@ func (_c *MockMsgStream_Seek_Call) Return(_a0 error) *MockMsgStream_Seek_Call {
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockMsgStream_Seek_Call) RunAndReturn(run func([]*msgpb.MsgPosition) error) *MockMsgStream_Seek_Call {
|
||||
func (_c *MockMsgStream_Seek_Call) RunAndReturn(run func(context.Context, []*msgpb.MsgPosition) error) *MockMsgStream_Seek_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
@ -452,7 +466,7 @@ type MockMsgStream_SetRepackFunc_Call struct {
|
||||
}
|
||||
|
||||
// SetRepackFunc is a helper method to define mock.On call
|
||||
// - repackFunc RepackFunc
|
||||
// - repackFunc RepackFunc
|
||||
func (_e *MockMsgStream_Expecter) SetRepackFunc(repackFunc interface{}) *MockMsgStream_SetRepackFunc_Call {
|
||||
return &MockMsgStream_SetRepackFunc_Call{Call: _e.mock.On("SetRepackFunc", repackFunc)}
|
||||
}
|
||||
|
@ -68,6 +68,15 @@ func NewPmsFactory(serviceParam *paramtable.ServiceParam) *PmsFactory {
|
||||
|
||||
// NewMsgStream is used to generate a new Msgstream object
|
||||
func (f *PmsFactory) NewMsgStream(ctx context.Context) (MsgStream, error) {
|
||||
var timeout time.Duration = f.RequestTimeout
|
||||
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
if deadline.Before(time.Now()) {
|
||||
return nil, errors.New("context timeout when NewMsgStream")
|
||||
}
|
||||
timeout = time.Until(deadline)
|
||||
}
|
||||
|
||||
auth, err := f.getAuthentication()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -75,7 +84,7 @@ func (f *PmsFactory) NewMsgStream(ctx context.Context) (MsgStream, error) {
|
||||
clientOpts := pulsar.ClientOptions{
|
||||
URL: f.PulsarAddress,
|
||||
Authentication: auth,
|
||||
OperationTimeout: f.RequestTimeout,
|
||||
OperationTimeout: timeout,
|
||||
}
|
||||
|
||||
pulsarClient, err := pulsarmqwrapper.NewClient(f.PulsarTenant, f.PulsarNameSpace, clientOpts)
|
||||
@ -87,13 +96,21 @@ func (f *PmsFactory) NewMsgStream(ctx context.Context) (MsgStream, error) {
|
||||
|
||||
// NewTtMsgStream is used to generate a new TtMsgstream object
|
||||
func (f *PmsFactory) NewTtMsgStream(ctx context.Context) (MsgStream, error) {
|
||||
var timeout time.Duration = f.RequestTimeout
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
if deadline.Before(time.Now()) {
|
||||
return nil, errors.New("context timeout when NewTtMsgStream")
|
||||
}
|
||||
timeout = time.Until(deadline)
|
||||
}
|
||||
auth, err := f.getAuthentication()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
clientOpts := pulsar.ClientOptions{
|
||||
URL: f.PulsarAddress,
|
||||
Authentication: auth,
|
||||
URL: f.PulsarAddress,
|
||||
Authentication: auth,
|
||||
OperationTimeout: timeout,
|
||||
}
|
||||
|
||||
pulsarClient, err := pulsarmqwrapper.NewClient(f.PulsarTenant, f.PulsarNameSpace, clientOpts)
|
||||
@ -156,12 +173,18 @@ type KmsFactory struct {
|
||||
}
|
||||
|
||||
func (f *KmsFactory) NewMsgStream(ctx context.Context) (MsgStream, error) {
|
||||
kafkaClient := kafkawrapper.NewKafkaClientInstanceWithConfig(f.config)
|
||||
kafkaClient, err := kafkawrapper.NewKafkaClientInstanceWithConfig(ctx, f.config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewMqMsgStream(ctx, f.ReceiveBufSize, f.MQBufSize, kafkaClient, f.dispatcherFactory.NewUnmarshalDispatcher())
|
||||
}
|
||||
|
||||
func (f *KmsFactory) NewTtMsgStream(ctx context.Context) (MsgStream, error) {
|
||||
kafkaClient := kafkawrapper.NewKafkaClientInstanceWithConfig(f.config)
|
||||
kafkaClient, err := kafkawrapper.NewKafkaClientInstanceWithConfig(ctx, f.config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewMqTtMsgStream(ctx, f.ReceiveBufSize, f.MQBufSize, kafkaClient, f.dispatcherFactory.NewUnmarshalDispatcher())
|
||||
}
|
||||
|
||||
@ -171,7 +194,7 @@ func (f *KmsFactory) NewMsgStreamDisposer(ctx context.Context) func([]string, st
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
msgstream.AsConsumer(channels, subname, mqwrapper.SubscriptionPositionUnknown)
|
||||
msgstream.AsConsumer(ctx, channels, subname, mqwrapper.SubscriptionPositionUnknown)
|
||||
msgstream.Close()
|
||||
return nil
|
||||
}
|
||||
|
@ -19,6 +19,7 @@ package msgstream
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
@ -26,15 +27,51 @@ import (
|
||||
func TestPmsFactory(t *testing.T) {
|
||||
pmsFactory := NewPmsFactory(&Params.ServiceParam)
|
||||
|
||||
ctx := context.Background()
|
||||
_, err := pmsFactory.NewMsgStream(ctx)
|
||||
err := pmsFactory.NewMsgStreamDisposer(context.Background())([]string{"hello"}, "xx")
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = pmsFactory.NewTtMsgStream(ctx)
|
||||
assert.NoError(t, err)
|
||||
tests := []struct {
|
||||
description string
|
||||
withTimeout bool
|
||||
ctxTimeouted bool
|
||||
expectedError bool
|
||||
}{
|
||||
{"normal ctx", false, false, false},
|
||||
{"timeout ctx not timeout", true, false, false},
|
||||
{"timeout ctx timeout", true, true, true},
|
||||
}
|
||||
|
||||
err = pmsFactory.NewMsgStreamDisposer(ctx)([]string{"hello"}, "xx")
|
||||
assert.NoError(t, err)
|
||||
for _, test := range tests {
|
||||
t.Run(test.description, func(t *testing.T) {
|
||||
var cancel context.CancelFunc
|
||||
ctx := context.Background()
|
||||
if test.withTimeout {
|
||||
ctx, cancel = context.WithTimeout(ctx, time.Millisecond)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
if test.ctxTimeouted {
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
stream, err := pmsFactory.NewMsgStream(ctx)
|
||||
if test.expectedError {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, stream)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, stream)
|
||||
}
|
||||
|
||||
ttStream, err := pmsFactory.NewTtMsgStream(ctx)
|
||||
if test.expectedError {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, ttStream)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, ttStream)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPmsFactoryWithAuth(t *testing.T) {
|
||||
@ -69,13 +106,47 @@ func TestPmsFactoryWithAuth(t *testing.T) {
|
||||
func TestKafkaFactory(t *testing.T) {
|
||||
kmsFactory := NewKmsFactory(&Params.ServiceParam)
|
||||
|
||||
ctx := context.Background()
|
||||
_, err := kmsFactory.NewMsgStream(ctx)
|
||||
assert.NoError(t, err)
|
||||
tests := []struct {
|
||||
description string
|
||||
withTimeout bool
|
||||
ctxTimeouted bool
|
||||
expectedError bool
|
||||
}{
|
||||
{"normal ctx", false, false, false},
|
||||
{"timeout ctx not timeout", true, false, false},
|
||||
{"timeout ctx timeout", true, true, true},
|
||||
}
|
||||
|
||||
_, err = kmsFactory.NewTtMsgStream(ctx)
|
||||
assert.NoError(t, err)
|
||||
for _, test := range tests {
|
||||
t.Run(test.description, func(t *testing.T) {
|
||||
var cancel context.CancelFunc
|
||||
ctx := context.Background()
|
||||
timeoutDur := time.Millisecond * 30
|
||||
if test.withTimeout {
|
||||
ctx, cancel = context.WithTimeout(ctx, timeoutDur)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
// err = kmsFactory.NewMsgStreamDisposer(ctx)([]string{"hello"}, "xx")
|
||||
// assert.NoError(t, err)
|
||||
if test.ctxTimeouted {
|
||||
time.Sleep(timeoutDur)
|
||||
}
|
||||
stream, err := kmsFactory.NewMsgStream(ctx)
|
||||
if test.expectedError {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, stream)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, stream)
|
||||
}
|
||||
|
||||
ttStream, err := kmsFactory.NewTtMsgStream(ctx)
|
||||
if test.expectedError {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, ttStream)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, ttStream)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -145,7 +145,7 @@ func TestStream_KafkaMsgStream_SeekToLast(t *testing.T) {
|
||||
defer outputStream2.Close()
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = outputStream2.Seek([]*msgpb.MsgPosition{seekPosition})
|
||||
err = outputStream2.Seek(ctx, []*msgpb.MsgPosition{seekPosition})
|
||||
assert.NoError(t, err)
|
||||
|
||||
cnt := 0
|
||||
@ -408,7 +408,7 @@ func TestStream_KafkaTtMsgStream_DataNodeTimetickMsgstream(t *testing.T) {
|
||||
factory := ProtoUDFactory{}
|
||||
kafkaClient := kafkawrapper.NewKafkaClientInstance(kafkaAddress)
|
||||
outputStream, _ := NewMqTtMsgStream(ctx, 100, 100, kafkaClient, factory.NewUnmarshalDispatcher())
|
||||
outputStream.AsConsumer(consumerChannels, consumerSubName, mqwrapper.SubscriptionPositionLatest)
|
||||
outputStream.AsConsumer(context.Background(), consumerChannels, consumerSubName, mqwrapper.SubscriptionPositionLatest)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
@ -462,7 +462,7 @@ func getKafkaOutputStream(ctx context.Context, kafkaAddress string, consumerChan
|
||||
factory := ProtoUDFactory{}
|
||||
kafkaClient := kafkawrapper.NewKafkaClientInstance(kafkaAddress)
|
||||
outputStream, _ := NewMqMsgStream(ctx, 100, 100, kafkaClient, factory.NewUnmarshalDispatcher())
|
||||
outputStream.AsConsumer(consumerChannels, consumerSubName, position)
|
||||
outputStream.AsConsumer(context.Background(), consumerChannels, consumerSubName, position)
|
||||
return outputStream
|
||||
}
|
||||
|
||||
@ -470,7 +470,7 @@ func getKafkaTtOutputStream(ctx context.Context, kafkaAddress string, consumerCh
|
||||
factory := ProtoUDFactory{}
|
||||
kafkaClient := kafkawrapper.NewKafkaClientInstance(kafkaAddress)
|
||||
outputStream, _ := NewMqTtMsgStream(ctx, 100, 100, kafkaClient, factory.NewUnmarshalDispatcher())
|
||||
outputStream.AsConsumer(consumerChannels, consumerSubName, mqwrapper.SubscriptionPositionEarliest)
|
||||
outputStream.AsConsumer(context.Background(), consumerChannels, consumerSubName, mqwrapper.SubscriptionPositionEarliest)
|
||||
return outputStream
|
||||
}
|
||||
|
||||
@ -482,7 +482,7 @@ func getKafkaTtOutputStreamAndSeek(ctx context.Context, kafkaAddress string, pos
|
||||
for _, c := range positions {
|
||||
consumerName = append(consumerName, c.ChannelName)
|
||||
}
|
||||
outputStream.AsConsumer(consumerName, funcutil.RandomString(8), mqwrapper.SubscriptionPositionUnknown)
|
||||
outputStream.Seek(positions)
|
||||
outputStream.AsConsumer(context.Background(), consumerName, funcutil.RandomString(8), mqwrapper.SubscriptionPositionUnknown)
|
||||
outputStream.Seek(context.Background(), positions)
|
||||
return outputStream
|
||||
}
|
||||
|
@ -33,6 +33,7 @@ import (
|
||||
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/mq/msgstream/mqwrapper"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/util/retry"
|
||||
"github.com/milvus-io/milvus/pkg/util/tsoutil"
|
||||
@ -146,7 +147,7 @@ func (ms *mqMsgStream) CheckTopicValid(channel string) error {
|
||||
|
||||
// AsConsumerWithPosition Create consumer to receive message from channels, with initial position
|
||||
// if initial position is set to latest, last message in the channel is exclusive
|
||||
func (ms *mqMsgStream) AsConsumer(channels []string, subName string, position mqwrapper.SubscriptionInitialPosition) {
|
||||
func (ms *mqMsgStream) AsConsumer(ctx context.Context, channels []string, subName string, position mqwrapper.SubscriptionInitialPosition) error {
|
||||
for _, channel := range channels {
|
||||
if _, ok := ms.consumers[channel]; ok {
|
||||
continue
|
||||
@ -171,14 +172,19 @@ func (ms *mqMsgStream) AsConsumer(channels []string, subName string, position mq
|
||||
ms.consumerChannels = append(ms.consumerChannels, channel)
|
||||
return nil
|
||||
}
|
||||
// TODO if know the former subscribe is invalid, should we use pulsarctl to accelerate recovery speed
|
||||
err := retry.Do(context.TODO(), fn, retry.Attempts(50), retry.Sleep(time.Millisecond*200), retry.MaxSleepTime(5*time.Second))
|
||||
|
||||
err := retry.Do(ctx, fn, retry.Attempts(20), retry.Sleep(time.Millisecond*200), retry.MaxSleepTime(5*time.Second))
|
||||
if err != nil {
|
||||
errMsg := "Failed to create consumer " + channel + ", error = " + err.Error()
|
||||
panic(errMsg)
|
||||
errMsg := fmt.Sprintf("Failed to create consumer %s", channel)
|
||||
if merr.IsCanceledOrTimeout(err) {
|
||||
return errors.Wrapf(err, errMsg)
|
||||
}
|
||||
|
||||
panic(fmt.Sprintf("%s, errors = %s", errMsg, err.Error()))
|
||||
}
|
||||
log.Info("Successfully create consumer", zap.String("channel", channel), zap.String("subname", subName))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ms *mqMsgStream) SetRepackFunc(repackFunc RepackFunc) {
|
||||
@ -425,7 +431,7 @@ func (ms *mqMsgStream) Chan() <-chan *MsgPack {
|
||||
|
||||
// Seek reset the subscription associated with this consumer to a specific position, the seek position is exclusive
|
||||
// User has to ensure mq_msgstream is not closed before seek, and the seek position is already written.
|
||||
func (ms *mqMsgStream) Seek(msgPositions []*msgpb.MsgPosition) error {
|
||||
func (ms *mqMsgStream) Seek(ctx context.Context, msgPositions []*msgpb.MsgPosition) error {
|
||||
for _, mp := range msgPositions {
|
||||
consumer, ok := ms.consumers[mp.ChannelName]
|
||||
if !ok {
|
||||
@ -509,7 +515,7 @@ func (ms *MqTtMsgStream) addConsumer(consumer mqwrapper.Consumer, channel string
|
||||
}
|
||||
|
||||
// AsConsumerWithPosition subscribes channels as consumer for a MsgStream and seeks to a certain position.
|
||||
func (ms *MqTtMsgStream) AsConsumer(channels []string, subName string, position mqwrapper.SubscriptionInitialPosition) {
|
||||
func (ms *MqTtMsgStream) AsConsumer(ctx context.Context, channels []string, subName string, position mqwrapper.SubscriptionInitialPosition) error {
|
||||
for _, channel := range channels {
|
||||
if _, ok := ms.consumers[channel]; ok {
|
||||
continue
|
||||
@ -533,12 +539,19 @@ func (ms *MqTtMsgStream) AsConsumer(channels []string, subName string, position
|
||||
ms.addConsumer(pc, channel)
|
||||
return nil
|
||||
}
|
||||
err := retry.Do(context.TODO(), fn, retry.Attempts(20), retry.Sleep(time.Millisecond*200), retry.MaxSleepTime(5*time.Second))
|
||||
|
||||
err := retry.Do(ctx, fn, retry.Attempts(20), retry.Sleep(time.Millisecond*200), retry.MaxSleepTime(5*time.Second))
|
||||
if err != nil {
|
||||
errMsg := "Failed to create consumer " + channel + ", error = " + err.Error()
|
||||
panic(errMsg)
|
||||
errMsg := fmt.Sprintf("Failed to create consumer %s", channel)
|
||||
if merr.IsCanceledOrTimeout(err) {
|
||||
return errors.Wrapf(err, errMsg)
|
||||
}
|
||||
|
||||
panic(fmt.Sprintf("%s, errors = %s", errMsg, err.Error()))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close will stop goroutine and free internal producers and consumers
|
||||
@ -773,7 +786,7 @@ func (ms *MqTtMsgStream) allChanReachSameTtMsg(chanTtMsgSync map[mqwrapper.Consu
|
||||
}
|
||||
|
||||
// Seek to the specified position
|
||||
func (ms *MqTtMsgStream) Seek(msgPositions []*msgpb.MsgPosition) error {
|
||||
func (ms *MqTtMsgStream) Seek(ctx context.Context, msgPositions []*msgpb.MsgPosition) error {
|
||||
var consumer mqwrapper.Consumer
|
||||
var mp *MsgPosition
|
||||
var err error
|
||||
@ -815,7 +828,7 @@ func (ms *MqTtMsgStream) Seek(msgPositions []*msgpb.MsgPosition) error {
|
||||
if len(mp.MsgID) == 0 {
|
||||
return fmt.Errorf("when msgID's length equal to 0, please use AsConsumer interface")
|
||||
}
|
||||
err = retry.Do(context.TODO(), fn, retry.Attempts(20), retry.Sleep(time.Millisecond*200), retry.MaxSleepTime(5*time.Second))
|
||||
err = retry.Do(ctx, fn, retry.Attempts(20), retry.Sleep(time.Millisecond*200), retry.MaxSleepTime(5*time.Second))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to seek, error %s", err.Error())
|
||||
}
|
||||
@ -828,6 +841,8 @@ func (ms *MqTtMsgStream) Seek(msgPositions []*msgpb.MsgPosition) error {
|
||||
select {
|
||||
case <-ms.ctx.Done():
|
||||
return ms.ctx.Err()
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case msg, ok := <-consumer.Chan():
|
||||
if !ok {
|
||||
return fmt.Errorf("consumer closed")
|
||||
|
@ -250,7 +250,7 @@ func TestStream_PulsarMsgStream_InsertRepackFunc(t *testing.T) {
|
||||
|
||||
pulsarClient2, _ := pulsarwrapper.NewClient(DefaultPulsarTenant, DefaultPulsarNamespace, pulsar.ClientOptions{URL: pulsarAddress})
|
||||
outputStream, _ := NewMqMsgStream(ctx, 100, 100, pulsarClient2, factory.NewUnmarshalDispatcher())
|
||||
outputStream.AsConsumer(consumerChannels, consumerSubName, mqwrapper.SubscriptionPositionEarliest)
|
||||
outputStream.AsConsumer(ctx, consumerChannels, consumerSubName, mqwrapper.SubscriptionPositionEarliest)
|
||||
var output MsgStream = outputStream
|
||||
|
||||
err := (*inputStream).Produce(&msgPack)
|
||||
@ -301,7 +301,7 @@ func TestStream_PulsarMsgStream_DeleteRepackFunc(t *testing.T) {
|
||||
|
||||
pulsarClient2, _ := pulsarwrapper.NewClient(DefaultPulsarTenant, DefaultPulsarNamespace, pulsar.ClientOptions{URL: pulsarAddress})
|
||||
outputStream, _ := NewMqMsgStream(ctx, 100, 100, pulsarClient2, factory.NewUnmarshalDispatcher())
|
||||
outputStream.AsConsumer(consumerChannels, consumerSubName, mqwrapper.SubscriptionPositionEarliest)
|
||||
outputStream.AsConsumer(ctx, consumerChannels, consumerSubName, mqwrapper.SubscriptionPositionEarliest)
|
||||
var output MsgStream = outputStream
|
||||
|
||||
err := (*inputStream).Produce(&msgPack)
|
||||
@ -333,7 +333,7 @@ func TestStream_PulsarMsgStream_DefaultRepackFunc(t *testing.T) {
|
||||
|
||||
pulsarClient2, _ := pulsarwrapper.NewClient(DefaultPulsarTenant, DefaultPulsarNamespace, pulsar.ClientOptions{URL: pulsarAddress})
|
||||
outputStream, _ := NewMqMsgStream(ctx, 100, 100, pulsarClient2, factory.NewUnmarshalDispatcher())
|
||||
outputStream.AsConsumer(consumerChannels, consumerSubName, mqwrapper.SubscriptionPositionEarliest)
|
||||
outputStream.AsConsumer(ctx, consumerChannels, consumerSubName, mqwrapper.SubscriptionPositionEarliest)
|
||||
var output MsgStream = outputStream
|
||||
|
||||
err := (*inputStream).Produce(&msgPack)
|
||||
@ -482,12 +482,12 @@ func TestStream_PulsarMsgStream_SeekToLast(t *testing.T) {
|
||||
factory := ProtoUDFactory{}
|
||||
pulsarClient, _ := pulsarwrapper.NewClient(DefaultPulsarTenant, DefaultPulsarNamespace, pulsar.ClientOptions{URL: pulsarAddress})
|
||||
outputStream2, _ := NewMqMsgStream(ctx, 100, 100, pulsarClient, factory.NewUnmarshalDispatcher())
|
||||
outputStream2.AsConsumer(consumerChannels, consumerSubName, mqwrapper.SubscriptionPositionEarliest)
|
||||
outputStream2.AsConsumer(ctx, consumerChannels, consumerSubName, mqwrapper.SubscriptionPositionEarliest)
|
||||
lastMsgID, err := outputStream2.GetLatestMsgID(c)
|
||||
defer outputStream2.Close()
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = outputStream2.Seek([]*msgpb.MsgPosition{seekPosition})
|
||||
err = outputStream2.Seek(ctx, []*msgpb.MsgPosition{seekPosition})
|
||||
assert.NoError(t, err)
|
||||
|
||||
cnt := 0
|
||||
@ -521,8 +521,34 @@ func TestStream_PulsarMsgStream_SeekToLast(t *testing.T) {
|
||||
assert.Equal(t, 4, cnt)
|
||||
}
|
||||
|
||||
func TestStream_MsgStream_AsConsumerCtxDone(t *testing.T) {
|
||||
pulsarAddress := getPulsarAddress()
|
||||
|
||||
t.Run("MsgStream AsConsumer with timeout context", func(t *testing.T) {
|
||||
c1 := funcutil.RandomString(8)
|
||||
consumerChannels := []string{c1}
|
||||
consumerSubName := funcutil.RandomString(8)
|
||||
|
||||
ctx := context.Background()
|
||||
factory := ProtoUDFactory{}
|
||||
pulsarClient, _ := pulsarwrapper.NewClient(DefaultPulsarTenant, DefaultPulsarNamespace, pulsar.ClientOptions{URL: pulsarAddress})
|
||||
outputStream, _ := NewMqTtMsgStream(ctx, 100, 100, pulsarClient, factory.NewUnmarshalDispatcher())
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, time.Millisecond)
|
||||
defer cancel()
|
||||
<-time.After(2 * time.Millisecond)
|
||||
err := outputStream.AsConsumer(ctx, consumerChannels, consumerSubName, mqwrapper.SubscriptionPositionEarliest)
|
||||
assert.Error(t, err)
|
||||
|
||||
omsgstream, _ := NewMqMsgStream(ctx, 100, 100, pulsarClient, factory.NewUnmarshalDispatcher())
|
||||
err = omsgstream.AsConsumer(ctx, consumerChannels, consumerSubName, mqwrapper.SubscriptionPositionEarliest)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestStream_PulsarTtMsgStream_Seek(t *testing.T) {
|
||||
pulsarAddress := getPulsarAddress()
|
||||
|
||||
c1 := funcutil.RandomString(8)
|
||||
producerChannels := []string{c1}
|
||||
consumerChannels := []string{c1}
|
||||
@ -889,8 +915,8 @@ func TestStream_MqMsgStream_Seek(t *testing.T) {
|
||||
factory := ProtoUDFactory{}
|
||||
pulsarClient, _ := pulsarwrapper.NewClient(DefaultPulsarTenant, DefaultPulsarNamespace, pulsar.ClientOptions{URL: pulsarAddress})
|
||||
outputStream2, _ := NewMqMsgStream(ctx, 100, 100, pulsarClient, factory.NewUnmarshalDispatcher())
|
||||
outputStream2.AsConsumer(consumerChannels, consumerSubName, mqwrapper.SubscriptionPositionEarliest)
|
||||
outputStream2.Seek([]*msgpb.MsgPosition{seekPosition})
|
||||
outputStream2.AsConsumer(ctx, consumerChannels, consumerSubName, mqwrapper.SubscriptionPositionEarliest)
|
||||
outputStream2.Seek(ctx, []*msgpb.MsgPosition{seekPosition})
|
||||
|
||||
for i := 6; i < 10; i++ {
|
||||
result := consumer(ctx, outputStream2)
|
||||
@ -930,7 +956,7 @@ func TestStream_MqMsgStream_SeekInvalidMessage(t *testing.T) {
|
||||
factory := ProtoUDFactory{}
|
||||
pulsarClient, _ := pulsarwrapper.NewClient(DefaultPulsarTenant, DefaultPulsarNamespace, pulsar.ClientOptions{URL: pulsarAddress})
|
||||
outputStream2, _ := NewMqMsgStream(ctx, 100, 100, pulsarClient, factory.NewUnmarshalDispatcher())
|
||||
outputStream2.AsConsumer(consumerChannels, funcutil.RandomString(8), mqwrapper.SubscriptionPositionEarliest)
|
||||
outputStream2.AsConsumer(ctx, consumerChannels, funcutil.RandomString(8), mqwrapper.SubscriptionPositionEarliest)
|
||||
defer outputStream2.Close()
|
||||
messageID, _ := pulsar.DeserializeMessageID(seekPosition.MsgID)
|
||||
// try to seek to not written position
|
||||
@ -945,7 +971,7 @@ func TestStream_MqMsgStream_SeekInvalidMessage(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
err = outputStream2.Seek(p)
|
||||
err = outputStream2.Seek(ctx, p)
|
||||
assert.NoError(t, err)
|
||||
|
||||
for i := 10; i < 20; i++ {
|
||||
@ -979,7 +1005,7 @@ func TestStream_MqMsgStream_SeekLatest(t *testing.T) {
|
||||
factory := ProtoUDFactory{}
|
||||
pulsarClient, _ := pulsarwrapper.NewClient(DefaultPulsarTenant, DefaultPulsarNamespace, pulsar.ClientOptions{URL: pulsarAddress})
|
||||
outputStream2, _ := NewMqMsgStream(ctx, 100, 100, pulsarClient, factory.NewUnmarshalDispatcher())
|
||||
outputStream2.AsConsumer(consumerChannels, consumerSubName, mqwrapper.SubscriptionPositionLatest)
|
||||
outputStream2.AsConsumer(ctx, consumerChannels, consumerSubName, mqwrapper.SubscriptionPositionLatest)
|
||||
|
||||
msgPack.Msgs = nil
|
||||
// produce another 10 tsMs
|
||||
@ -1321,7 +1347,7 @@ func getPulsarOutputStream(ctx context.Context, pulsarAddress string, consumerCh
|
||||
factory := ProtoUDFactory{}
|
||||
pulsarClient, _ := pulsarwrapper.NewClient(DefaultPulsarTenant, DefaultPulsarNamespace, pulsar.ClientOptions{URL: pulsarAddress})
|
||||
outputStream, _ := NewMqMsgStream(ctx, 100, 100, pulsarClient, factory.NewUnmarshalDispatcher())
|
||||
outputStream.AsConsumer(consumerChannels, consumerSubName, mqwrapper.SubscriptionPositionEarliest)
|
||||
outputStream.AsConsumer(context.Background(), consumerChannels, consumerSubName, mqwrapper.SubscriptionPositionEarliest)
|
||||
return outputStream
|
||||
}
|
||||
|
||||
@ -1329,7 +1355,7 @@ func getPulsarTtOutputStream(ctx context.Context, pulsarAddress string, consumer
|
||||
factory := ProtoUDFactory{}
|
||||
pulsarClient, _ := pulsarwrapper.NewClient(DefaultPulsarTenant, DefaultPulsarNamespace, pulsar.ClientOptions{URL: pulsarAddress})
|
||||
outputStream, _ := NewMqTtMsgStream(ctx, 100, 100, pulsarClient, factory.NewUnmarshalDispatcher())
|
||||
outputStream.AsConsumer(consumerChannels, consumerSubName, mqwrapper.SubscriptionPositionEarliest)
|
||||
outputStream.AsConsumer(context.Background(), consumerChannels, consumerSubName, mqwrapper.SubscriptionPositionEarliest)
|
||||
return outputStream
|
||||
}
|
||||
|
||||
@ -1341,8 +1367,8 @@ func getPulsarTtOutputStreamAndSeek(ctx context.Context, pulsarAddress string, p
|
||||
for _, c := range positions {
|
||||
consumerName = append(consumerName, c.ChannelName)
|
||||
}
|
||||
outputStream.AsConsumer(consumerName, funcutil.RandomString(8), mqwrapper.SubscriptionPositionUnknown)
|
||||
outputStream.Seek(positions)
|
||||
outputStream.AsConsumer(context.Background(), consumerName, funcutil.RandomString(8), mqwrapper.SubscriptionPositionUnknown)
|
||||
outputStream.Seek(context.Background(), positions)
|
||||
return outputStream
|
||||
}
|
||||
|
||||
|
@ -1,10 +1,13 @@
|
||||
package kafka
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/confluentinc/confluent-kafka-go/kafka"
|
||||
"go.uber.org/atomic"
|
||||
"go.uber.org/zap"
|
||||
@ -54,9 +57,18 @@ func NewKafkaClientInstanceWithConfigMap(config kafka.ConfigMap, extraConsumerCo
|
||||
return &kafkaClient{basicConfig: config, consumerConfig: extraConsumerConfig, producerConfig: extraProducerConfig}
|
||||
}
|
||||
|
||||
func NewKafkaClientInstanceWithConfig(config *paramtable.KafkaConfig) *kafkaClient {
|
||||
func NewKafkaClientInstanceWithConfig(ctx context.Context, config *paramtable.KafkaConfig) (*kafkaClient, error) {
|
||||
kafkaConfig := getBasicConfig(config.Address.GetValue())
|
||||
|
||||
// connection setup timeout, default as 30000ms
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
if deadline.Before(time.Now()) {
|
||||
return nil, errors.New("context timeout when new kafka client")
|
||||
}
|
||||
timeout := time.Until(deadline).Milliseconds()
|
||||
kafkaConfig.SetKey("socket.connection.setup.timeout.ms", timeout)
|
||||
}
|
||||
|
||||
if (config.SaslUsername.GetValue() == "" && config.SaslPassword.GetValue() != "") ||
|
||||
(config.SaslUsername.GetValue() != "" && config.SaslPassword.GetValue() == "") {
|
||||
panic("enable security mode need config username and password at the same time!")
|
||||
@ -77,7 +89,10 @@ func NewKafkaClientInstanceWithConfig(config *paramtable.KafkaConfig) *kafkaClie
|
||||
return kafkaConfigMap
|
||||
}
|
||||
|
||||
return NewKafkaClientInstanceWithConfigMap(kafkaConfig, specExtraConfig(config.ConsumerExtraConfig.GetValue()), specExtraConfig(config.ProducerExtraConfig.GetValue()))
|
||||
return NewKafkaClientInstanceWithConfigMap(
|
||||
kafkaConfig,
|
||||
specExtraConfig(config.ConsumerExtraConfig.GetValue()),
|
||||
specExtraConfig(config.ProducerExtraConfig.GetValue())), nil
|
||||
|
||||
}
|
||||
|
||||
|
@ -364,10 +364,10 @@ func createKafkaConfig(opts ...kafkaCfgOption) *paramtable.KafkaConfig {
|
||||
func TestKafkaClient_NewKafkaClientInstanceWithConfig(t *testing.T) {
|
||||
config1 := createKafkaConfig(withAddr("addr"), withPasswd("password"))
|
||||
|
||||
assert.Panics(t, func() { NewKafkaClientInstanceWithConfig(config1) })
|
||||
assert.Panics(t, func() { NewKafkaClientInstanceWithConfig(context.Background(), config1) })
|
||||
|
||||
config2 := createKafkaConfig(withAddr("addr"), withUsername("username"))
|
||||
assert.Panics(t, func() { NewKafkaClientInstanceWithConfig(config2) })
|
||||
assert.Panics(t, func() { NewKafkaClientInstanceWithConfig(context.Background(), config2) })
|
||||
|
||||
producerConfig := make(map[string]string)
|
||||
producerConfig["client.id"] = "dc1"
|
||||
@ -378,7 +378,8 @@ func TestKafkaClient_NewKafkaClientInstanceWithConfig(t *testing.T) {
|
||||
config.ConsumerExtraConfig = paramtable.ParamGroup{GetFunc: func() map[string]string { return consumerConfig }}
|
||||
config.ProducerExtraConfig = paramtable.ParamGroup{GetFunc: func() map[string]string { return producerConfig }}
|
||||
|
||||
client := NewKafkaClientInstanceWithConfig(config)
|
||||
client, err := NewKafkaClientInstanceWithConfig(context.Background(), config)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, client)
|
||||
assert.NotNil(t, client.basicConfig)
|
||||
|
||||
|
@ -17,6 +17,7 @@
|
||||
package nmq
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
@ -40,8 +41,17 @@ type nmqClient struct {
|
||||
|
||||
// NewClientWithDefaultOptions returns a new NMQ client with default options.
|
||||
// It retrieves the NMQ client URL from the server configuration.
|
||||
func NewClientWithDefaultOptions() (mqwrapper.Client, error) {
|
||||
func NewClientWithDefaultOptions(ctx context.Context) (mqwrapper.Client, error) {
|
||||
url := Nmq.ClientURL()
|
||||
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
if deadline.Before(time.Now()) {
|
||||
return nil, errors.New("context timeout when new nmq client")
|
||||
}
|
||||
timeoutOption := nats.Timeout(time.Until(deadline))
|
||||
return NewClient(url, timeoutOption)
|
||||
}
|
||||
|
||||
return NewClient(url)
|
||||
}
|
||||
|
||||
|
@ -38,6 +38,42 @@ func Test_NewNmqClient(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, client)
|
||||
client.Close()
|
||||
|
||||
tests := []struct {
|
||||
description string
|
||||
withTimeout bool
|
||||
ctxTimeouted bool
|
||||
expectErr bool
|
||||
}{
|
||||
{"without context", false, false, false},
|
||||
{"without timeout context, no timeout", true, false, false},
|
||||
{"without timeout context, timeout", true, true, true},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.description, func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
var cancel context.CancelFunc
|
||||
if test.withTimeout {
|
||||
ctx, cancel = context.WithTimeout(ctx, time.Millisecond)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
if test.ctxTimeouted {
|
||||
<-time.After(time.Millisecond)
|
||||
}
|
||||
client, err := NewClientWithDefaultOptions(ctx)
|
||||
|
||||
if test.expectErr {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, client)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, client)
|
||||
client.Close()
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNmqClient_CreateProducer(t *testing.T) {
|
||||
|
@ -62,9 +62,9 @@ type MsgStream interface {
|
||||
GetProduceChannels() []string
|
||||
Broadcast(*MsgPack) (map[string][]MessageID, error)
|
||||
|
||||
AsConsumer(channels []string, subName string, position mqwrapper.SubscriptionInitialPosition)
|
||||
AsConsumer(ctx context.Context, channels []string, subName string, position mqwrapper.SubscriptionInitialPosition) error
|
||||
Chan() <-chan *MsgPack
|
||||
Seek(offset []*MsgPosition) error
|
||||
Seek(ctx context.Context, offset []*MsgPosition) error
|
||||
|
||||
GetLatestMsgID(channel string) (MessageID, error)
|
||||
CheckTopicValid(channel string) error
|
||||
|
@ -25,7 +25,7 @@ func BenchmarkProduceAndConsumeNatsMQ(b *testing.B) {
|
||||
cfg.Opts.StoreDir = storeDir
|
||||
nmq.MustInitNatsMQ(cfg)
|
||||
|
||||
client, err := nmq.NewClientWithDefaultOptions()
|
||||
client, err := nmq.NewClientWithDefaultOptions(context.Background())
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
@ -19,6 +19,7 @@ import (
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
)
|
||||
|
||||
@ -26,8 +27,11 @@ import (
|
||||
// fn is the func to run.
|
||||
// Option can control the retry times and timeout.
|
||||
func Do(ctx context.Context, fn func() error, opts ...Option) error {
|
||||
log := log.Ctx(ctx)
|
||||
if !funcutil.CheckCtxValid(ctx) {
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
log := log.Ctx(ctx)
|
||||
c := newDefaultConfig()
|
||||
|
||||
for _, opt := range opts {
|
||||
@ -52,8 +56,7 @@ func Do(ctx context.Context, fn func() error, opts ...Option) error {
|
||||
select {
|
||||
case <-time.After(c.sleep):
|
||||
case <-ctx.Done():
|
||||
el = merr.Combine(el, errors.Wrapf(ctx.Err(), "context done during sleep after run#%d", i))
|
||||
return el
|
||||
return merr.Combine(el, ctx.Err())
|
||||
}
|
||||
|
||||
c.sleep *= 2
|
||||
|
@ -19,6 +19,7 @@ import (
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/lingdor/stackerror"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
@ -130,5 +131,6 @@ func TestContextCancel(t *testing.T) {
|
||||
|
||||
err := Do(ctx, testFn)
|
||||
assert.Error(t, err)
|
||||
assert.True(t, merr.IsCanceledOrTimeout(err))
|
||||
t.Log(err)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user