mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-12-09 15:30:33 +08:00
a55f739608
Signed-off-by: SimFG <bang.fu@zilliz.com> Signed-off-by: SimFG <bang.fu@zilliz.com>
491 lines
12 KiB
Go
491 lines
12 KiB
Go
package proxy
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"testing"
|
|
|
|
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
|
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
|
|
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
|
"github.com/stretchr/testify/assert"
|
|
)
|
|
|
|
func TestCalcDistanceTask_arrangeVectorsByStrID(t *testing.T) {
|
|
task := &calcDistanceTask{}
|
|
|
|
inputIds := make([]string, 0)
|
|
inputIds = append(inputIds, "c")
|
|
inputIds = append(inputIds, "b")
|
|
inputIds = append(inputIds, "a")
|
|
|
|
sequence := make(map[string]int)
|
|
sequence["a"] = 0
|
|
sequence["b"] = 1
|
|
sequence["c"] = 2
|
|
|
|
dim := 16
|
|
|
|
// float vector
|
|
floatValue := make([]float32, 0)
|
|
for i := 0; i < dim*3; i++ {
|
|
floatValue = append(floatValue, float32(i))
|
|
}
|
|
retrievedVectors := &schemapb.VectorField{
|
|
Dim: int64(dim),
|
|
Data: &schemapb.VectorField_FloatVector{
|
|
FloatVector: &schemapb.FloatArray{
|
|
Data: floatValue,
|
|
},
|
|
},
|
|
}
|
|
|
|
result, err := task.arrangeVectorsByStrID(inputIds, sequence, retrievedVectors)
|
|
assert.Nil(t, err)
|
|
|
|
floatResult := result.GetFloatVector().GetData()
|
|
for i := 0; i < 3; i++ {
|
|
for j := 0; j < dim; j++ {
|
|
assert.Equal(t, floatValue[dim*sequence[inputIds[i]]+j], floatResult[i*dim+j])
|
|
}
|
|
}
|
|
|
|
// binary vector
|
|
binaryValue := make([]byte, 0)
|
|
for i := 0; i < 3*dim/8; i++ {
|
|
binaryValue = append(binaryValue, byte(i))
|
|
}
|
|
retrievedVectors = &schemapb.VectorField{
|
|
Dim: int64(dim),
|
|
Data: &schemapb.VectorField_BinaryVector{
|
|
BinaryVector: binaryValue,
|
|
},
|
|
}
|
|
|
|
result, err = task.arrangeVectorsByStrID(inputIds, sequence, retrievedVectors)
|
|
assert.Nil(t, err)
|
|
|
|
binaryResult := result.GetBinaryVector()
|
|
numBytes := dim / 8
|
|
for i := 0; i < 3; i++ {
|
|
for j := 0; j < numBytes; j++ {
|
|
assert.Equal(t, binaryValue[sequence[inputIds[i]]*numBytes+j], binaryResult[i*numBytes+j])
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestCalcDistanceTask_arrangeVectorsByIntID(t *testing.T) {
|
|
task := &calcDistanceTask{}
|
|
|
|
inputIds := make([]int64, 0)
|
|
inputIds = append(inputIds, 2)
|
|
inputIds = append(inputIds, 0)
|
|
inputIds = append(inputIds, 1)
|
|
|
|
sequence := make(map[int64]int)
|
|
sequence[0] = 0
|
|
sequence[1] = 1
|
|
sequence[2] = 2
|
|
|
|
dim := 16
|
|
|
|
// float vector
|
|
floatValue := make([]float32, 0)
|
|
for i := 0; i < dim*3; i++ {
|
|
floatValue = append(floatValue, float32(i))
|
|
}
|
|
retrievedVectors := &schemapb.VectorField{
|
|
Dim: int64(dim),
|
|
Data: &schemapb.VectorField_FloatVector{
|
|
FloatVector: &schemapb.FloatArray{
|
|
Data: floatValue,
|
|
},
|
|
},
|
|
}
|
|
|
|
result, err := task.arrangeVectorsByIntID(inputIds, sequence, retrievedVectors)
|
|
assert.Nil(t, err)
|
|
|
|
floatResult := result.GetFloatVector().GetData()
|
|
for i := 0; i < 3; i++ {
|
|
for j := 0; j < dim; j++ {
|
|
assert.Equal(t, floatValue[dim*sequence[inputIds[i]]+j], floatResult[i*dim+j])
|
|
}
|
|
}
|
|
|
|
// binary vector
|
|
binaryValue := make([]byte, 0)
|
|
for i := 0; i < dim*3; i++ {
|
|
binaryValue = append(binaryValue, byte(i))
|
|
}
|
|
retrievedVectors = &schemapb.VectorField{
|
|
Dim: int64(dim),
|
|
Data: &schemapb.VectorField_BinaryVector{
|
|
BinaryVector: binaryValue,
|
|
},
|
|
}
|
|
|
|
result, err = task.arrangeVectorsByIntID(inputIds, sequence, retrievedVectors)
|
|
assert.Nil(t, err)
|
|
|
|
binaryResult := result.GetBinaryVector()
|
|
numBytes := dim / 8
|
|
for i := 0; i < 3; i++ {
|
|
for j := 0; j < numBytes; j++ {
|
|
assert.Equal(t, binaryValue[sequence[inputIds[i]]*numBytes+j], binaryResult[i*numBytes+j])
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestCalcDistanceTask_ExecuteFloat(t *testing.T) {
|
|
ctx := context.Background()
|
|
queryFunc := func(ids *milvuspb.VectorIDs) (*milvuspb.QueryResults, error) {
|
|
return nil, errors.New("unexpected error")
|
|
}
|
|
|
|
task := &calcDistanceTask{
|
|
traceID: "dummy",
|
|
queryFunc: queryFunc,
|
|
}
|
|
|
|
request := &milvuspb.CalcDistanceRequest{
|
|
OpLeft: nil,
|
|
OpRight: nil,
|
|
Params: []*commonpb.KeyValuePair{
|
|
{Key: "metric", Value: "L2"},
|
|
},
|
|
}
|
|
|
|
// left-op empty
|
|
calcResult, err := task.Execute(ctx, request)
|
|
assert.Nil(t, err)
|
|
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, calcResult.Status.ErrorCode)
|
|
|
|
request = &milvuspb.CalcDistanceRequest{
|
|
OpLeft: &milvuspb.VectorsArray{
|
|
Array: &milvuspb.VectorsArray_IdArray{
|
|
IdArray: &milvuspb.VectorIDs{},
|
|
},
|
|
},
|
|
OpRight: nil,
|
|
Params: []*commonpb.KeyValuePair{
|
|
{Key: "metric", Value: "L2"},
|
|
},
|
|
}
|
|
|
|
// left-op query error
|
|
calcResult, err = task.Execute(ctx, request)
|
|
assert.Nil(t, err)
|
|
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, calcResult.Status.ErrorCode)
|
|
|
|
fieldIds := make([]int64, 0)
|
|
fieldIds = append(fieldIds, 2)
|
|
fieldIds = append(fieldIds, 0)
|
|
fieldIds = append(fieldIds, 1)
|
|
|
|
dim := 8
|
|
floatValue := make([]float32, 0)
|
|
for i := 0; i < dim*3; i++ {
|
|
floatValue = append(floatValue, float32(i))
|
|
}
|
|
|
|
queryFunc = func(ids *milvuspb.VectorIDs) (*milvuspb.QueryResults, error) {
|
|
if ids == nil {
|
|
return &milvuspb.QueryResults{
|
|
Status: &commonpb.Status{
|
|
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
|
Reason: "unexpected",
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
return &milvuspb.QueryResults{
|
|
FieldsData: []*schemapb.FieldData{
|
|
{
|
|
Type: schemapb.DataType_Int64,
|
|
FieldName: "id",
|
|
Field: &schemapb.FieldData_Scalars{
|
|
Scalars: &schemapb.ScalarField{
|
|
Data: &schemapb.ScalarField_LongData{
|
|
LongData: &schemapb.LongArray{
|
|
Data: fieldIds,
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
Type: schemapb.DataType_FloatVector,
|
|
FieldName: "vec",
|
|
Field: &schemapb.FieldData_Vectors{
|
|
Vectors: &schemapb.VectorField{
|
|
Dim: int64(dim),
|
|
Data: &schemapb.VectorField_FloatVector{
|
|
FloatVector: &schemapb.FloatArray{
|
|
Data: floatValue,
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
task.queryFunc = queryFunc
|
|
calcResult, err = task.Execute(ctx, request)
|
|
assert.Nil(t, err)
|
|
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, calcResult.Status.ErrorCode)
|
|
|
|
idArray := &milvuspb.VectorsArray{
|
|
Array: &milvuspb.VectorsArray_IdArray{
|
|
IdArray: &milvuspb.VectorIDs{
|
|
FieldName: "vec",
|
|
IdArray: &schemapb.IDs{
|
|
IdField: &schemapb.IDs_IntId{
|
|
IntId: &schemapb.LongArray{
|
|
Data: fieldIds,
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
request = &milvuspb.CalcDistanceRequest{
|
|
OpLeft: idArray,
|
|
OpRight: idArray,
|
|
Params: []*commonpb.KeyValuePair{
|
|
{Key: "metric", Value: "L2"},
|
|
},
|
|
}
|
|
|
|
// success
|
|
calcResult, err = task.Execute(ctx, request)
|
|
assert.Nil(t, err)
|
|
assert.Equal(t, commonpb.ErrorCode_Success, calcResult.Status.ErrorCode)
|
|
|
|
// right-op query error
|
|
request.OpRight = nil
|
|
calcResult, err = task.Execute(ctx, request)
|
|
assert.Nil(t, err)
|
|
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, calcResult.Status.ErrorCode)
|
|
|
|
request.OpRight = &milvuspb.VectorsArray{
|
|
Array: &milvuspb.VectorsArray_IdArray{
|
|
IdArray: &milvuspb.VectorIDs{
|
|
FieldName: "kkk",
|
|
IdArray: &schemapb.IDs{
|
|
IdField: &schemapb.IDs_IntId{
|
|
IntId: &schemapb.LongArray{
|
|
Data: fieldIds,
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
// right-op arrange error
|
|
calcResult, err = task.Execute(ctx, request)
|
|
assert.Nil(t, err)
|
|
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, calcResult.Status.ErrorCode)
|
|
|
|
request.OpRight = &milvuspb.VectorsArray{
|
|
Array: &milvuspb.VectorsArray_DataArray{
|
|
DataArray: &schemapb.VectorField{
|
|
Dim: 5,
|
|
},
|
|
},
|
|
}
|
|
|
|
// different dimension
|
|
calcResult, err = task.Execute(ctx, request)
|
|
assert.Nil(t, err)
|
|
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, calcResult.Status.ErrorCode)
|
|
|
|
request.OpRight = &milvuspb.VectorsArray{
|
|
Array: &milvuspb.VectorsArray_DataArray{
|
|
DataArray: &schemapb.VectorField{
|
|
Dim: int64(dim),
|
|
Data: &schemapb.VectorField_FloatVector{
|
|
FloatVector: &schemapb.FloatArray{
|
|
Data: make([]float32, 0),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
// calcdistance return error
|
|
calcResult, err = task.Execute(ctx, request)
|
|
assert.Nil(t, err)
|
|
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, calcResult.Status.ErrorCode)
|
|
}
|
|
|
|
func TestCalcDistanceTask_ExecuteBinary(t *testing.T) {
|
|
ctx := context.Background()
|
|
|
|
fieldIds := make([]int64, 0)
|
|
fieldIds = append(fieldIds, 2)
|
|
fieldIds = append(fieldIds, 0)
|
|
fieldIds = append(fieldIds, 1)
|
|
|
|
dim := 16
|
|
binaryValue := make([]byte, 0)
|
|
for i := 0; i < 3*dim/8; i++ {
|
|
binaryValue = append(binaryValue, byte(i))
|
|
}
|
|
|
|
queryFunc := func(ids *milvuspb.VectorIDs) (*milvuspb.QueryResults, error) {
|
|
if ids == nil {
|
|
return &milvuspb.QueryResults{
|
|
Status: &commonpb.Status{
|
|
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
|
Reason: "unexpected",
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
return &milvuspb.QueryResults{
|
|
FieldsData: []*schemapb.FieldData{
|
|
{
|
|
Type: schemapb.DataType_Int64,
|
|
FieldName: "id",
|
|
Field: &schemapb.FieldData_Scalars{
|
|
Scalars: &schemapb.ScalarField{
|
|
Data: &schemapb.ScalarField_LongData{
|
|
LongData: &schemapb.LongArray{
|
|
Data: fieldIds,
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
Type: schemapb.DataType_FloatVector,
|
|
FieldName: "vec",
|
|
Field: &schemapb.FieldData_Vectors{
|
|
Vectors: &schemapb.VectorField{
|
|
Dim: int64(dim),
|
|
Data: &schemapb.VectorField_BinaryVector{
|
|
BinaryVector: binaryValue,
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
idArray := &milvuspb.VectorsArray{
|
|
Array: &milvuspb.VectorsArray_IdArray{
|
|
IdArray: &milvuspb.VectorIDs{
|
|
FieldName: "vec",
|
|
IdArray: &schemapb.IDs{
|
|
IdField: &schemapb.IDs_IntId{
|
|
IntId: &schemapb.LongArray{
|
|
Data: fieldIds,
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
request := &milvuspb.CalcDistanceRequest{
|
|
OpLeft: idArray,
|
|
OpRight: idArray,
|
|
Params: []*commonpb.KeyValuePair{
|
|
{Key: "metric", Value: "HAMMING"},
|
|
},
|
|
}
|
|
|
|
task := &calcDistanceTask{
|
|
traceID: "dummy",
|
|
queryFunc: queryFunc,
|
|
}
|
|
|
|
// success
|
|
calcResult, err := task.Execute(ctx, request)
|
|
assert.Nil(t, err)
|
|
assert.Equal(t, commonpb.ErrorCode_Success, calcResult.Status.ErrorCode)
|
|
|
|
floatArray := &milvuspb.VectorsArray{
|
|
Array: &milvuspb.VectorsArray_DataArray{
|
|
DataArray: &schemapb.VectorField{
|
|
Dim: int64(dim),
|
|
Data: &schemapb.VectorField_FloatVector{},
|
|
},
|
|
},
|
|
}
|
|
binaryArray := &milvuspb.VectorsArray{
|
|
Array: &milvuspb.VectorsArray_DataArray{
|
|
DataArray: &schemapb.VectorField{
|
|
Dim: int64(dim),
|
|
Data: &schemapb.VectorField_BinaryVector{
|
|
BinaryVector: binaryValue,
|
|
},
|
|
},
|
|
},
|
|
}
|
|
request = &milvuspb.CalcDistanceRequest{
|
|
OpLeft: floatArray,
|
|
OpRight: binaryArray,
|
|
Params: []*commonpb.KeyValuePair{
|
|
{Key: "metric", Value: "HAMMING"},
|
|
},
|
|
}
|
|
|
|
// float vs binary
|
|
calcResult, err = task.Execute(ctx, request)
|
|
assert.Nil(t, err)
|
|
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, calcResult.Status.ErrorCode)
|
|
|
|
request = &milvuspb.CalcDistanceRequest{
|
|
OpLeft: binaryArray,
|
|
OpRight: binaryArray,
|
|
Params: []*commonpb.KeyValuePair{
|
|
{Key: "metric", Value: "HAMMING"},
|
|
},
|
|
}
|
|
|
|
// hamming
|
|
calcResult, err = task.Execute(ctx, request)
|
|
assert.Nil(t, err)
|
|
assert.Equal(t, commonpb.ErrorCode_Success, calcResult.Status.ErrorCode)
|
|
|
|
request = &milvuspb.CalcDistanceRequest{
|
|
OpLeft: binaryArray,
|
|
OpRight: binaryArray,
|
|
Params: []*commonpb.KeyValuePair{
|
|
{Key: "metric", Value: "TANIMOTO"},
|
|
},
|
|
}
|
|
|
|
// tanimoto
|
|
calcResult, err = task.Execute(ctx, request)
|
|
assert.Nil(t, err)
|
|
assert.Equal(t, commonpb.ErrorCode_Success, calcResult.Status.ErrorCode)
|
|
|
|
request = &milvuspb.CalcDistanceRequest{
|
|
OpLeft: binaryArray,
|
|
OpRight: &milvuspb.VectorsArray{
|
|
Array: &milvuspb.VectorsArray_DataArray{
|
|
DataArray: &schemapb.VectorField{
|
|
Dim: int64(dim),
|
|
Data: &schemapb.VectorField_BinaryVector{
|
|
BinaryVector: make([]byte, 0),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
Params: []*commonpb.KeyValuePair{
|
|
{Key: "metric", Value: "HAMMING"},
|
|
},
|
|
}
|
|
|
|
// hamming error
|
|
calcResult, err = task.Execute(ctx, request)
|
|
assert.Nil(t, err)
|
|
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, calcResult.Status.ErrorCode)
|
|
}
|