mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-12-01 03:18:29 +08:00
ebd0af14f4
pr: #33760 Issue: #22837 Signed-off-by: Cai Yudong <yudong.cai@zilliz.com>
283 lines
7.3 KiB
Go
283 lines
7.3 KiB
Go
// Licensed to the LF AI & Data foundation under one
|
|
// or more contributor license agreements. See the NOTICE file
|
|
// distributed with this work for additional information
|
|
// regarding copyright ownership. The ASF licenses this file
|
|
// to you under the Apache License, Version 2.0 (the
|
|
// "License"); you may not use this file except in compliance
|
|
// with the License. You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package integration
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/binary"
|
|
"encoding/json"
|
|
"math/rand"
|
|
"strconv"
|
|
"time"
|
|
|
|
"github.com/golang/protobuf/proto"
|
|
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
|
"github.com/milvus-io/milvus/pkg/common"
|
|
"github.com/milvus-io/milvus/pkg/util/testutils"
|
|
)
|
|
|
|
const (
|
|
AnnsFieldKey = "anns_field"
|
|
TopKKey = "topk"
|
|
NQKey = "nq"
|
|
MetricTypeKey = common.MetricTypeKey
|
|
SearchParamsKey = common.IndexParamsKey
|
|
RoundDecimalKey = "round_decimal"
|
|
OffsetKey = "offset"
|
|
LimitKey = "limit"
|
|
)
|
|
|
|
func (s *MiniClusterSuite) WaitForLoadWithDB(ctx context.Context, dbName, collection string) {
|
|
s.waitForLoadInternal(ctx, dbName, collection)
|
|
}
|
|
|
|
func (s *MiniClusterSuite) WaitForLoad(ctx context.Context, collection string) {
|
|
s.waitForLoadInternal(ctx, "", collection)
|
|
}
|
|
|
|
func (s *MiniClusterSuite) waitForLoadInternal(ctx context.Context, dbName, collection string) {
|
|
cluster := s.Cluster
|
|
getLoadingProgress := func() *milvuspb.GetLoadingProgressResponse {
|
|
loadProgress, err := cluster.Proxy.GetLoadingProgress(ctx, &milvuspb.GetLoadingProgressRequest{
|
|
DbName: dbName,
|
|
CollectionName: collection,
|
|
})
|
|
if err != nil {
|
|
panic("GetLoadingProgress fail")
|
|
}
|
|
return loadProgress
|
|
}
|
|
for getLoadingProgress().GetProgress() != 100 {
|
|
select {
|
|
case <-ctx.Done():
|
|
s.FailNow("failed to wait for load")
|
|
return
|
|
default:
|
|
time.Sleep(500 * time.Millisecond)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *MiniClusterSuite) WaitForLoadRefresh(ctx context.Context, dbName, collection string) {
|
|
cluster := s.Cluster
|
|
getLoadingProgress := func() *milvuspb.GetLoadingProgressResponse {
|
|
loadProgress, err := cluster.Proxy.GetLoadingProgress(ctx, &milvuspb.GetLoadingProgressRequest{
|
|
DbName: dbName,
|
|
CollectionName: collection,
|
|
})
|
|
if err != nil {
|
|
panic("GetLoadingProgress fail")
|
|
}
|
|
return loadProgress
|
|
}
|
|
for getLoadingProgress().GetRefreshProgress() != 100 {
|
|
select {
|
|
case <-ctx.Done():
|
|
s.FailNow("failed to wait for load (refresh)")
|
|
return
|
|
default:
|
|
time.Sleep(500 * time.Millisecond)
|
|
}
|
|
}
|
|
}
|
|
|
|
func ConstructSearchRequest(
|
|
dbName, collectionName string,
|
|
expr string,
|
|
vecField string,
|
|
vectorType schemapb.DataType,
|
|
outputFields []string,
|
|
metricType string,
|
|
params map[string]any,
|
|
nq, dim int, topk, roundDecimal int,
|
|
) *milvuspb.SearchRequest {
|
|
b, err := json.Marshal(params)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
plg := constructPlaceholderGroup(nq, dim, vectorType)
|
|
plgBs, err := proto.Marshal(plg)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
return &milvuspb.SearchRequest{
|
|
Base: nil,
|
|
DbName: dbName,
|
|
CollectionName: collectionName,
|
|
PartitionNames: nil,
|
|
Dsl: expr,
|
|
PlaceholderGroup: plgBs,
|
|
DslType: commonpb.DslType_BoolExprV1,
|
|
OutputFields: outputFields,
|
|
SearchParams: []*commonpb.KeyValuePair{
|
|
{
|
|
Key: common.MetricTypeKey,
|
|
Value: metricType,
|
|
},
|
|
{
|
|
Key: SearchParamsKey,
|
|
Value: string(b),
|
|
},
|
|
{
|
|
Key: AnnsFieldKey,
|
|
Value: vecField,
|
|
},
|
|
{
|
|
Key: common.TopKKey,
|
|
Value: strconv.Itoa(topk),
|
|
},
|
|
{
|
|
Key: RoundDecimalKey,
|
|
Value: strconv.Itoa(roundDecimal),
|
|
},
|
|
},
|
|
TravelTimestamp: 0,
|
|
GuaranteeTimestamp: 0,
|
|
Nq: int64(nq),
|
|
}
|
|
}
|
|
|
|
func ConstructSearchRequestWithConsistencyLevel(
|
|
dbName, collectionName string,
|
|
expr string,
|
|
vecField string,
|
|
vectorType schemapb.DataType,
|
|
outputFields []string,
|
|
metricType string,
|
|
params map[string]any,
|
|
nq, dim int, topk, roundDecimal int,
|
|
useDefaultConsistency bool,
|
|
consistencyLevel commonpb.ConsistencyLevel,
|
|
) *milvuspb.SearchRequest {
|
|
b, err := json.Marshal(params)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
plg := constructPlaceholderGroup(nq, dim, vectorType)
|
|
plgBs, err := proto.Marshal(plg)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
return &milvuspb.SearchRequest{
|
|
Base: nil,
|
|
DbName: dbName,
|
|
CollectionName: collectionName,
|
|
PartitionNames: nil,
|
|
Dsl: expr,
|
|
PlaceholderGroup: plgBs,
|
|
DslType: commonpb.DslType_BoolExprV1,
|
|
OutputFields: outputFields,
|
|
SearchParams: []*commonpb.KeyValuePair{
|
|
{
|
|
Key: common.MetricTypeKey,
|
|
Value: metricType,
|
|
},
|
|
{
|
|
Key: SearchParamsKey,
|
|
Value: string(b),
|
|
},
|
|
{
|
|
Key: AnnsFieldKey,
|
|
Value: vecField,
|
|
},
|
|
{
|
|
Key: common.TopKKey,
|
|
Value: strconv.Itoa(topk),
|
|
},
|
|
{
|
|
Key: RoundDecimalKey,
|
|
Value: strconv.Itoa(roundDecimal),
|
|
},
|
|
},
|
|
TravelTimestamp: 0,
|
|
GuaranteeTimestamp: 0,
|
|
UseDefaultConsistency: useDefaultConsistency,
|
|
ConsistencyLevel: consistencyLevel,
|
|
}
|
|
}
|
|
|
|
func constructPlaceholderGroup(nq, dim int, vectorType schemapb.DataType) *commonpb.PlaceholderGroup {
|
|
values := make([][]byte, 0, nq)
|
|
var placeholderType commonpb.PlaceholderType
|
|
switch vectorType {
|
|
case schemapb.DataType_FloatVector:
|
|
placeholderType = commonpb.PlaceholderType_FloatVector
|
|
for i := 0; i < nq; i++ {
|
|
bs := make([]byte, 0, dim*4)
|
|
for j := 0; j < dim; j++ {
|
|
var buffer bytes.Buffer
|
|
f := rand.Float32()
|
|
err := binary.Write(&buffer, common.Endian, f)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
bs = append(bs, buffer.Bytes()...)
|
|
}
|
|
values = append(values, bs)
|
|
}
|
|
case schemapb.DataType_BinaryVector:
|
|
placeholderType = commonpb.PlaceholderType_BinaryVector
|
|
for i := 0; i < nq; i++ {
|
|
total := dim / 8
|
|
ret := make([]byte, total)
|
|
_, err := rand.Read(ret)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
values = append(values, ret)
|
|
}
|
|
case schemapb.DataType_Float16Vector:
|
|
placeholderType = commonpb.PlaceholderType_Float16Vector
|
|
data := testutils.GenerateFloat16Vectors(nq, dim)
|
|
for i := 0; i < nq; i++ {
|
|
rowBytes := dim * 2
|
|
values = append(values, data[rowBytes*i:rowBytes*(i+1)])
|
|
}
|
|
case schemapb.DataType_BFloat16Vector:
|
|
placeholderType = commonpb.PlaceholderType_BFloat16Vector
|
|
data := testutils.GenerateBFloat16Vectors(nq, dim)
|
|
for i := 0; i < nq; i++ {
|
|
rowBytes := dim * 2
|
|
values = append(values, data[rowBytes*i:rowBytes*(i+1)])
|
|
}
|
|
case schemapb.DataType_SparseFloatVector:
|
|
// for sparse, all query rows are encoded in a single byte array
|
|
values = make([][]byte, 0, 1)
|
|
placeholderType = commonpb.PlaceholderType_SparseFloatVector
|
|
sparseVecs := GenerateSparseFloatArray(nq)
|
|
values = append(values, sparseVecs.Contents...)
|
|
default:
|
|
panic("invalid vector data type")
|
|
}
|
|
|
|
return &commonpb.PlaceholderGroup{
|
|
Placeholders: []*commonpb.PlaceholderValue{
|
|
{
|
|
Tag: "$0",
|
|
Type: placeholderType,
|
|
Values: values,
|
|
},
|
|
},
|
|
}
|
|
}
|