diff --git a/internal/proxy/task_query.go b/internal/proxy/task_query.go index 979d23263c..79c47684d2 100644 --- a/internal/proxy/task_query.go +++ b/internal/proxy/task_query.go @@ -18,6 +18,8 @@ import ( "github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/metrics" "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/internal/util/funcutil" + "github.com/milvus-io/milvus/internal/util/grpcclient" "github.com/milvus-io/milvus/internal/util/timerecord" "github.com/milvus-io/milvus/internal/util/tsoutil" "github.com/milvus-io/milvus/internal/util/typeutil" @@ -266,7 +268,7 @@ func (t *queryTask) Execute(ctx context.Context) error { } err := executeQuery(WithCache) - if err == errInvalidShardLeaders { + if errors.Is(err, errInvalidShardLeaders) || funcutil.IsGrpcErr(err) || errors.Is(err, grpcclient.ErrConnect) { log.Warn("invalid shard leaders cache, updating shardleader caches and retry search") return executeQuery(WithoutCache) } @@ -357,9 +359,13 @@ func (t *queryTask) queryShard(ctx context.Context, leaders []nodeInfo, channelI } result, err := qn.Query(ctx, req) - if err != nil || result.GetStatus().GetErrorCode() == commonpb.ErrorCode_NotShardLeader { - log.Warn("QueryNode query returns error", zap.Int64("nodeID", nodeID), + if err != nil { + log.Warn("QueryNode query return error", zap.Int64("nodeID", nodeID), zap.String("channel", channelID), zap.Error(err)) + return err + } + if result.GetStatus().GetErrorCode() == commonpb.ErrorCode_NotShardLeader { + log.Warn("QueryNode is not shardLeader", zap.Int64("nodeID", nodeID), zap.String("channel", channelID)) return errInvalidShardLeaders } if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { diff --git a/internal/proxy/task_search.go b/internal/proxy/task_search.go index 036801d7d3..63b40c0641 100644 --- a/internal/proxy/task_search.go +++ b/internal/proxy/task_search.go @@ -20,6 +20,7 @@ import ( "github.com/milvus-io/milvus/internal/util/distance" "github.com/milvus-io/milvus/internal/util/funcutil" + "github.com/milvus-io/milvus/internal/util/grpcclient" "github.com/milvus-io/milvus/internal/util/timerecord" "github.com/milvus-io/milvus/internal/util/trace" "github.com/milvus-io/milvus/internal/util/tsoutil" @@ -302,7 +303,7 @@ func (t *searchTask) Execute(ctx context.Context) error { } err := executeSearch(WithCache) - if err == errInvalidShardLeaders || funcutil.IsGrpcErr(err) { + if errors.Is(err, errInvalidShardLeaders) || funcutil.IsGrpcErr(err) || errors.Is(err, grpcclient.ErrConnect) { log.Warn("first search failed, updating shardleader caches and retry search", zap.Error(err)) return executeSearch(WithoutCache) } @@ -415,11 +416,12 @@ func (t *searchTask) searchShard(ctx context.Context, leaders []nodeInfo, channe } result, err := qn.Search(ctx, req) if err != nil { + log.Warn("QueryNode search return error", zap.Int64("nodeID", nodeID), zap.String("channel", channelID), + zap.Error(err)) return err } if result.GetStatus().GetErrorCode() == commonpb.ErrorCode_NotShardLeader { - log.Warn("QueryNode search returns error", zap.Int64("nodeID", nodeID), - zap.Error(err)) + log.Warn("QueryNode is not shardLeader", zap.Int64("nodeID", nodeID), zap.String("channel", channelID)) return errInvalidShardLeaders } if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { diff --git a/internal/util/funcutil/func.go b/internal/util/funcutil/func.go index 101609bff7..70b3855e9f 100644 --- a/internal/util/funcutil/func.go +++ b/internal/util/funcutil/func.go @@ -355,6 +355,7 @@ func ReadBinary(endian binary.ByteOrder, bs []byte, receiver interface{}) error return binary.Read(buf, endian, receiver) } +// IsGrpcErr checks whether err is instance of grpc status error. func IsGrpcErr(err error) bool { if err == nil { return false diff --git a/internal/util/grpcclient/client.go b/internal/util/grpcclient/client.go index 42d62d8862..f4a5fb6763 100644 --- a/internal/util/grpcclient/client.go +++ b/internal/util/grpcclient/client.go @@ -181,7 +181,7 @@ func (c *ClientBase) connect(ctx context.Context) error { ) cancel() if err != nil { - return err + return wrapErrConnect(addr, err) } if c.conn != nil { _ = c.conn.Close() diff --git a/internal/util/grpcclient/client_test.go b/internal/util/grpcclient/client_test.go index 486acd03b9..a7bfb0f0fc 100644 --- a/internal/util/grpcclient/client_test.go +++ b/internal/util/grpcclient/client_test.go @@ -17,7 +17,10 @@ package grpcclient import ( + "context" + "errors" "testing" + "time" "github.com/stretchr/testify/assert" ) @@ -33,3 +36,17 @@ func TestClientBase_GetRole(t *testing.T) { base := ClientBase{} assert.Equal(t, "", base.GetRole()) } + +func TestClientBase_connect(t *testing.T) { + t.Run("failed to connect", func(t *testing.T) { + base := ClientBase{ + getAddrFunc: func() (string, error) { + return "", nil + }, + DialTimeout: time.Millisecond, + } + err := base.connect(context.Background()) + assert.Error(t, err) + assert.True(t, errors.Is(err, ErrConnect)) + }) +} diff --git a/internal/util/grpcclient/errors.go b/internal/util/grpcclient/errors.go new file mode 100644 index 0000000000..5f831e6fcf --- /dev/null +++ b/internal/util/grpcclient/errors.go @@ -0,0 +1,50 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package grpcclient + +import ( + "errors" + "fmt" +) + +// ErrConnect is the instance for errors.Is target usage. +var ErrConnect errConnect + +// make sure ErrConnect implements error. +var _ error = errConnect{} + +// errConnect error instance returned when dial error returned. +type errConnect struct { + addr string + err error +} + +// Error implements error interface. +func (e errConnect) Error() string { + return fmt.Sprintf("failed to connect %s, reason: %s", e.addr, e.err.Error()) +} + +// Is checks err is ErrConnect to make errors.Is work. +func (e errConnect) Is(err error) bool { + var ce errConnect + return errors.As(err, &ce) +} + +// wrapErrConnect wrap connection error and related address to ErrConnect. +func wrapErrConnect(addr string, err error) error { + return errConnect{addr: addr, err: err} +}