// Licensed to the LF AI & Data foundation under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package client import ( "encoding/json" "strconv" "google.golang.org/protobuf/proto" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/client/v2/entity" ) const ( spAnnsField = `anns_field` spTopK = `topk` spOffset = `offset` spLimit = `limit` spParams = `params` spMetricsType = `metric_type` spRoundDecimal = `round_decimal` spIgnoreGrowing = `ignore_growing` spGroupBy = `group_by_field` ) type SearchOption interface { Request() *milvuspb.SearchRequest } var _ SearchOption = (*searchOption)(nil) type searchOption struct { collectionName string partitionNames []string topK int offset int outputFields []string consistencyLevel entity.ConsistencyLevel useDefaultConsistencyLevel bool ignoreGrowing bool expr string // normal search request request *annRequest // TODO add sub request when support hybrid search } type annRequest struct { vectors []entity.Vector annField string metricsType entity.MetricType searchParam map[string]string groupByField string } func (opt *searchOption) Request() *milvuspb.SearchRequest { // TODO check whether search is hybrid after logic merged return opt.prepareSearchRequest(opt.request) } func (opt *searchOption) prepareSearchRequest(annRequest *annRequest) *milvuspb.SearchRequest { request := &milvuspb.SearchRequest{ CollectionName: opt.collectionName, PartitionNames: opt.partitionNames, Dsl: opt.expr, DslType: commonpb.DslType_BoolExprV1, ConsistencyLevel: commonpb.ConsistencyLevel(opt.consistencyLevel), OutputFields: opt.outputFields, } if annRequest != nil { // nq request.Nq = int64(len(annRequest.vectors)) // search param bs, _ := json.Marshal(annRequest.searchParam) params := map[string]string{ spAnnsField: annRequest.annField, spTopK: strconv.Itoa(opt.topK), spOffset: strconv.Itoa(opt.offset), spParams: string(bs), spMetricsType: string(annRequest.metricsType), spRoundDecimal: "-1", spIgnoreGrowing: strconv.FormatBool(opt.ignoreGrowing), } if annRequest.groupByField != "" { params[spGroupBy] = annRequest.groupByField } request.SearchParams = entity.MapKvPairs(params) // placeholder group request.PlaceholderGroup = vector2PlaceholderGroupBytes(annRequest.vectors) } return request } func (opt *searchOption) WithFilter(expr string) *searchOption { opt.expr = expr return opt } func (opt *searchOption) WithOffset(offset int) *searchOption { opt.offset = offset return opt } func (opt *searchOption) WithOutputFields(fieldNames []string) *searchOption { opt.outputFields = fieldNames return opt } func (opt *searchOption) WithConsistencyLevel(consistencyLevel entity.ConsistencyLevel) *searchOption { opt.consistencyLevel = consistencyLevel opt.useDefaultConsistencyLevel = false return opt } func (opt *searchOption) WithANNSField(annsField string) *searchOption { opt.request.annField = annsField return opt } func (opt *searchOption) WithPartitions(partitionNames []string) *searchOption { opt.partitionNames = partitionNames return opt } func NewSearchOption(collectionName string, limit int, vectors []entity.Vector) *searchOption { return &searchOption{ collectionName: collectionName, topK: limit, request: &annRequest{ vectors: vectors, }, useDefaultConsistencyLevel: true, consistencyLevel: entity.ClBounded, } } func vector2PlaceholderGroupBytes(vectors []entity.Vector) []byte { phg := &commonpb.PlaceholderGroup{ Placeholders: []*commonpb.PlaceholderValue{ vector2Placeholder(vectors), }, } bs, _ := proto.Marshal(phg) return bs } func vector2Placeholder(vectors []entity.Vector) *commonpb.PlaceholderValue { var placeHolderType commonpb.PlaceholderType ph := &commonpb.PlaceholderValue{ Tag: "$0", Values: make([][]byte, 0, len(vectors)), } if len(vectors) == 0 { return ph } switch vectors[0].(type) { case entity.FloatVector: placeHolderType = commonpb.PlaceholderType_FloatVector case entity.BinaryVector: placeHolderType = commonpb.PlaceholderType_BinaryVector case entity.BFloat16Vector: placeHolderType = commonpb.PlaceholderType_BFloat16Vector case entity.Float16Vector: placeHolderType = commonpb.PlaceholderType_Float16Vector case entity.SparseEmbedding: placeHolderType = commonpb.PlaceholderType_SparseFloatVector } ph.Type = placeHolderType for _, vector := range vectors { ph.Values = append(ph.Values, vector.Serialize()) } return ph } type QueryOption interface { Request() *milvuspb.QueryRequest } type queryOption struct { collectionName string partitionNames []string queryParams map[string]string outputFields []string consistencyLevel entity.ConsistencyLevel useDefaultConsistencyLevel bool expr string } func (opt *queryOption) Request() *milvuspb.QueryRequest { return &milvuspb.QueryRequest{ CollectionName: opt.collectionName, PartitionNames: opt.partitionNames, OutputFields: opt.outputFields, Expr: opt.expr, QueryParams: entity.MapKvPairs(opt.queryParams), ConsistencyLevel: opt.consistencyLevel.CommonConsistencyLevel(), } } func (opt *queryOption) WithFilter(expr string) *queryOption { opt.expr = expr return opt } func (opt *queryOption) WithOffset(offset int) *queryOption { if opt.queryParams == nil { opt.queryParams = make(map[string]string) } opt.queryParams[spOffset] = strconv.Itoa(offset) return opt } func (opt *queryOption) WithLimit(limit int) *queryOption { if opt.queryParams == nil { opt.queryParams = make(map[string]string) } opt.queryParams[spLimit] = strconv.Itoa(limit) return opt } func (opt *queryOption) WithOutputFields(fieldNames []string) *queryOption { opt.outputFields = fieldNames return opt } func (opt *queryOption) WithConsistencyLevel(consistencyLevel entity.ConsistencyLevel) *queryOption { opt.consistencyLevel = consistencyLevel opt.useDefaultConsistencyLevel = false return opt } func (opt *queryOption) WithPartitions(partitionNames []string) *queryOption { opt.partitionNames = partitionNames return opt } func NewQueryOption(collectionName string) *queryOption { return &queryOption{ collectionName: collectionName, useDefaultConsistencyLevel: true, consistencyLevel: entity.ClBounded, } }