mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-12-05 05:18:52 +08:00
244d2c04f6
Related to #31293 Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
252 lines
6.6 KiB
Go
252 lines
6.6 KiB
Go
package client
|
|
|
|
import (
|
|
"context"
|
|
"math/rand"
|
|
"net"
|
|
"strings"
|
|
|
|
mock "github.com/stretchr/testify/mock"
|
|
"github.com/stretchr/testify/suite"
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/credentials/insecure"
|
|
"google.golang.org/grpc/test/bufconn"
|
|
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
|
"github.com/milvus-io/milvus/client/v2/entity"
|
|
)
|
|
|
|
const (
|
|
bufSize = 1024 * 1024
|
|
)
|
|
|
|
type MockSuiteBase struct {
|
|
suite.Suite
|
|
|
|
lis *bufconn.Listener
|
|
svr *grpc.Server
|
|
mock *MilvusServiceServer
|
|
|
|
client *Client
|
|
}
|
|
|
|
func (s *MockSuiteBase) SetupSuite() {
|
|
s.lis = bufconn.Listen(bufSize)
|
|
s.svr = grpc.NewServer()
|
|
|
|
s.mock = &MilvusServiceServer{}
|
|
|
|
milvuspb.RegisterMilvusServiceServer(s.svr, s.mock)
|
|
|
|
go func() {
|
|
s.T().Log("start mock server")
|
|
if err := s.svr.Serve(s.lis); err != nil {
|
|
s.Fail("failed to start mock server", err.Error())
|
|
}
|
|
}()
|
|
s.setupConnect()
|
|
}
|
|
|
|
func (s *MockSuiteBase) TearDownSuite() {
|
|
s.svr.Stop()
|
|
s.lis.Close()
|
|
}
|
|
|
|
func (s *MockSuiteBase) mockDialer(context.Context, string) (net.Conn, error) {
|
|
return s.lis.Dial()
|
|
}
|
|
|
|
func (s *MockSuiteBase) SetupTest() {
|
|
c, err := New(context.Background(), &ClientConfig{
|
|
Address: "bufnet",
|
|
DialOptions: []grpc.DialOption{
|
|
grpc.WithBlock(),
|
|
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
|
grpc.WithContextDialer(s.mockDialer),
|
|
},
|
|
})
|
|
s.Require().NoError(err)
|
|
s.setupConnect()
|
|
|
|
s.client = c
|
|
}
|
|
|
|
func (s *MockSuiteBase) TearDownTest() {
|
|
s.client.Close(context.Background())
|
|
s.client = nil
|
|
}
|
|
|
|
func (s *MockSuiteBase) resetMock() {
|
|
// MetaCache.reset()
|
|
if s.mock != nil {
|
|
s.mock.Calls = nil
|
|
s.mock.ExpectedCalls = nil
|
|
s.setupConnect()
|
|
}
|
|
}
|
|
|
|
func (s *MockSuiteBase) setupConnect() {
|
|
s.mock.EXPECT().Connect(mock.Anything, mock.AnythingOfType("*milvuspb.ConnectRequest")).
|
|
Return(&milvuspb.ConnectResponse{
|
|
Status: &commonpb.Status{},
|
|
Identifier: 1,
|
|
}, nil).Maybe()
|
|
}
|
|
|
|
func (s *MockSuiteBase) setupCache(collName string, schema *entity.Schema) {
|
|
s.client.collCache.collections.Insert(collName, &entity.Collection{
|
|
Name: collName,
|
|
Schema: schema,
|
|
})
|
|
}
|
|
|
|
func (s *MockSuiteBase) setupHasCollection(collNames ...string) {
|
|
s.mock.EXPECT().HasCollection(mock.Anything, mock.AnythingOfType("*milvuspb.HasCollectionRequest")).
|
|
Call.Return(func(ctx context.Context, req *milvuspb.HasCollectionRequest) *milvuspb.BoolResponse {
|
|
resp := &milvuspb.BoolResponse{Status: &commonpb.Status{}}
|
|
for _, collName := range collNames {
|
|
if req.GetCollectionName() == collName {
|
|
resp.Value = true
|
|
break
|
|
}
|
|
}
|
|
return resp
|
|
}, nil)
|
|
}
|
|
|
|
func (s *MockSuiteBase) setupHasCollectionError(errorCode commonpb.ErrorCode, err error) {
|
|
s.mock.EXPECT().HasCollection(mock.Anything, mock.AnythingOfType("*milvuspb.HasCollectionRequest")).
|
|
Return(&milvuspb.BoolResponse{
|
|
Status: &commonpb.Status{ErrorCode: errorCode},
|
|
}, err)
|
|
}
|
|
|
|
func (s *MockSuiteBase) setupHasPartition(collName string, partNames ...string) {
|
|
s.mock.EXPECT().HasPartition(mock.Anything, mock.AnythingOfType("*milvuspb.HasPartitionRequest")).
|
|
Call.Return(func(ctx context.Context, req *milvuspb.HasPartitionRequest) *milvuspb.BoolResponse {
|
|
resp := &milvuspb.BoolResponse{Status: &commonpb.Status{}}
|
|
if req.GetCollectionName() == collName {
|
|
for _, partName := range partNames {
|
|
if req.GetPartitionName() == partName {
|
|
resp.Value = true
|
|
break
|
|
}
|
|
}
|
|
}
|
|
return resp
|
|
}, nil)
|
|
}
|
|
|
|
func (s *MockSuiteBase) setupHasPartitionError(errorCode commonpb.ErrorCode, err error) {
|
|
s.mock.EXPECT().HasPartition(mock.Anything, mock.AnythingOfType("*milvuspb.HasPartitionRequest")).
|
|
Return(&milvuspb.BoolResponse{
|
|
Status: &commonpb.Status{ErrorCode: errorCode},
|
|
}, err)
|
|
}
|
|
|
|
func (s *MockSuiteBase) setupDescribeCollection(_ string, schema *entity.Schema) {
|
|
s.mock.EXPECT().DescribeCollection(mock.Anything, mock.AnythingOfType("*milvuspb.DescribeCollectionRequest")).
|
|
Call.Return(func(ctx context.Context, req *milvuspb.DescribeCollectionRequest) *milvuspb.DescribeCollectionResponse {
|
|
return &milvuspb.DescribeCollectionResponse{
|
|
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
|
|
Schema: schema.ProtoMessage(),
|
|
}
|
|
}, nil)
|
|
}
|
|
|
|
func (s *MockSuiteBase) setupDescribeCollectionError(errorCode commonpb.ErrorCode, err error) {
|
|
s.mock.EXPECT().DescribeCollection(mock.Anything, mock.AnythingOfType("*milvuspb.DescribeCollectionRequest")).
|
|
Return(&milvuspb.DescribeCollectionResponse{
|
|
Status: &commonpb.Status{ErrorCode: errorCode},
|
|
}, err)
|
|
}
|
|
|
|
func (s *MockSuiteBase) getInt64FieldData(name string, data []int64) *schemapb.FieldData {
|
|
return &schemapb.FieldData{
|
|
Type: schemapb.DataType_Int64,
|
|
FieldName: name,
|
|
Field: &schemapb.FieldData_Scalars{
|
|
Scalars: &schemapb.ScalarField{
|
|
Data: &schemapb.ScalarField_LongData{
|
|
LongData: &schemapb.LongArray{
|
|
Data: data,
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
func (s *MockSuiteBase) getVarcharFieldData(name string, data []string) *schemapb.FieldData {
|
|
return &schemapb.FieldData{
|
|
Type: schemapb.DataType_VarChar,
|
|
FieldName: name,
|
|
Field: &schemapb.FieldData_Scalars{
|
|
Scalars: &schemapb.ScalarField{
|
|
Data: &schemapb.ScalarField_StringData{
|
|
StringData: &schemapb.StringArray{
|
|
Data: data,
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
func (s *MockSuiteBase) getJSONBytesFieldData(name string, data [][]byte, isDynamic bool) *schemapb.FieldData {
|
|
return &schemapb.FieldData{
|
|
Type: schemapb.DataType_JSON,
|
|
FieldName: name,
|
|
Field: &schemapb.FieldData_Scalars{
|
|
Scalars: &schemapb.ScalarField{
|
|
Data: &schemapb.ScalarField_JsonData{
|
|
JsonData: &schemapb.JSONArray{
|
|
Data: data,
|
|
},
|
|
},
|
|
},
|
|
},
|
|
IsDynamic: isDynamic,
|
|
}
|
|
}
|
|
|
|
func (s *MockSuiteBase) getFloatVectorFieldData(name string, dim int64, data []float32) *schemapb.FieldData {
|
|
return &schemapb.FieldData{
|
|
Type: schemapb.DataType_FloatVector,
|
|
FieldName: name,
|
|
Field: &schemapb.FieldData_Vectors{
|
|
Vectors: &schemapb.VectorField{
|
|
Dim: dim,
|
|
Data: &schemapb.VectorField_FloatVector{
|
|
FloatVector: &schemapb.FloatArray{
|
|
Data: data,
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
func (s *MockSuiteBase) getSuccessStatus() *commonpb.Status {
|
|
return s.getStatus(commonpb.ErrorCode_Success, "")
|
|
}
|
|
|
|
func (s *MockSuiteBase) getStatus(code commonpb.ErrorCode, reason string) *commonpb.Status {
|
|
return &commonpb.Status{
|
|
ErrorCode: code,
|
|
Reason: reason,
|
|
}
|
|
}
|
|
|
|
var letters = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
|
|
|
|
func (s *MockSuiteBase) randString(l int) string {
|
|
builder := strings.Builder{}
|
|
for i := 0; i < l; i++ {
|
|
builder.WriteRune(letters[rand.Intn(len(letters))])
|
|
}
|
|
return builder.String()
|
|
}
|