mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-12-05 05:18:52 +08:00
8b3e5189e1
Signed-off-by: chasingegg <chao.gao@zilliz.com>
206 lines
6.4 KiB
Go
206 lines
6.4 KiB
Go
package proxy
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/cockroachdb/errors"
|
|
"github.com/golang/protobuf/proto"
|
|
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
|
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
|
|
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
|
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
|
"github.com/milvus-io/milvus/internal/proto/querypb"
|
|
"github.com/milvus-io/milvus/internal/types"
|
|
"github.com/milvus-io/milvus/internal/util/funcutil"
|
|
"github.com/milvus-io/milvus/internal/util/paramtable"
|
|
"github.com/milvus-io/milvus/internal/util/typeutil"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/mock"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func TestStatisticTask_all(t *testing.T) {
|
|
var (
|
|
err error
|
|
ctx = context.TODO()
|
|
|
|
rc = NewRootCoordMock()
|
|
qc = types.NewMockQueryCoord(t)
|
|
qn = types.NewMockQueryNode(t)
|
|
|
|
shardsNum = int32(2)
|
|
collectionName = t.Name() + funcutil.GenRandomStr()
|
|
)
|
|
|
|
successStatus := commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}
|
|
qc.EXPECT().Start().Return(nil)
|
|
qc.EXPECT().Stop().Return(nil)
|
|
qc.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(&successStatus, nil)
|
|
|
|
mockCreator := func(ctx context.Context, address string) (types.QueryNode, error) {
|
|
return qn, nil
|
|
}
|
|
|
|
mgr := newShardClientMgr(withShardClientCreator(mockCreator))
|
|
|
|
rc.Start()
|
|
defer rc.Stop()
|
|
qc.Start()
|
|
defer qc.Stop()
|
|
qc.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).Return(&querypb.GetShardLeadersResponse{
|
|
Status: &successStatus,
|
|
Shards: []*querypb.ShardLeadersList{
|
|
{
|
|
ChannelName: "channel-1",
|
|
NodeIds: []int64{1, 2, 3},
|
|
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
|
|
},
|
|
},
|
|
}, nil)
|
|
|
|
err = InitMetaCache(ctx, rc, qc, mgr)
|
|
assert.NoError(t, err)
|
|
|
|
fieldName2Types := map[string]schemapb.DataType{
|
|
testBoolField: schemapb.DataType_Bool,
|
|
testInt32Field: schemapb.DataType_Int32,
|
|
testInt64Field: schemapb.DataType_Int64,
|
|
testFloatField: schemapb.DataType_Float,
|
|
testDoubleField: schemapb.DataType_Double,
|
|
testFloatVecField: schemapb.DataType_FloatVector,
|
|
}
|
|
if enableMultipleVectorFields {
|
|
fieldName2Types[testBinaryVecField] = schemapb.DataType_BinaryVector
|
|
}
|
|
|
|
schema := constructCollectionSchemaByDataType(collectionName, fieldName2Types, testInt64Field, false)
|
|
marshaledSchema, err := proto.Marshal(schema)
|
|
assert.NoError(t, err)
|
|
|
|
createColT := &createCollectionTask{
|
|
Condition: NewTaskCondition(ctx),
|
|
CreateCollectionRequest: &milvuspb.CreateCollectionRequest{
|
|
CollectionName: collectionName,
|
|
Schema: marshaledSchema,
|
|
ShardsNum: shardsNum,
|
|
},
|
|
ctx: ctx,
|
|
rootCoord: rc,
|
|
}
|
|
|
|
require.NoError(t, createColT.OnEnqueue())
|
|
require.NoError(t, createColT.PreExecute(ctx))
|
|
require.NoError(t, createColT.Execute(ctx))
|
|
require.NoError(t, createColT.PostExecute(ctx))
|
|
|
|
collectionID, err := globalMetaCache.GetCollectionID(ctx, collectionName)
|
|
assert.NoError(t, err)
|
|
|
|
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{
|
|
Status: &successStatus,
|
|
CollectionIDs: []int64{collectionID},
|
|
InMemoryPercentages: []int64{100},
|
|
}, nil)
|
|
|
|
status, err := qc.LoadCollection(ctx, &querypb.LoadCollectionRequest{
|
|
Base: &commonpb.MsgBase{
|
|
MsgType: commonpb.MsgType_LoadCollection,
|
|
SourceID: paramtable.GetNodeID(),
|
|
},
|
|
CollectionID: collectionID,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Equal(t, commonpb.ErrorCode_Success, status.ErrorCode)
|
|
|
|
// test begins
|
|
task := &getStatisticsTask{
|
|
Condition: NewTaskCondition(ctx),
|
|
ctx: ctx,
|
|
result: &milvuspb.GetStatisticsResponse{
|
|
Status: &commonpb.Status{
|
|
ErrorCode: commonpb.ErrorCode_Success,
|
|
},
|
|
},
|
|
request: &milvuspb.GetStatisticsRequest{
|
|
Base: &commonpb.MsgBase{
|
|
MsgType: commonpb.MsgType_Retrieve,
|
|
SourceID: paramtable.GetNodeID(),
|
|
},
|
|
CollectionName: collectionName,
|
|
},
|
|
qc: qc,
|
|
shardMgr: mgr,
|
|
}
|
|
|
|
assert.NoError(t, task.OnEnqueue())
|
|
|
|
qc.EXPECT().ShowPartitions(mock.Anything, mock.Anything).Return(&querypb.ShowPartitionsResponse{
|
|
Status: &commonpb.Status{
|
|
ErrorCode: commonpb.ErrorCode_Success,
|
|
},
|
|
PartitionIDs: []int64{1, 2, 3},
|
|
}, nil)
|
|
|
|
// test query task with timeout
|
|
ctx1, cancel1 := context.WithTimeout(ctx, 10*time.Second)
|
|
defer cancel1()
|
|
// before preExecute
|
|
assert.Equal(t, typeutil.ZeroTimestamp, task.TimeoutTimestamp)
|
|
task.ctx = ctx1
|
|
assert.NoError(t, task.PreExecute(ctx))
|
|
// after preExecute
|
|
assert.Greater(t, task.TimeoutTimestamp, typeutil.ZeroTimestamp)
|
|
|
|
task.ctx = ctx
|
|
task.statisticShardPolicy = func(context.Context, *shardClientMgr, func(context.Context, int64, types.QueryNode, []string, int) error, map[string][]nodeInfo) error {
|
|
return fmt.Errorf("fake error")
|
|
}
|
|
task.fromQueryNode = true
|
|
assert.Error(t, task.Execute(ctx))
|
|
assert.NoError(t, task.PostExecute(ctx))
|
|
|
|
task.statisticShardPolicy = func(context.Context, *shardClientMgr, func(context.Context, int64, types.QueryNode, []string, int) error, map[string][]nodeInfo) error {
|
|
return errInvalidShardLeaders
|
|
}
|
|
task.fromQueryNode = true
|
|
assert.Error(t, task.Execute(ctx))
|
|
assert.NoError(t, task.PostExecute(ctx))
|
|
|
|
task.statisticShardPolicy = mergeRoundRobinPolicy
|
|
task.fromQueryNode = true
|
|
qn.EXPECT().GetStatistics(mock.Anything, mock.Anything).Return(nil, errors.New("GetStatistics failed")).Times(3)
|
|
assert.Error(t, task.Execute(ctx))
|
|
assert.NoError(t, task.PostExecute(ctx))
|
|
|
|
task.statisticShardPolicy = mergeRoundRobinPolicy
|
|
task.fromQueryNode = true
|
|
qn.EXPECT().GetStatistics(mock.Anything, mock.Anything).Return(&internalpb.GetStatisticsResponse{
|
|
Status: &commonpb.Status{
|
|
ErrorCode: commonpb.ErrorCode_NotShardLeader,
|
|
Reason: "error",
|
|
},
|
|
}, nil).Times(6)
|
|
assert.Error(t, task.Execute(ctx))
|
|
assert.NoError(t, task.PostExecute(ctx))
|
|
|
|
task.statisticShardPolicy = mergeRoundRobinPolicy
|
|
task.fromQueryNode = true
|
|
qn.EXPECT().GetStatistics(mock.Anything, mock.Anything).Return(&internalpb.GetStatisticsResponse{
|
|
Status: &commonpb.Status{
|
|
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
|
Reason: "error",
|
|
},
|
|
}, nil).Times(3)
|
|
assert.Error(t, task.Execute(ctx))
|
|
assert.NoError(t, task.PostExecute(ctx))
|
|
|
|
task.statisticShardPolicy = mergeRoundRobinPolicy
|
|
task.fromQueryNode = true
|
|
qn.EXPECT().GetStatistics(mock.Anything, mock.Anything).Return(nil, nil).Once()
|
|
assert.NoError(t, task.Execute(ctx))
|
|
assert.NoError(t, task.PostExecute(ctx))
|
|
}
|