mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-12-05 05:18:52 +08:00
9702cef2b5
issue #25639 Signed-off-by: xige-16 <xi.ge@zilliz.com> Signed-off-by: xige-16 <xi.ge@zilliz.com>
331 lines
9.9 KiB
Go
331 lines
9.9 KiB
Go
package proxy
|
|
|
|
import (
|
|
"context"
|
|
"strconv"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/cockroachdb/errors"
|
|
"github.com/golang/protobuf/proto"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/mock"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
|
"github.com/milvus-io/milvus/internal/mocks"
|
|
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
|
"github.com/milvus-io/milvus/internal/proto/querypb"
|
|
"github.com/milvus-io/milvus/internal/types"
|
|
"github.com/milvus-io/milvus/internal/util/dependency"
|
|
"github.com/milvus-io/milvus/pkg/common"
|
|
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
|
"github.com/milvus-io/milvus/pkg/util/merr"
|
|
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
|
"github.com/milvus-io/milvus/pkg/util/timerecord"
|
|
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
|
)
|
|
|
|
func createCollWithMultiVecField(t *testing.T, name string, rc types.RootCoordClient) {
|
|
schema := genCollectionSchema(name)
|
|
marshaledSchema, err := proto.Marshal(schema)
|
|
require.NoError(t, err)
|
|
ctx := context.TODO()
|
|
|
|
createColT := &createCollectionTask{
|
|
Condition: NewTaskCondition(context.TODO()),
|
|
CreateCollectionRequest: &milvuspb.CreateCollectionRequest{
|
|
CollectionName: name,
|
|
Schema: marshaledSchema,
|
|
ShardsNum: common.DefaultShardsNum,
|
|
},
|
|
ctx: ctx,
|
|
rootCoord: rc,
|
|
}
|
|
|
|
require.NoError(t, createColT.OnEnqueue())
|
|
require.NoError(t, createColT.PreExecute(ctx))
|
|
require.NoError(t, createColT.Execute(ctx))
|
|
require.NoError(t, createColT.PostExecute(ctx))
|
|
}
|
|
|
|
func TestHybridSearchTask_PreExecute(t *testing.T) {
|
|
var err error
|
|
|
|
var (
|
|
rc = NewRootCoordMock()
|
|
qc = mocks.NewMockQueryCoordClient(t)
|
|
ctx = context.TODO()
|
|
)
|
|
|
|
defer rc.Close()
|
|
require.NoError(t, err)
|
|
mgr := newShardClientMgr()
|
|
err = InitMetaCache(ctx, rc, qc, mgr)
|
|
require.NoError(t, err)
|
|
|
|
genHybridSearchTaskWithNq := func(t *testing.T, collName string, reqs []*milvuspb.SearchRequest) *hybridSearchTask {
|
|
task := &hybridSearchTask{
|
|
ctx: ctx,
|
|
Condition: NewTaskCondition(ctx),
|
|
request: &milvuspb.HybridSearchRequest{
|
|
CollectionName: collName,
|
|
Requests: reqs,
|
|
},
|
|
qc: qc,
|
|
tr: timerecord.NewTimeRecorder("test-hybrid-search"),
|
|
}
|
|
require.NoError(t, task.OnEnqueue())
|
|
return task
|
|
}
|
|
|
|
t.Run("bad nq 0", func(t *testing.T) {
|
|
collName := "test_bad_nq0_error" + funcutil.GenRandomStr()
|
|
createCollWithMultiVecField(t, collName, rc)
|
|
// Nq must be 1.
|
|
task := genHybridSearchTaskWithNq(t, collName, []*milvuspb.SearchRequest{{Nq: 0}})
|
|
err = task.PreExecute(ctx)
|
|
assert.Error(t, err)
|
|
})
|
|
|
|
t.Run("bad req num 0", func(t *testing.T) {
|
|
collName := "test_bad_req_num0_error" + funcutil.GenRandomStr()
|
|
createCollWithMultiVecField(t, collName, rc)
|
|
// num of reqs must be [1, 1024].
|
|
task := genHybridSearchTaskWithNq(t, collName, nil)
|
|
err = task.PreExecute(ctx)
|
|
assert.Error(t, err)
|
|
})
|
|
|
|
t.Run("bad req num 1025", func(t *testing.T) {
|
|
collName := "test_bad_req_num1025_error" + funcutil.GenRandomStr()
|
|
createCollWithMultiVecField(t, collName, rc)
|
|
// num of reqs must be [1, 1024].
|
|
reqs := make([]*milvuspb.SearchRequest, 0)
|
|
for i := 0; i <= defaultMaxSearchRequest; i++ {
|
|
reqs = append(reqs, &milvuspb.SearchRequest{
|
|
CollectionName: collName,
|
|
Nq: 1,
|
|
})
|
|
}
|
|
task := genHybridSearchTaskWithNq(t, collName, reqs)
|
|
err = task.PreExecute(ctx)
|
|
assert.Error(t, err)
|
|
})
|
|
|
|
t.Run("collection not exist", func(t *testing.T) {
|
|
collName := "test_collection_not_exist" + funcutil.GenRandomStr()
|
|
task := genHybridSearchTaskWithNq(t, collName, []*milvuspb.SearchRequest{{Nq: 1}})
|
|
err = task.PreExecute(ctx)
|
|
assert.Error(t, err)
|
|
})
|
|
|
|
t.Run("hybrid search with timeout", func(t *testing.T) {
|
|
collName := "hybrid_search_with_timeout" + funcutil.GenRandomStr()
|
|
createCollWithMultiVecField(t, collName, rc)
|
|
|
|
task := genHybridSearchTaskWithNq(t, collName, []*milvuspb.SearchRequest{{Nq: 1}})
|
|
|
|
ctxTimeout, cancel := context.WithTimeout(ctx, time.Second)
|
|
defer cancel()
|
|
|
|
task.ctx = ctxTimeout
|
|
task.request.OutputFields = []string{testFloatVecField}
|
|
assert.NoError(t, task.PreExecute(ctx))
|
|
})
|
|
}
|
|
|
|
func TestHybridSearchTask_ErrExecute(t *testing.T) {
|
|
var (
|
|
err error
|
|
ctx = context.TODO()
|
|
|
|
rc = NewRootCoordMock()
|
|
qc = getQueryCoordClient()
|
|
qn = getQueryNodeClient()
|
|
|
|
collectionName = t.Name() + funcutil.GenRandomStr()
|
|
)
|
|
|
|
qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe()
|
|
|
|
mgr := NewMockShardClientManager(t)
|
|
mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(qn, nil).Maybe()
|
|
mgr.EXPECT().UpdateShardLeaders(mock.Anything, mock.Anything).Return(nil).Maybe()
|
|
lb := NewLBPolicyImpl(mgr)
|
|
|
|
factory := dependency.NewDefaultFactory(true)
|
|
node, err := NewProxy(ctx, factory)
|
|
assert.NoError(t, err)
|
|
node.UpdateStateCode(commonpb.StateCode_Healthy)
|
|
node.tsoAllocator = ×tampAllocator{
|
|
tso: newMockTimestampAllocatorInterface(),
|
|
}
|
|
scheduler, err := newTaskScheduler(ctx, node.tsoAllocator, factory)
|
|
assert.NoError(t, err)
|
|
node.sched = scheduler
|
|
err = node.sched.Start()
|
|
assert.NoError(t, err)
|
|
err = node.initRateCollector()
|
|
assert.NoError(t, err)
|
|
node.rootCoord = rc
|
|
node.queryCoord = qc
|
|
|
|
defer qc.Close()
|
|
|
|
err = InitMetaCache(ctx, rc, qc, mgr)
|
|
assert.NoError(t, err)
|
|
|
|
createCollWithMultiVecField(t, collectionName, rc)
|
|
|
|
collectionID, err := globalMetaCache.GetCollectionID(ctx, GetCurDBNameFromContextOrDefault(ctx), collectionName)
|
|
assert.NoError(t, err)
|
|
|
|
schema, err := globalMetaCache.GetCollectionSchema(ctx, GetCurDBNameFromContextOrDefault(ctx), collectionName)
|
|
assert.NoError(t, err)
|
|
|
|
successStatus := &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}
|
|
qc.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(successStatus, nil)
|
|
qc.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).Return(&querypb.GetShardLeadersResponse{
|
|
Status: successStatus,
|
|
Shards: []*querypb.ShardLeadersList{
|
|
{
|
|
ChannelName: "channel-1",
|
|
NodeIds: []int64{1},
|
|
NodeAddrs: []string{"localhost:9000"},
|
|
},
|
|
},
|
|
}, nil)
|
|
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{
|
|
Status: successStatus,
|
|
CollectionIDs: []int64{collectionID},
|
|
InMemoryPercentages: []int64{100},
|
|
}, nil)
|
|
status, err := qc.LoadCollection(ctx, &querypb.LoadCollectionRequest{
|
|
Base: &commonpb.MsgBase{
|
|
MsgType: commonpb.MsgType_LoadCollection,
|
|
SourceID: paramtable.GetNodeID(),
|
|
},
|
|
CollectionID: collectionID,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Equal(t, commonpb.ErrorCode_Success, status.ErrorCode)
|
|
|
|
vectorFields := typeutil.GetVectorFieldSchemas(schema.CollectionSchema)
|
|
vectorFieldNames := make([]string, len(vectorFields))
|
|
for i, field := range vectorFields {
|
|
vectorFieldNames[i] = field.GetName()
|
|
}
|
|
|
|
// test begins
|
|
task := &hybridSearchTask{
|
|
Condition: NewTaskCondition(ctx),
|
|
ctx: ctx,
|
|
result: &milvuspb.SearchResults{
|
|
Status: merr.Success(),
|
|
},
|
|
request: &milvuspb.HybridSearchRequest{
|
|
CollectionName: collectionName,
|
|
Requests: []*milvuspb.SearchRequest{
|
|
{
|
|
Base: &commonpb.MsgBase{
|
|
MsgType: commonpb.MsgType_Search,
|
|
SourceID: paramtable.GetNodeID(),
|
|
},
|
|
CollectionName: collectionName,
|
|
Nq: 1,
|
|
DslType: commonpb.DslType_BoolExprV1,
|
|
SearchParams: []*commonpb.KeyValuePair{
|
|
{Key: AnnsFieldKey, Value: testFloatVecField},
|
|
{Key: TopKKey, Value: "10"},
|
|
},
|
|
},
|
|
{
|
|
Base: &commonpb.MsgBase{
|
|
MsgType: commonpb.MsgType_Search,
|
|
SourceID: paramtable.GetNodeID(),
|
|
},
|
|
CollectionName: collectionName,
|
|
Nq: 1,
|
|
DslType: commonpb.DslType_BoolExprV1,
|
|
SearchParams: []*commonpb.KeyValuePair{
|
|
{Key: AnnsFieldKey, Value: testBinaryVecField},
|
|
{Key: TopKKey, Value: "10"},
|
|
},
|
|
},
|
|
},
|
|
OutputFields: vectorFieldNames,
|
|
},
|
|
qc: qc,
|
|
lb: lb,
|
|
node: node,
|
|
}
|
|
|
|
assert.NoError(t, task.OnEnqueue())
|
|
task.ctx = ctx
|
|
assert.NoError(t, task.PreExecute(ctx))
|
|
|
|
qn.EXPECT().Search(mock.Anything, mock.Anything).Return(nil, errors.New("mock error"))
|
|
assert.Error(t, task.Execute(ctx))
|
|
|
|
qn.ExpectedCalls = nil
|
|
qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe()
|
|
qn.EXPECT().Search(mock.Anything, mock.Anything).Return(&internalpb.SearchResults{
|
|
Status: &commonpb.Status{
|
|
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
|
},
|
|
}, nil)
|
|
assert.Error(t, task.Execute(ctx))
|
|
}
|
|
|
|
func TestHybridSearchTask_PostExecute(t *testing.T) {
|
|
var (
|
|
rc = NewRootCoordMock()
|
|
qc = getQueryCoordClient()
|
|
qn = getQueryNodeClient()
|
|
collectionName = t.Name() + funcutil.GenRandomStr()
|
|
)
|
|
|
|
defer rc.Close()
|
|
mgr := NewMockShardClientManager(t)
|
|
mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(qn, nil).Maybe()
|
|
mgr.EXPECT().UpdateShardLeaders(mock.Anything, mock.Anything).Return(nil).Maybe()
|
|
|
|
t.Run("Test empty result", func(t *testing.T) {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
|
|
err := InitMetaCache(ctx, rc, qc, mgr)
|
|
assert.NoError(t, err)
|
|
createCollWithMultiVecField(t, collectionName, rc)
|
|
|
|
schema, err := globalMetaCache.GetCollectionSchema(ctx, GetCurDBNameFromContextOrDefault(ctx), collectionName)
|
|
assert.NoError(t, err)
|
|
|
|
rankParams := []*commonpb.KeyValuePair{
|
|
{Key: LimitKey, Value: strconv.Itoa(3)},
|
|
{Key: OffsetKey, Value: strconv.Itoa(2)},
|
|
}
|
|
qt := &hybridSearchTask{
|
|
ctx: ctx,
|
|
Condition: NewTaskCondition(context.TODO()),
|
|
qc: nil,
|
|
tr: timerecord.NewTimeRecorder("search"),
|
|
schema: schema,
|
|
request: &milvuspb.HybridSearchRequest{
|
|
Base: &commonpb.MsgBase{
|
|
MsgType: commonpb.MsgType_Search,
|
|
},
|
|
CollectionName: collectionName,
|
|
RankParams: rankParams,
|
|
},
|
|
multipleRecallResults: typeutil.NewConcurrentSet[*milvuspb.SearchResults](),
|
|
}
|
|
|
|
err = qt.PostExecute(context.TODO())
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, qt.result.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success)
|
|
})
|
|
}
|