mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-12-02 11:59:00 +08:00
Fix querynode panics when watch/unsub runs concurrently (#20606)
Signed-off-by: Congqi Xia <congqi.xia@zilliz.com> Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
This commit is contained in:
parent
d8f8296b03
commit
ac9a993a39
@ -26,6 +26,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/samber/lo"
|
||||
"go.uber.org/zap"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
@ -34,6 +35,7 @@ import (
|
||||
"github.com/milvus-io/milvus/internal/common"
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
"github.com/milvus-io/milvus/internal/metrics"
|
||||
"github.com/milvus-io/milvus/internal/proto/datapb"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/util/metricsinfo"
|
||||
@ -303,6 +305,14 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, in *querypb.WatchDmC
|
||||
return status, nil
|
||||
}
|
||||
|
||||
log := log.With(
|
||||
zap.Int64("collectionID", in.GetCollectionID()),
|
||||
zap.Int64("nodeID", paramtable.GetNodeID()),
|
||||
zap.Strings("channels", lo.Map(in.GetInfos(), func(info *datapb.VchannelInfo, _ int) string {
|
||||
return info.GetChannelName()
|
||||
})),
|
||||
)
|
||||
|
||||
task := &watchDmChannelsTask{
|
||||
baseTask: baseTask{
|
||||
ctx: ctx,
|
||||
@ -313,13 +323,10 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, in *querypb.WatchDmC
|
||||
}
|
||||
|
||||
startTs := time.Now()
|
||||
log.Info("watchDmChannels init", zap.Int64("collectionID", in.CollectionID),
|
||||
zap.String("channelName", in.Infos[0].GetChannelName()),
|
||||
zap.Int64("nodeID", paramtable.GetNodeID()))
|
||||
log.Info("watchDmChannels init")
|
||||
// currently we only support load one channel as a time
|
||||
future := node.taskPool.Submit(func() (interface{}, error) {
|
||||
log.Info("watchDmChannels start ", zap.Int64("collectionID", in.CollectionID),
|
||||
zap.String("channelName", in.Infos[0].GetChannelName()),
|
||||
log.Info("watchDmChannels start ",
|
||||
zap.Duration("timeInQueue", time.Since(startTs)))
|
||||
err := task.PreExecute(ctx)
|
||||
if err != nil {
|
||||
@ -337,7 +344,7 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, in *querypb.WatchDmC
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
}
|
||||
log.Warn("failed to subscribe channel ", zap.Error(err))
|
||||
log.Warn("failed to subscribe channel", zap.Error(err))
|
||||
return status, nil
|
||||
}
|
||||
|
||||
@ -351,10 +358,7 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, in *querypb.WatchDmC
|
||||
return status, nil
|
||||
}
|
||||
|
||||
sc, _ := node.ShardClusterService.getShardCluster(in.Infos[0].GetChannelName())
|
||||
sc.SetupFirstVersion()
|
||||
log.Info("successfully watchDmChannelsTask", zap.Int64("collectionID", in.CollectionID),
|
||||
zap.String("channelName", in.Infos[0].GetChannelName()), zap.Int64("nodeID", paramtable.GetNodeID()))
|
||||
log.Info("successfully watchDmChannelsTask")
|
||||
return &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
}, nil
|
||||
|
@ -25,6 +25,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
||||
@ -137,10 +138,47 @@ func TestImpl_WatchDmChannels(t *testing.T) {
|
||||
},
|
||||
}
|
||||
node.UpdateStateCode(commonpb.StateCode_Abnormal)
|
||||
defer node.UpdateStateCode(commonpb.StateCode_Healthy)
|
||||
status, err := node.WatchDmChannels(ctx, req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, status.ErrorCode)
|
||||
})
|
||||
|
||||
t.Run("mock release after loaded", func(t *testing.T) {
|
||||
|
||||
mockTSReplica := &MockTSafeReplicaInterface{}
|
||||
|
||||
oldTSReplica := node.tSafeReplica
|
||||
defer func() {
|
||||
node.tSafeReplica = oldTSReplica
|
||||
}()
|
||||
node.tSafeReplica = mockTSReplica
|
||||
mockTSReplica.On("addTSafe", mock.Anything).Run(func(_ mock.Arguments) {
|
||||
node.ShardClusterService.releaseShardCluster("1001-dmc0")
|
||||
})
|
||||
schema := genTestCollectionSchema()
|
||||
req := &queryPb.WatchDmChannelsRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_WatchDmChannels,
|
||||
MsgID: rand.Int63(),
|
||||
TargetID: node.session.ServerID,
|
||||
},
|
||||
CollectionID: defaultCollectionID,
|
||||
PartitionIDs: []UniqueID{defaultPartitionID},
|
||||
Schema: schema,
|
||||
Infos: []*datapb.VchannelInfo{
|
||||
{
|
||||
CollectionID: 1001,
|
||||
ChannelName: "1001-dmc0",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
status, err := node.WatchDmChannels(ctx, req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, status.ErrorCode)
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func TestImpl_UnsubDmChannel(t *testing.T) {
|
||||
|
@ -65,16 +65,19 @@ func (w *watchDmChannelsTask) Execute(ctx context.Context) (err error) {
|
||||
VPChannels[v] = p
|
||||
}
|
||||
|
||||
log := log.With(
|
||||
zap.Int64("collectionID", w.req.GetCollectionID()),
|
||||
zap.Strings("vChannels", vChannels),
|
||||
zap.Int64("replicaID", w.req.GetReplicaID()),
|
||||
)
|
||||
|
||||
if len(VPChannels) != len(vChannels) {
|
||||
return errors.New("get physical channels failed, illegal channel length, collectionID = " + fmt.Sprintln(collectionID))
|
||||
}
|
||||
|
||||
log.Info("Starting WatchDmChannels ...",
|
||||
zap.String("collectionName", w.req.Schema.Name),
|
||||
zap.Int64("collectionID", collectionID),
|
||||
zap.Int64("replicaID", w.req.GetReplicaID()),
|
||||
zap.String("load type", lType.String()),
|
||||
zap.Strings("vChannels", vChannels),
|
||||
zap.String("loadType", lType.String()),
|
||||
zap.String("collectionName", w.req.GetSchema().GetName()),
|
||||
)
|
||||
|
||||
// init collection meta
|
||||
@ -126,7 +129,7 @@ func (w *watchDmChannelsTask) Execute(ctx context.Context) (err error) {
|
||||
|
||||
coll.setLoadType(lType)
|
||||
|
||||
log.Info("watchDMChannel, init replica done", zap.Int64("collectionID", collectionID), zap.Strings("vChannels", vChannels))
|
||||
log.Info("watchDMChannel, init replica done")
|
||||
|
||||
// create tSafe
|
||||
for _, channel := range vChannels {
|
||||
@ -143,7 +146,30 @@ func (w *watchDmChannelsTask) Execute(ctx context.Context) (err error) {
|
||||
fg.flowGraph.Start()
|
||||
}
|
||||
|
||||
log.Info("WatchDmChannels done", zap.Int64("collectionID", collectionID), zap.Strings("vChannels", vChannels))
|
||||
log.Info("WatchDmChannels done")
|
||||
return nil
|
||||
}
|
||||
|
||||
// PostExecute setup ShardCluster first version and without do gc if failed.
|
||||
func (w *watchDmChannelsTask) PostExecute(ctx context.Context) error {
|
||||
// setup shard cluster version
|
||||
var releasedChannels []string
|
||||
for _, info := range w.req.GetInfos() {
|
||||
sc, ok := w.node.ShardClusterService.getShardCluster(info.GetChannelName())
|
||||
// shard cluster may be released by a release task
|
||||
if !ok {
|
||||
releasedChannels = append(releasedChannels, info.GetChannelName())
|
||||
continue
|
||||
}
|
||||
sc.SetupFirstVersion()
|
||||
}
|
||||
if len(releasedChannels) > 0 {
|
||||
// no clean up needed, release shall do the job
|
||||
log.Warn("WatchDmChannels failed, shard cluster may be released",
|
||||
zap.Strings("releasedChannels", releasedChannels),
|
||||
)
|
||||
return fmt.Errorf("failed to watch %v, shard cluster may be released", releasedChannels)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user