diff --git a/internal/proxy/task_search_test.go b/internal/proxy/task_search_test.go index 0ce0bda398..f744a4a777 100644 --- a/internal/proxy/task_search_test.go +++ b/internal/proxy/task_search_test.go @@ -28,6 +28,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + "google.golang.org/grpc" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" @@ -1938,32 +1939,45 @@ func TestSearchTask_Requery(t *testing.T) { t.Run("Test normal", func(t *testing.T) { schema := constructCollectionSchema(pkField, vecField, dim, collection) qn := mocks.NewMockQueryNodeClient(t) - qn.EXPECT().Query(mock.Anything, mock.Anything). - Return(&internalpb.RetrieveResults{ - Ids: &schemapb.IDs{ + qn.EXPECT().Query(mock.Anything, mock.Anything).RunAndReturn( + func(ctx context.Context, request *querypb.QueryRequest, option ...grpc.CallOption) (*internalpb.RetrieveResults, error) { + idFieldData := &schemapb.FieldData{ + Type: schemapb.DataType_Int64, + FieldName: pkField, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_LongData{ + LongData: &schemapb.LongArray{ + Data: ids, + }, + }, + }, + }, + } + idField := &schemapb.IDs{ IdField: &schemapb.IDs_IntId{ IntId: &schemapb.LongArray{ Data: ids, }, }, - }, - FieldsData: []*schemapb.FieldData{ - { - Type: schemapb.DataType_Int64, - FieldName: pkField, - Field: &schemapb.FieldData_Scalars{ - Scalars: &schemapb.ScalarField{ - Data: &schemapb.ScalarField_LongData{ - LongData: &schemapb.LongArray{ - Data: ids, - }, - }, - }, + } + if request.GetReq().GetOutputFieldsId()[0] == 100 { + return &internalpb.RetrieveResults{ + Ids: idField, + FieldsData: []*schemapb.FieldData{ + idFieldData, + newFloatVectorFieldData(vecField, rows, dim), }, + }, nil + } + return &internalpb.RetrieveResults{ + Ids: idField, + FieldsData: []*schemapb.FieldData{ + newFloatVectorFieldData(vecField, rows, dim), + idFieldData, }, - newFloatVectorFieldData(vecField, rows, dim), - }, - }, nil) + }, nil + }) lb := NewMockLBPolicy(t) lb.EXPECT().Execute(mock.Anything, mock.Anything).Run(func(ctx context.Context, workload CollectionWorkLoad) {