Fix querynode panics when watch/unsub runs concurrently (#20606)

Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>

Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
This commit is contained in:
congqixia 2022-11-15 19:03:08 +08:00 committed by GitHub
parent d8f8296b03
commit ac9a993a39
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 85 additions and 17 deletions

View File

@ -26,6 +26,7 @@ import (
"time"
"github.com/golang/protobuf/proto"
"github.com/samber/lo"
"go.uber.org/zap"
"golang.org/x/sync/errgroup"
@ -34,6 +35,7 @@ import (
"github.com/milvus-io/milvus/internal/common"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/metrics"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/util/metricsinfo"
@ -303,6 +305,14 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, in *querypb.WatchDmC
return status, nil
}
log := log.With(
zap.Int64("collectionID", in.GetCollectionID()),
zap.Int64("nodeID", paramtable.GetNodeID()),
zap.Strings("channels", lo.Map(in.GetInfos(), func(info *datapb.VchannelInfo, _ int) string {
return info.GetChannelName()
})),
)
task := &watchDmChannelsTask{
baseTask: baseTask{
ctx: ctx,
@ -313,13 +323,10 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, in *querypb.WatchDmC
}
startTs := time.Now()
log.Info("watchDmChannels init", zap.Int64("collectionID", in.CollectionID),
zap.String("channelName", in.Infos[0].GetChannelName()),
zap.Int64("nodeID", paramtable.GetNodeID()))
log.Info("watchDmChannels init")
// currently we only support load one channel as a time
future := node.taskPool.Submit(func() (interface{}, error) {
log.Info("watchDmChannels start ", zap.Int64("collectionID", in.CollectionID),
zap.String("channelName", in.Infos[0].GetChannelName()),
log.Info("watchDmChannels start ",
zap.Duration("timeInQueue", time.Since(startTs)))
err := task.PreExecute(ctx)
if err != nil {
@ -337,7 +344,7 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, in *querypb.WatchDmC
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(),
}
log.Warn("failed to subscribe channel ", zap.Error(err))
log.Warn("failed to subscribe channel", zap.Error(err))
return status, nil
}
@ -351,10 +358,7 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, in *querypb.WatchDmC
return status, nil
}
sc, _ := node.ShardClusterService.getShardCluster(in.Infos[0].GetChannelName())
sc.SetupFirstVersion()
log.Info("successfully watchDmChannelsTask", zap.Int64("collectionID", in.CollectionID),
zap.String("channelName", in.Infos[0].GetChannelName()), zap.Int64("nodeID", paramtable.GetNodeID()))
log.Info("successfully watchDmChannelsTask")
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
}, nil

View File

@ -25,6 +25,7 @@ import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
@ -137,10 +138,47 @@ func TestImpl_WatchDmChannels(t *testing.T) {
},
}
node.UpdateStateCode(commonpb.StateCode_Abnormal)
defer node.UpdateStateCode(commonpb.StateCode_Healthy)
status, err := node.WatchDmChannels(ctx, req)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, status.ErrorCode)
})
t.Run("mock release after loaded", func(t *testing.T) {
mockTSReplica := &MockTSafeReplicaInterface{}
oldTSReplica := node.tSafeReplica
defer func() {
node.tSafeReplica = oldTSReplica
}()
node.tSafeReplica = mockTSReplica
mockTSReplica.On("addTSafe", mock.Anything).Run(func(_ mock.Arguments) {
node.ShardClusterService.releaseShardCluster("1001-dmc0")
})
schema := genTestCollectionSchema()
req := &queryPb.WatchDmChannelsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_WatchDmChannels,
MsgID: rand.Int63(),
TargetID: node.session.ServerID,
},
CollectionID: defaultCollectionID,
PartitionIDs: []UniqueID{defaultPartitionID},
Schema: schema,
Infos: []*datapb.VchannelInfo{
{
CollectionID: 1001,
ChannelName: "1001-dmc0",
},
},
}
status, err := node.WatchDmChannels(ctx, req)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, status.ErrorCode)
})
}
func TestImpl_UnsubDmChannel(t *testing.T) {

View File

@ -65,16 +65,19 @@ func (w *watchDmChannelsTask) Execute(ctx context.Context) (err error) {
VPChannels[v] = p
}
log := log.With(
zap.Int64("collectionID", w.req.GetCollectionID()),
zap.Strings("vChannels", vChannels),
zap.Int64("replicaID", w.req.GetReplicaID()),
)
if len(VPChannels) != len(vChannels) {
return errors.New("get physical channels failed, illegal channel length, collectionID = " + fmt.Sprintln(collectionID))
}
log.Info("Starting WatchDmChannels ...",
zap.String("collectionName", w.req.Schema.Name),
zap.Int64("collectionID", collectionID),
zap.Int64("replicaID", w.req.GetReplicaID()),
zap.String("load type", lType.String()),
zap.Strings("vChannels", vChannels),
zap.String("loadType", lType.String()),
zap.String("collectionName", w.req.GetSchema().GetName()),
)
// init collection meta
@ -126,7 +129,7 @@ func (w *watchDmChannelsTask) Execute(ctx context.Context) (err error) {
coll.setLoadType(lType)
log.Info("watchDMChannel, init replica done", zap.Int64("collectionID", collectionID), zap.Strings("vChannels", vChannels))
log.Info("watchDMChannel, init replica done")
// create tSafe
for _, channel := range vChannels {
@ -143,7 +146,30 @@ func (w *watchDmChannelsTask) Execute(ctx context.Context) (err error) {
fg.flowGraph.Start()
}
log.Info("WatchDmChannels done", zap.Int64("collectionID", collectionID), zap.Strings("vChannels", vChannels))
log.Info("WatchDmChannels done")
return nil
}
// PostExecute setup ShardCluster first version and without do gc if failed.
func (w *watchDmChannelsTask) PostExecute(ctx context.Context) error {
// setup shard cluster version
var releasedChannels []string
for _, info := range w.req.GetInfos() {
sc, ok := w.node.ShardClusterService.getShardCluster(info.GetChannelName())
// shard cluster may be released by a release task
if !ok {
releasedChannels = append(releasedChannels, info.GetChannelName())
continue
}
sc.SetupFirstVersion()
}
if len(releasedChannels) > 0 {
// no clean up needed, release shall do the job
log.Warn("WatchDmChannels failed, shard cluster may be released",
zap.Strings("releasedChannels", releasedChannels),
)
return fmt.Errorf("failed to watch %v, shard cluster may be released", releasedChannels)
}
return nil
}