mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-12-05 21:39:19 +08:00
a55f739608
Signed-off-by: SimFG <bang.fu@zilliz.com> Signed-off-by: SimFG <bang.fu@zilliz.com>
435 lines
13 KiB
Go
435 lines
13 KiB
Go
package proxy
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
|
|
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
|
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
|
|
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
|
"github.com/milvus-io/milvus/internal/log"
|
|
"github.com/milvus-io/milvus/internal/util/distance"
|
|
"github.com/milvus-io/milvus/internal/util/funcutil"
|
|
"github.com/milvus-io/milvus/internal/util/typeutil"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
type calcDistanceTask struct {
|
|
traceID string
|
|
queryFunc func(ids *milvuspb.VectorIDs) (*milvuspb.QueryResults, error)
|
|
}
|
|
|
|
func (t *calcDistanceTask) arrangeVectorsByIntID(inputIds []int64, sequence map[int64]int, retrievedVectors *schemapb.VectorField) (*schemapb.VectorField, error) {
|
|
if retrievedVectors.GetFloatVector() != nil {
|
|
floatArr := retrievedVectors.GetFloatVector().GetData()
|
|
element := retrievedVectors.GetDim()
|
|
result := make([]float32, 0, int64(len(inputIds))*element)
|
|
for _, id := range inputIds {
|
|
index, ok := sequence[id]
|
|
if !ok {
|
|
log.Error("id not found in CalcDistance", zap.Int64("id", id))
|
|
return nil, errors.New("failed to fetch vectors by id: " + fmt.Sprintln(id))
|
|
}
|
|
result = append(result, floatArr[int64(index)*element:int64(index+1)*element]...)
|
|
}
|
|
|
|
return &schemapb.VectorField{
|
|
Dim: element,
|
|
Data: &schemapb.VectorField_FloatVector{
|
|
FloatVector: &schemapb.FloatArray{
|
|
Data: result,
|
|
},
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
if retrievedVectors.GetBinaryVector() != nil {
|
|
binaryArr := retrievedVectors.GetBinaryVector()
|
|
singleBitLen := distance.SingleBitLen(retrievedVectors.GetDim())
|
|
numBytes := singleBitLen / 8
|
|
|
|
result := make([]byte, 0, int64(len(inputIds))*numBytes)
|
|
for _, id := range inputIds {
|
|
index, ok := sequence[id]
|
|
if !ok {
|
|
log.Error("id not found in CalcDistance", zap.Int64("id", id))
|
|
return nil, errors.New("failed to fetch vectors by id: " + fmt.Sprintln(id))
|
|
}
|
|
result = append(result, binaryArr[int64(index)*numBytes:int64(index+1)*numBytes]...)
|
|
}
|
|
|
|
return &schemapb.VectorField{
|
|
Dim: retrievedVectors.GetDim(),
|
|
Data: &schemapb.VectorField_BinaryVector{
|
|
BinaryVector: result,
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
return nil, errors.New("unsupported vector type")
|
|
}
|
|
|
|
func (t *calcDistanceTask) arrangeVectorsByStrID(inputIds []string, sequence map[string]int, retrievedVectors *schemapb.VectorField) (*schemapb.VectorField, error) {
|
|
if retrievedVectors.GetFloatVector() != nil {
|
|
floatArr := retrievedVectors.GetFloatVector().GetData()
|
|
element := retrievedVectors.GetDim()
|
|
result := make([]float32, 0, int64(len(inputIds))*element)
|
|
for _, id := range inputIds {
|
|
index, ok := sequence[id]
|
|
if !ok {
|
|
log.Error("id not found in CalcDistance", zap.String("id", id))
|
|
return nil, errors.New("failed to fetch vectors by id: " + fmt.Sprintln(id))
|
|
}
|
|
result = append(result, floatArr[int64(index)*element:int64(index+1)*element]...)
|
|
}
|
|
|
|
return &schemapb.VectorField{
|
|
Dim: element,
|
|
Data: &schemapb.VectorField_FloatVector{
|
|
FloatVector: &schemapb.FloatArray{
|
|
Data: result,
|
|
},
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
if retrievedVectors.GetBinaryVector() != nil {
|
|
binaryArr := retrievedVectors.GetBinaryVector()
|
|
singleBitLen := distance.SingleBitLen(retrievedVectors.GetDim())
|
|
numBytes := singleBitLen / 8
|
|
|
|
result := make([]byte, 0, int64(len(inputIds))*numBytes)
|
|
for _, id := range inputIds {
|
|
index, ok := sequence[id]
|
|
if !ok {
|
|
log.Error("id not found in CalcDistance", zap.String("id", id))
|
|
return nil, errors.New("failed to fetch vectors by id: " + fmt.Sprintln(id))
|
|
}
|
|
result = append(result, binaryArr[int64(index)*numBytes:int64(index+1)*numBytes]...)
|
|
}
|
|
|
|
return &schemapb.VectorField{
|
|
Dim: retrievedVectors.GetDim(),
|
|
Data: &schemapb.VectorField_BinaryVector{
|
|
BinaryVector: result,
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
return nil, errors.New("unsupported vector type")
|
|
}
|
|
|
|
func (t *calcDistanceTask) Execute(ctx context.Context, request *milvuspb.CalcDistanceRequest) (*milvuspb.CalcDistanceResults, error) {
|
|
param, _ := funcutil.GetAttrByKeyFromRepeatedKV("metric", request.GetParams())
|
|
metric, err := distance.ValidateMetricType(param)
|
|
if err != nil {
|
|
return &milvuspb.CalcDistanceResults{
|
|
Status: &commonpb.Status{
|
|
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
|
Reason: err.Error(),
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
// the vectors retrieved are random order, we need re-arrange the vectors by the order of input ids
|
|
arrangeFunc := func(ids *milvuspb.VectorIDs, retrievedFields []*schemapb.FieldData) (*schemapb.VectorField, error) {
|
|
var retrievedIds *schemapb.ScalarField
|
|
var retrievedVectors *schemapb.VectorField
|
|
isStringID := true
|
|
for _, fieldData := range retrievedFields {
|
|
if fieldData.FieldName == ids.FieldName {
|
|
retrievedVectors = fieldData.GetVectors()
|
|
}
|
|
if fieldData.Type == schemapb.DataType_Int64 ||
|
|
fieldData.Type == schemapb.DataType_VarChar ||
|
|
fieldData.Type == schemapb.DataType_String {
|
|
retrievedIds = fieldData.GetScalars()
|
|
|
|
if fieldData.Type == schemapb.DataType_Int64 {
|
|
isStringID = false
|
|
}
|
|
}
|
|
}
|
|
|
|
if retrievedIds == nil || retrievedVectors == nil {
|
|
return nil, errors.New("failed to fetch vectors")
|
|
}
|
|
|
|
if isStringID {
|
|
dict := make(map[string]int)
|
|
for index, id := range retrievedIds.GetStringData().GetData() {
|
|
dict[id] = index
|
|
}
|
|
|
|
inputIds := ids.IdArray.GetStrId().GetData()
|
|
return t.arrangeVectorsByStrID(inputIds, dict, retrievedVectors)
|
|
}
|
|
|
|
dict := make(map[int64]int)
|
|
for index, id := range retrievedIds.GetLongData().GetData() {
|
|
dict[id] = index
|
|
}
|
|
|
|
inputIds := ids.IdArray.GetIntId().GetData()
|
|
return t.arrangeVectorsByIntID(inputIds, dict, retrievedVectors)
|
|
}
|
|
|
|
log.Debug("CalcDistance received",
|
|
zap.String("traceID", t.traceID),
|
|
zap.String("role", typeutil.ProxyRole),
|
|
zap.String("metric", metric))
|
|
|
|
vectorsLeft := request.GetOpLeft().GetDataArray()
|
|
opLeft := request.GetOpLeft().GetIdArray()
|
|
if opLeft != nil {
|
|
log.Debug("OpLeft IdArray not empty, Get vectors by id",
|
|
zap.String("traceID", t.traceID),
|
|
zap.String("role", typeutil.ProxyRole))
|
|
|
|
result, err := t.queryFunc(opLeft)
|
|
if err != nil {
|
|
log.Debug("Failed to get left vectors by id",
|
|
zap.Error(err),
|
|
zap.String("traceID", t.traceID),
|
|
zap.String("role", typeutil.ProxyRole))
|
|
|
|
return &milvuspb.CalcDistanceResults{
|
|
Status: &commonpb.Status{
|
|
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
|
Reason: err.Error(),
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
log.Debug("OpLeft IdArray not empty, Get vectors by id done",
|
|
zap.String("traceID", t.traceID),
|
|
zap.String("role", typeutil.ProxyRole))
|
|
|
|
vectorsLeft, err = arrangeFunc(opLeft, result.FieldsData)
|
|
if err != nil {
|
|
log.Debug("Failed to re-arrange left vectors",
|
|
zap.Error(err),
|
|
zap.String("traceID", t.traceID),
|
|
zap.String("role", typeutil.ProxyRole))
|
|
|
|
return &milvuspb.CalcDistanceResults{
|
|
Status: &commonpb.Status{
|
|
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
|
Reason: err.Error(),
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
log.Debug("Re-arrange left vectors done",
|
|
zap.String("traceID", t.traceID),
|
|
zap.String("role", typeutil.ProxyRole))
|
|
}
|
|
|
|
if vectorsLeft == nil {
|
|
msg := "Left vectors array is empty"
|
|
log.Debug(msg,
|
|
zap.String("traceID", t.traceID),
|
|
zap.String("role", typeutil.ProxyRole))
|
|
|
|
return &milvuspb.CalcDistanceResults{
|
|
Status: &commonpb.Status{
|
|
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
|
Reason: msg,
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
vectorsRight := request.GetOpRight().GetDataArray()
|
|
opRight := request.GetOpRight().GetIdArray()
|
|
if opRight != nil {
|
|
log.Debug("OpRight IdArray not empty, Get vectors by id",
|
|
zap.String("traceID", t.traceID),
|
|
zap.String("role", typeutil.ProxyRole))
|
|
|
|
result, err := t.queryFunc(opRight)
|
|
if err != nil {
|
|
log.Debug("Failed to get right vectors by id",
|
|
zap.Error(err),
|
|
zap.String("traceID", t.traceID),
|
|
zap.String("role", typeutil.ProxyRole))
|
|
|
|
return &milvuspb.CalcDistanceResults{
|
|
Status: &commonpb.Status{
|
|
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
|
Reason: err.Error(),
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
log.Debug("OpRight IdArray not empty, Get vectors by id done",
|
|
zap.String("traceID", t.traceID),
|
|
zap.String("role", typeutil.ProxyRole))
|
|
|
|
vectorsRight, err = arrangeFunc(opRight, result.FieldsData)
|
|
if err != nil {
|
|
log.Debug("Failed to re-arrange right vectors",
|
|
zap.Error(err),
|
|
zap.String("traceID", t.traceID),
|
|
zap.String("role", typeutil.ProxyRole))
|
|
|
|
return &milvuspb.CalcDistanceResults{
|
|
Status: &commonpb.Status{
|
|
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
|
Reason: err.Error(),
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
log.Debug("Re-arrange right vectors done",
|
|
zap.String("traceID", t.traceID),
|
|
zap.String("role", typeutil.ProxyRole))
|
|
}
|
|
|
|
if vectorsRight == nil {
|
|
msg := "Right vectors array is empty"
|
|
log.Debug(msg,
|
|
zap.String("traceID", t.traceID),
|
|
zap.String("role", typeutil.ProxyRole))
|
|
|
|
return &milvuspb.CalcDistanceResults{
|
|
Status: &commonpb.Status{
|
|
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
|
Reason: msg,
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
if vectorsLeft.GetDim() != vectorsRight.GetDim() {
|
|
msg := "Vectors dimension is not equal"
|
|
log.Debug(msg,
|
|
zap.String("traceID", t.traceID),
|
|
zap.String("role", typeutil.ProxyRole))
|
|
|
|
return &milvuspb.CalcDistanceResults{
|
|
Status: &commonpb.Status{
|
|
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
|
Reason: msg,
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
if vectorsLeft.GetFloatVector() != nil && vectorsRight.GetFloatVector() != nil {
|
|
distances, err := distance.CalcFloatDistance(vectorsLeft.GetDim(), vectorsLeft.GetFloatVector().GetData(), vectorsRight.GetFloatVector().GetData(), metric)
|
|
if err != nil {
|
|
log.Debug("Failed to CalcFloatDistance",
|
|
zap.Error(err),
|
|
zap.Int64("leftDim", vectorsLeft.GetDim()),
|
|
zap.Int("leftLen", len(vectorsLeft.GetFloatVector().GetData())),
|
|
zap.Int64("rightDim", vectorsRight.GetDim()),
|
|
zap.Int("rightLen", len(vectorsRight.GetFloatVector().GetData())),
|
|
zap.String("traceID", t.traceID),
|
|
zap.String("role", typeutil.ProxyRole))
|
|
|
|
return &milvuspb.CalcDistanceResults{
|
|
Status: &commonpb.Status{
|
|
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
|
Reason: err.Error(),
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
log.Debug("CalcFloatDistance done",
|
|
zap.Error(err),
|
|
zap.String("traceID", t.traceID),
|
|
zap.String("role", typeutil.ProxyRole))
|
|
|
|
return &milvuspb.CalcDistanceResults{
|
|
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success, Reason: ""},
|
|
Array: &milvuspb.CalcDistanceResults_FloatDist{
|
|
FloatDist: &schemapb.FloatArray{
|
|
Data: distances,
|
|
},
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
if vectorsLeft.GetBinaryVector() != nil && vectorsRight.GetBinaryVector() != nil {
|
|
hamming, err := distance.CalcHammingDistance(vectorsLeft.GetDim(), vectorsLeft.GetBinaryVector(), vectorsRight.GetBinaryVector())
|
|
if err != nil {
|
|
log.Debug("Failed to CalcHammingDistance",
|
|
zap.Error(err),
|
|
zap.Int64("leftDim", vectorsLeft.GetDim()),
|
|
zap.Int("leftLen", len(vectorsLeft.GetBinaryVector())),
|
|
zap.Int64("rightDim", vectorsRight.GetDim()),
|
|
zap.Int("rightLen", len(vectorsRight.GetBinaryVector())),
|
|
zap.String("traceID", t.traceID),
|
|
zap.String("role", typeutil.ProxyRole))
|
|
|
|
return &milvuspb.CalcDistanceResults{
|
|
Status: &commonpb.Status{
|
|
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
|
Reason: err.Error(),
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
if metric == distance.HAMMING {
|
|
log.Debug("CalcHammingDistance done",
|
|
zap.String("traceID", t.traceID),
|
|
zap.String("role", typeutil.ProxyRole))
|
|
|
|
return &milvuspb.CalcDistanceResults{
|
|
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success, Reason: ""},
|
|
Array: &milvuspb.CalcDistanceResults_IntDist{
|
|
IntDist: &schemapb.IntArray{
|
|
Data: hamming,
|
|
},
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
if metric == distance.TANIMOTO {
|
|
tanimoto, err := distance.CalcTanimotoCoefficient(vectorsLeft.GetDim(), hamming)
|
|
if err != nil {
|
|
log.Debug("Failed to CalcTanimotoCoefficient",
|
|
zap.Error(err),
|
|
zap.String("traceID", t.traceID),
|
|
zap.String("role", typeutil.ProxyRole))
|
|
|
|
return &milvuspb.CalcDistanceResults{
|
|
Status: &commonpb.Status{
|
|
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
|
Reason: err.Error(),
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
log.Debug("CalcTanimotoCoefficient done",
|
|
zap.String("traceID", t.traceID),
|
|
zap.String("role", typeutil.ProxyRole))
|
|
|
|
return &milvuspb.CalcDistanceResults{
|
|
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success, Reason: ""},
|
|
Array: &milvuspb.CalcDistanceResults_FloatDist{
|
|
FloatDist: &schemapb.FloatArray{
|
|
Data: tanimoto,
|
|
},
|
|
},
|
|
}, nil
|
|
}
|
|
}
|
|
|
|
err = errors.New("unexpected error")
|
|
if (vectorsLeft.GetBinaryVector() != nil && vectorsRight.GetFloatVector() != nil) || (vectorsLeft.GetFloatVector() != nil && vectorsRight.GetBinaryVector() != nil) {
|
|
err = errors.New("cannot calculate distance between binary vectors and float vectors")
|
|
}
|
|
|
|
log.Debug("Failed to CalcDistance",
|
|
zap.Error(err),
|
|
zap.String("traceID", t.traceID),
|
|
zap.String("role", typeutil.ProxyRole))
|
|
|
|
return &milvuspb.CalcDistanceResults{
|
|
Status: &commonpb.Status{
|
|
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
|
Reason: err.Error(),
|
|
},
|
|
}, nil
|
|
}
|