From 2aed48c433f47edc1f1bb6bc5e9207f29d87069b Mon Sep 17 00:00:00 2001 From: SimFG Date: Wed, 28 Dec 2022 10:19:31 +0800 Subject: [PATCH] Fix query node panic when watching dm channels (#21402) Signed-off-by: SimFG --- internal/querynode/impl.go | 10 ++++++++-- internal/querynode/impl_test.go | 21 ++++++++++++++++----- internal/util/concurrency/pool.go | 4 ++++ 3 files changed, 28 insertions(+), 7 deletions(-) diff --git a/internal/querynode/impl.go b/internal/querynode/impl.go index 8caa1f9933..db24e62c03 100644 --- a/internal/querynode/impl.go +++ b/internal/querynode/impl.go @@ -377,8 +377,14 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, in *querypb.WatchDmC ErrorCode: commonpb.ErrorCode_Success, }, nil }) - ret, _ := future.Await() - return ret.(*commonpb.Status), nil + ret, err := future.Await() + if status, ok := ret.(*commonpb.Status); ok { + return status, nil + } + log.Warn("fail to convert the *commonpb.Status", zap.Any("ret", ret), zap.Error(err)) + return &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + }, nil } func (node *QueryNode) UnsubDmChannel(ctx context.Context, req *querypb.UnsubDmChannelRequest) (*commonpb.Status, error) { diff --git a/internal/querynode/impl_test.go b/internal/querynode/impl_test.go index 72c7458449..ba94216c0d 100644 --- a/internal/querynode/impl_test.go +++ b/internal/querynode/impl_test.go @@ -20,14 +20,11 @@ import ( "context" "encoding/json" "math/rand" + "runtime" "sync" "sync/atomic" "testing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" - "github.com/milvus-io/milvus-proto/go-api/commonpb" "github.com/milvus-io/milvus-proto/go-api/milvuspb" "github.com/milvus-io/milvus/internal/common" @@ -35,9 +32,14 @@ import ( "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/querypb" queryPb "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/util/concurrency" "github.com/milvus-io/milvus/internal/util/etcd" "github.com/milvus-io/milvus/internal/util/metricsinfo" "github.com/milvus-io/milvus/internal/util/sessionutil" + "github.com/panjf2000/ants/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" ) func TestImpl_GetComponentStates(t *testing.T) { @@ -115,6 +117,16 @@ func TestImpl_WatchDmChannels(t *testing.T) { status, err := node.WatchDmChannels(ctx, req) assert.NoError(t, err) assert.Equal(t, commonpb.ErrorCode_Success, status.ErrorCode) + + originPool := node.taskPool + defer func() { + node.taskPool = originPool + }() + node.taskPool, _ = concurrency.NewPool(runtime.GOMAXPROCS(0), ants.WithPreAlloc(true)) + node.taskPool.Release() + status, err = node.WatchDmChannels(ctx, req) + assert.NoError(t, err) + assert.Equal(t, commonpb.ErrorCode_UnexpectedError, status.ErrorCode) }) t.Run("target not match", func(t *testing.T) { @@ -192,7 +204,6 @@ func TestImpl_WatchDmChannels(t *testing.T) { assert.NoError(t, err) assert.Equal(t, commonpb.ErrorCode_UnexpectedError, status.ErrorCode) }) - } func TestImpl_UnsubDmChannel(t *testing.T) { diff --git a/internal/util/concurrency/pool.go b/internal/util/concurrency/pool.go index 14a6a7642b..41e680bbdd 100644 --- a/internal/util/concurrency/pool.go +++ b/internal/util/concurrency/pool.go @@ -67,3 +67,7 @@ func (pool *Pool) Cap() int { func (pool *Pool) Running() int { return pool.inner.Running() } + +func (pool *Pool) Release() { + pool.inner.Release() +}