mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-12-02 03:48:37 +08:00
enhance: add efficient distance computations in Go (#28657)
Related #28656 Add more efficient calc_distance at Go side. --------- Signed-off-by: chasingegg <chao.gao@zilliz.com>
This commit is contained in:
parent
606ec77b66
commit
5bd73d5503
97
pkg/util/distance/asm/ip.go
Normal file
97
pkg/util/distance/asm/ip.go
Normal file
@ -0,0 +1,97 @@
|
||||
//go:build ignore
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
. "github.com/mmcloughlin/avo/build"
|
||||
. "github.com/mmcloughlin/avo/operand"
|
||||
. "github.com/mmcloughlin/avo/reg"
|
||||
)
|
||||
|
||||
var unroll = 4
|
||||
|
||||
// inspired by the avo example https://github.com/mmcloughlin/avo
|
||||
// avo is a tool to generate go assembly
|
||||
func main() {
|
||||
TEXT("IP", NOSPLIT, "func(x, y []float32) float32")
|
||||
Doc("inner product between x and y")
|
||||
x := Mem{Base: Load(Param("x").Base(), GP64())}
|
||||
y := Mem{Base: Load(Param("y").Base(), GP64())}
|
||||
n := Load(Param("x").Len(), GP64())
|
||||
|
||||
acc := make([]VecVirtual, unroll)
|
||||
for i := 0; i < unroll; i++ {
|
||||
acc[i] = YMM()
|
||||
}
|
||||
|
||||
// Zero initialization
|
||||
for i := 0; i < unroll; i++ {
|
||||
VXORPS(acc[i], acc[i], acc[i])
|
||||
}
|
||||
|
||||
// Loop over blocks and process them with vector instructions
|
||||
blockitems := 8 * unroll
|
||||
blocksize := 4 * blockitems
|
||||
Label("blockloop")
|
||||
CMPQ(n, U32(blockitems))
|
||||
JL(LabelRef("tail"))
|
||||
|
||||
// Load x
|
||||
xs := make([]VecVirtual, unroll)
|
||||
for i := 0; i < unroll; i++ {
|
||||
xs[i] = YMM()
|
||||
}
|
||||
|
||||
for i := 0; i < unroll; i++ {
|
||||
VMOVUPS(x.Offset(32*i), xs[i])
|
||||
}
|
||||
|
||||
// The actual FMA
|
||||
for i := 0; i < unroll; i++ {
|
||||
VFMADD231PS(y.Offset(32*i), xs[i], acc[i])
|
||||
}
|
||||
|
||||
ADDQ(U32(blocksize), x.Base)
|
||||
ADDQ(U32(blocksize), y.Base)
|
||||
SUBQ(U32(blockitems), n)
|
||||
JMP(LabelRef("blockloop"))
|
||||
|
||||
// Process any trailing entries
|
||||
Label("tail")
|
||||
tail := XMM()
|
||||
VXORPS(tail, tail, tail)
|
||||
|
||||
Label("tailloop")
|
||||
CMPQ(n, U32(0))
|
||||
JE(LabelRef("reduce"))
|
||||
|
||||
xt := XMM()
|
||||
VMOVSS(x, xt)
|
||||
VFMADD231SS(y, xt, tail)
|
||||
|
||||
ADDQ(U32(4), x.Base)
|
||||
ADDQ(U32(4), y.Base)
|
||||
DECQ(n)
|
||||
JMP(LabelRef("tailloop"))
|
||||
|
||||
// Reduce the lanes to one.
|
||||
Label("reduce")
|
||||
|
||||
// Manual reduction
|
||||
VADDPS(acc[0], acc[1], acc[0])
|
||||
VADDPS(acc[2], acc[3], acc[2])
|
||||
VADDPS(acc[0], acc[2], acc[0])
|
||||
|
||||
result := acc[0].AsX()
|
||||
top := XMM()
|
||||
VEXTRACTF128(U8(1), acc[0], top)
|
||||
VADDPS(result, top, result)
|
||||
VADDPS(result, tail, result)
|
||||
VHADDPS(result, result, result)
|
||||
VHADDPS(result, result, result)
|
||||
Store(result, ReturnIndex(0))
|
||||
|
||||
RET()
|
||||
|
||||
Generate()
|
||||
}
|
55
pkg/util/distance/asm/ip.s
Normal file
55
pkg/util/distance/asm/ip.s
Normal file
@ -0,0 +1,55 @@
|
||||
// Code generated by command: go run ip.go -out ip.s -stubs ip_stub.go. DO NOT EDIT.
|
||||
|
||||
#include "textflag.h"
|
||||
|
||||
// func IP(x []float32, y []float32) float32
|
||||
// Requires: AVX, FMA3, SSE
|
||||
TEXT ·IP(SB), NOSPLIT, $0-52
|
||||
MOVQ x_base+0(FP), AX
|
||||
MOVQ y_base+24(FP), CX
|
||||
MOVQ x_len+8(FP), DX
|
||||
VXORPS Y0, Y0, Y0
|
||||
VXORPS Y1, Y1, Y1
|
||||
VXORPS Y2, Y2, Y2
|
||||
VXORPS Y3, Y3, Y3
|
||||
|
||||
blockloop:
|
||||
CMPQ DX, $0x00000020
|
||||
JL tail
|
||||
VMOVUPS (AX), Y4
|
||||
VMOVUPS 32(AX), Y5
|
||||
VMOVUPS 64(AX), Y6
|
||||
VMOVUPS 96(AX), Y7
|
||||
VFMADD231PS (CX), Y4, Y0
|
||||
VFMADD231PS 32(CX), Y5, Y1
|
||||
VFMADD231PS 64(CX), Y6, Y2
|
||||
VFMADD231PS 96(CX), Y7, Y3
|
||||
ADDQ $0x00000080, AX
|
||||
ADDQ $0x00000080, CX
|
||||
SUBQ $0x00000020, DX
|
||||
JMP blockloop
|
||||
|
||||
tail:
|
||||
VXORPS X4, X4, X4
|
||||
|
||||
tailloop:
|
||||
CMPQ DX, $0x00000000
|
||||
JE reduce
|
||||
VMOVSS (AX), X5
|
||||
VFMADD231SS (CX), X5, X4
|
||||
ADDQ $0x00000004, AX
|
||||
ADDQ $0x00000004, CX
|
||||
DECQ DX
|
||||
JMP tailloop
|
||||
|
||||
reduce:
|
||||
VADDPS Y0, Y1, Y0
|
||||
VADDPS Y2, Y3, Y2
|
||||
VADDPS Y0, Y2, Y0
|
||||
VEXTRACTF128 $0x01, Y0, X1
|
||||
VADDPS X0, X1, X0
|
||||
VADDPS X0, X4, X0
|
||||
VHADDPS X0, X0, X0
|
||||
VHADDPS X0, X0, X0
|
||||
MOVSS X0, ret+48(FP)
|
||||
RET
|
6
pkg/util/distance/asm/ip_stub.go
Normal file
6
pkg/util/distance/asm/ip_stub.go
Normal file
@ -0,0 +1,6 @@
|
||||
// Code generated by command: go run ip.go -out ip.s -stubs ip_stub.go. DO NOT EDIT.
|
||||
|
||||
package asm
|
||||
|
||||
// inner product between x and y
|
||||
func IP(x []float32, y []float32) float32
|
105
pkg/util/distance/asm/l2.go
Normal file
105
pkg/util/distance/asm/l2.go
Normal file
@ -0,0 +1,105 @@
|
||||
//go:build ignore
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
. "github.com/mmcloughlin/avo/build"
|
||||
. "github.com/mmcloughlin/avo/operand"
|
||||
. "github.com/mmcloughlin/avo/reg"
|
||||
)
|
||||
|
||||
var unroll = 4
|
||||
|
||||
// inspired by the avo example https://github.com/mmcloughlin/avo
|
||||
// avo is a tool to generate go assembly
|
||||
func main() {
|
||||
TEXT("L2", NOSPLIT, "func(x, y []float32) float32")
|
||||
Doc("squared l2 between x and y")
|
||||
x := Mem{Base: Load(Param("x").Base(), GP64())}
|
||||
y := Mem{Base: Load(Param("y").Base(), GP64())}
|
||||
n := Load(Param("x").Len(), GP64())
|
||||
|
||||
acc := make([]VecVirtual, unroll)
|
||||
diff := make([]VecVirtual, unroll)
|
||||
for i := 0; i < unroll; i++ {
|
||||
acc[i] = YMM()
|
||||
diff[i] = YMM()
|
||||
}
|
||||
|
||||
for i := 0; i < unroll; i++ {
|
||||
VXORPS(acc[i], acc[i], acc[i])
|
||||
VXORPS(diff[i], diff[i], diff[i])
|
||||
}
|
||||
|
||||
blockitems := 8 * unroll
|
||||
blocksize := 4 * blockitems
|
||||
Label("blockloop")
|
||||
CMPQ(n, U32(blockitems))
|
||||
JL(LabelRef("tail"))
|
||||
|
||||
// Load x
|
||||
xs := make([]VecVirtual, unroll)
|
||||
for i := 0; i < unroll; i++ {
|
||||
xs[i] = YMM()
|
||||
}
|
||||
|
||||
for i := 0; i < unroll; i++ {
|
||||
VMOVUPS(x.Offset(32*i), xs[i])
|
||||
}
|
||||
|
||||
for i := 0; i < unroll; i++ {
|
||||
VSUBPS(y.Offset(32*i), xs[i], diff[i])
|
||||
}
|
||||
|
||||
for i := 0; i < unroll; i++ {
|
||||
VFMADD231PS(diff[i], diff[i], acc[i])
|
||||
}
|
||||
|
||||
ADDQ(U32(blocksize), x.Base)
|
||||
ADDQ(U32(blocksize), y.Base)
|
||||
SUBQ(U32(blockitems), n)
|
||||
JMP(LabelRef("blockloop"))
|
||||
|
||||
// Process any trailing entries
|
||||
Label("tail")
|
||||
tail := XMM()
|
||||
VXORPS(tail, tail, tail)
|
||||
|
||||
Label("tailloop")
|
||||
CMPQ(n, U32(0))
|
||||
JE(LabelRef("reduce"))
|
||||
|
||||
xt := XMM()
|
||||
VMOVSS(x, xt)
|
||||
|
||||
difft := XMM()
|
||||
VSUBSS(y, xt, difft)
|
||||
|
||||
VFMADD231SS(difft, difft, tail)
|
||||
|
||||
ADDQ(U32(4), x.Base)
|
||||
ADDQ(U32(4), y.Base)
|
||||
DECQ(n)
|
||||
JMP(LabelRef("tailloop"))
|
||||
|
||||
// Reduce the lanes to one
|
||||
Label("reduce")
|
||||
|
||||
// Manual reduction
|
||||
VADDPS(acc[0], acc[1], acc[0])
|
||||
VADDPS(acc[2], acc[3], acc[2])
|
||||
VADDPS(acc[0], acc[2], acc[0])
|
||||
|
||||
result := acc[0].AsX()
|
||||
top := XMM()
|
||||
VEXTRACTF128(U8(1), acc[0], top)
|
||||
VADDPS(result, top, result)
|
||||
VADDPS(result, tail, result)
|
||||
VHADDPS(result, result, result)
|
||||
VHADDPS(result, result, result)
|
||||
Store(result, ReturnIndex(0))
|
||||
|
||||
RET()
|
||||
|
||||
Generate()
|
||||
}
|
64
pkg/util/distance/asm/l2.s
Normal file
64
pkg/util/distance/asm/l2.s
Normal file
@ -0,0 +1,64 @@
|
||||
// Code generated by command: go run l2.go -out l2.s -stubs l2_stub.go. DO NOT EDIT.
|
||||
|
||||
#include "textflag.h"
|
||||
|
||||
// func L2(x []float32, y []float32) float32
|
||||
// Requires: AVX, FMA3, SSE
|
||||
TEXT ·L2(SB), NOSPLIT, $0-52
|
||||
MOVQ x_base+0(FP), AX
|
||||
MOVQ y_base+24(FP), CX
|
||||
MOVQ x_len+8(FP), DX
|
||||
VXORPS Y0, Y0, Y0
|
||||
VXORPS Y1, Y1, Y1
|
||||
VXORPS Y2, Y2, Y2
|
||||
VXORPS Y3, Y3, Y3
|
||||
VXORPS Y4, Y4, Y4
|
||||
VXORPS Y5, Y5, Y5
|
||||
VXORPS Y6, Y6, Y6
|
||||
VXORPS Y7, Y7, Y7
|
||||
|
||||
blockloop:
|
||||
CMPQ DX, $0x00000020
|
||||
JL tail
|
||||
VMOVUPS (AX), Y1
|
||||
VMOVUPS 32(AX), Y3
|
||||
VMOVUPS 64(AX), Y5
|
||||
VMOVUPS 96(AX), Y7
|
||||
VSUBPS (CX), Y1, Y1
|
||||
VSUBPS 32(CX), Y3, Y3
|
||||
VSUBPS 64(CX), Y5, Y5
|
||||
VSUBPS 96(CX), Y7, Y7
|
||||
VFMADD231PS Y1, Y1, Y0
|
||||
VFMADD231PS Y3, Y3, Y2
|
||||
VFMADD231PS Y5, Y5, Y4
|
||||
VFMADD231PS Y7, Y7, Y6
|
||||
ADDQ $0x00000080, AX
|
||||
ADDQ $0x00000080, CX
|
||||
SUBQ $0x00000020, DX
|
||||
JMP blockloop
|
||||
|
||||
tail:
|
||||
VXORPS X1, X1, X1
|
||||
|
||||
tailloop:
|
||||
CMPQ DX, $0x00000000
|
||||
JE reduce
|
||||
VMOVSS (AX), X3
|
||||
VSUBSS (CX), X3, X3
|
||||
VFMADD231SS X3, X3, X1
|
||||
ADDQ $0x00000004, AX
|
||||
ADDQ $0x00000004, CX
|
||||
DECQ DX
|
||||
JMP tailloop
|
||||
|
||||
reduce:
|
||||
VADDPS Y0, Y2, Y0
|
||||
VADDPS Y4, Y6, Y4
|
||||
VADDPS Y0, Y4, Y0
|
||||
VEXTRACTF128 $0x01, Y0, X2
|
||||
VADDPS X0, X2, X0
|
||||
VADDPS X0, X1, X0
|
||||
VHADDPS X0, X0, X0
|
||||
VHADDPS X0, X0, X0
|
||||
MOVSS X0, ret+48(FP)
|
||||
RET
|
6
pkg/util/distance/asm/l2_stub.go
Normal file
6
pkg/util/distance/asm/l2_stub.go
Normal file
@ -0,0 +1,6 @@
|
||||
// Code generated by command: go run l2.go -out l2.s -stubs l2_stub.go. DO NOT EDIT.
|
||||
|
||||
package asm
|
||||
|
||||
// squared l2 between x and y
|
||||
func L2(x []float32, y []float32) float32
|
166
pkg/util/distance/calc_distance.go
Normal file
166
pkg/util/distance/calc_distance.go
Normal file
@ -0,0 +1,166 @@
|
||||
package distance
|
||||
|
||||
import (
|
||||
"math"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"golang.org/x/sys/cpu"
|
||||
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/util/distance/asm"
|
||||
)
|
||||
|
||||
/**
|
||||
* Delete in #25663 Remove calc_distance
|
||||
* Add back partially as clustering feature needs to calculate distance between search vector and clustering center
|
||||
*/
|
||||
const (
|
||||
// L2 represents the Euclidean distance
|
||||
L2 = "L2"
|
||||
// IP represents the inner product distance
|
||||
IP = "IP"
|
||||
// COSINE represents the cosine distance
|
||||
COSINE = "COSINE"
|
||||
)
|
||||
|
||||
func L2ImplPure(a []float32, b []float32) float32 {
|
||||
var sum float32
|
||||
|
||||
for i := range a {
|
||||
sum += (a[i] - b[i]) * (a[i] - b[i])
|
||||
}
|
||||
|
||||
return sum
|
||||
}
|
||||
|
||||
func IPImplPure(a []float32, b []float32) float32 {
|
||||
var sum float32
|
||||
|
||||
for i := range a {
|
||||
sum += a[i] * b[i]
|
||||
}
|
||||
|
||||
return sum
|
||||
}
|
||||
|
||||
func CosineImplPure(a []float32, b []float32) float32 {
|
||||
var sum, normA, normB float32
|
||||
|
||||
for i := range a {
|
||||
sum += a[i] * b[i]
|
||||
normA += a[i] * a[i]
|
||||
normB += b[i] * b[i]
|
||||
}
|
||||
|
||||
return sum / float32(math.Sqrt(float64(normA)*float64(normB)))
|
||||
}
|
||||
|
||||
var (
|
||||
L2Impl func(a []float32, b []float32) float32
|
||||
IPImpl func(a []float32, b []float32) float32
|
||||
CosineImpl func(a []float32, b []float32) float32
|
||||
)
|
||||
|
||||
func init() {
|
||||
if cpu.X86.HasAVX2 {
|
||||
log.Info("Hook avx for go simd distance computation")
|
||||
IPImpl = asm.IP
|
||||
L2Impl = asm.L2
|
||||
CosineImpl = func(a []float32, b []float32) float32 {
|
||||
return asm.IP(a, b) / float32(math.Sqrt(float64(asm.IP(a, a))*float64((asm.IP(b, b)))))
|
||||
}
|
||||
} else {
|
||||
log.Info("Use pure go distance computation")
|
||||
IPImpl = IPImplPure
|
||||
L2Impl = L2ImplPure
|
||||
CosineImpl = CosineImplPure
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateMetricType returns metric text or error
|
||||
func ValidateMetricType(metric string) (string, error) {
|
||||
if metric == "" {
|
||||
err := errors.New("metric type is empty")
|
||||
return "", err
|
||||
}
|
||||
|
||||
m := strings.ToUpper(metric)
|
||||
if m == L2 || m == IP || m == COSINE {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
err := errors.New("invalid metric type")
|
||||
return metric, err
|
||||
}
|
||||
|
||||
// ValidateFloatArrayLength is used validate float vector length
|
||||
func ValidateFloatArrayLength(dim int64, length int) error {
|
||||
if length == 0 || int64(length)%dim != 0 {
|
||||
err := errors.New("invalid float vector length")
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CalcFFBatch calculate the distance of @left & @right vectors in batch by given @metic, store result in @result
|
||||
func CalcFFBatch(dim int64, left []float32, lIndex int64, right []float32, metric string, result *[]float32) {
|
||||
rightNum := int64(len(right)) / dim
|
||||
for i := int64(0); i < rightNum; i++ {
|
||||
var distance float32 = -1.0
|
||||
if metric == L2 {
|
||||
distance = L2Impl(left[lIndex*dim:lIndex*dim+dim], right[i*dim:i*dim+dim])
|
||||
} else if metric == IP {
|
||||
distance = IPImpl(left[lIndex*dim:lIndex*dim+dim], right[i*dim:i*dim+dim])
|
||||
} else if metric == COSINE {
|
||||
distance = CosineImpl(left[lIndex*dim:lIndex*dim+dim], right[i*dim:i*dim+dim])
|
||||
}
|
||||
(*result)[lIndex*rightNum+i] = distance
|
||||
}
|
||||
}
|
||||
|
||||
// CalcFloatDistance calculate float distance by given metric
|
||||
// it will checks input, and calculate the distance concurrently
|
||||
func CalcFloatDistance(dim int64, left, right []float32, metric string) ([]float32, error) {
|
||||
if dim <= 0 {
|
||||
err := errors.New("invalid dimension")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
metricUpper := strings.ToUpper(metric)
|
||||
if metricUpper != L2 && metricUpper != IP && metricUpper != COSINE {
|
||||
err := errors.New("invalid metric type")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err := ValidateFloatArrayLength(dim, len(left))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = ValidateFloatArrayLength(dim, len(right))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
leftNum := int64(len(left)) / dim
|
||||
rightNum := int64(len(right)) / dim
|
||||
|
||||
distArray := make([]float32, leftNum*rightNum)
|
||||
|
||||
// Multi-threads to calculate distance. TODO: avoid too many go routines
|
||||
var waitGroup sync.WaitGroup
|
||||
CalcWorker := func(index int64) {
|
||||
CalcFFBatch(dim, left, index, right, metricUpper, &distArray)
|
||||
waitGroup.Done()
|
||||
}
|
||||
for i := int64(0); i < leftNum; i++ {
|
||||
waitGroup.Add(1)
|
||||
go CalcWorker(i)
|
||||
}
|
||||
waitGroup.Wait()
|
||||
|
||||
return distArray, nil
|
||||
}
|
218
pkg/util/distance/calc_distance_test.go
Normal file
218
pkg/util/distance/calc_distance_test.go
Normal file
@ -0,0 +1,218 @@
|
||||
package distance
|
||||
|
||||
import (
|
||||
"math"
|
||||
"math/rand"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
const PRECISION = 1e-5
|
||||
|
||||
func TestValidateMetricType(t *testing.T) {
|
||||
invalidMetric := []string{"", "aaa"}
|
||||
for _, str := range invalidMetric {
|
||||
_, err := ValidateMetricType(str)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
validMetric := []string{"L2", "ip", "COSINE"}
|
||||
for _, str := range validMetric {
|
||||
metric, err := ValidateMetricType(str)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, metric == L2 || metric == IP || metric == COSINE)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateFloatArrayLength(t *testing.T) {
|
||||
err := ValidateFloatArrayLength(3, 12)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = ValidateFloatArrayLength(5, 11)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
func CreateFloatArray(n, dim int64) []float32 {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
num := n * dim
|
||||
array := make([]float32, num)
|
||||
for i := int64(0); i < num; i++ {
|
||||
array[i] = rand.Float32()
|
||||
}
|
||||
|
||||
return array
|
||||
}
|
||||
|
||||
func DistanceL2(left, right []float32) float32 {
|
||||
if len(left) != len(right) {
|
||||
panic("array dimension not equal")
|
||||
}
|
||||
var sum float32
|
||||
for i := 0; i < len(left); i++ {
|
||||
gap := left[i] - right[i]
|
||||
sum += gap * gap
|
||||
}
|
||||
|
||||
return sum
|
||||
}
|
||||
|
||||
func DistanceIP(left, right []float32) float32 {
|
||||
if len(left) != len(right) {
|
||||
panic("array dimension not equal")
|
||||
}
|
||||
var sum float32
|
||||
for i := 0; i < len(left); i++ {
|
||||
sum += left[i] * right[i]
|
||||
}
|
||||
|
||||
return sum
|
||||
}
|
||||
|
||||
func DistanceCosine(left, right []float32) float32 {
|
||||
if len(left) != len(right) {
|
||||
panic("array dimension not equal")
|
||||
}
|
||||
return DistanceIP(left, right) / float32(math.Sqrt(float64(DistanceIP(left, left))*float64(DistanceIP(right, right))))
|
||||
}
|
||||
|
||||
func Test_CalcL2(t *testing.T) {
|
||||
var dim int64 = 128
|
||||
var leftNum int64 = 1
|
||||
var rightNum int64 = 1
|
||||
|
||||
left := CreateFloatArray(leftNum, dim)
|
||||
right := CreateFloatArray(rightNum, dim)
|
||||
|
||||
sum := DistanceL2(left, right)
|
||||
|
||||
distance := L2Impl(left, right)
|
||||
assert.InEpsilon(t, sum, distance, PRECISION)
|
||||
distance = L2ImplPure(left, right)
|
||||
assert.InEpsilon(t, sum, distance, PRECISION)
|
||||
|
||||
left = []float32{0, 1, 2}
|
||||
right = []float32{1, 2, 3}
|
||||
expected := float32(3)
|
||||
distance = L2Impl(left, right)
|
||||
assert.Equal(t, expected, distance)
|
||||
|
||||
distance = L2ImplPure(left, right)
|
||||
assert.Equal(t, expected, distance)
|
||||
}
|
||||
|
||||
func Test_CalcIP(t *testing.T) {
|
||||
var dim int64 = 128
|
||||
var leftNum int64 = 1
|
||||
var rightNum int64 = 1
|
||||
|
||||
left := CreateFloatArray(leftNum, dim)
|
||||
right := CreateFloatArray(rightNum, dim)
|
||||
|
||||
sum := DistanceIP(left, right)
|
||||
|
||||
distance := IPImpl(left, right)
|
||||
assert.InEpsilon(t, sum, distance, PRECISION)
|
||||
distance = IPImplPure(left, right)
|
||||
assert.InEpsilon(t, sum, distance, PRECISION)
|
||||
|
||||
left = []float32{0, 1, 2}
|
||||
right = []float32{1, 2, 3}
|
||||
expected := float32(8)
|
||||
distance = IPImpl(left, right)
|
||||
assert.Equal(t, expected, distance)
|
||||
distance = IPImplPure(left, right)
|
||||
assert.Equal(t, expected, distance)
|
||||
}
|
||||
|
||||
func Test_CalcCosine(t *testing.T) {
|
||||
var dim int64 = 128
|
||||
var leftNum int64 = 1
|
||||
var rightNum int64 = 1
|
||||
|
||||
left := CreateFloatArray(leftNum, dim)
|
||||
right := CreateFloatArray(rightNum, dim)
|
||||
|
||||
sum := DistanceCosine(left, right)
|
||||
|
||||
distance := CosineImpl(left, right)
|
||||
assert.InEpsilon(t, sum, distance, PRECISION)
|
||||
distance = CosineImplPure(left, right)
|
||||
assert.InEpsilon(t, sum, distance, PRECISION)
|
||||
|
||||
left = []float32{0, 0, 10}
|
||||
right = []float32{6, 0, 8}
|
||||
expected := float32(0.8)
|
||||
distance = CosineImpl(left, right)
|
||||
assert.Equal(t, expected, distance)
|
||||
distance = CosineImplPure(left, right)
|
||||
assert.Equal(t, expected, distance)
|
||||
}
|
||||
|
||||
func Test_CalcFloatDistance(t *testing.T) {
|
||||
var dim int64 = 128
|
||||
var leftNum int64 = 10
|
||||
var rightNum int64 = 5
|
||||
|
||||
left := CreateFloatArray(leftNum, dim)
|
||||
right := CreateFloatArray(rightNum, dim)
|
||||
|
||||
// Verify illegal cases
|
||||
_, err := CalcFloatDistance(dim, left, right, "HAMMIN")
|
||||
assert.Error(t, err)
|
||||
|
||||
_, err = CalcFloatDistance(3, left, right, "L2")
|
||||
assert.Error(t, err)
|
||||
|
||||
_, err = CalcFloatDistance(dim, left, right, "HAMMIN")
|
||||
assert.Error(t, err)
|
||||
|
||||
_, err = CalcFloatDistance(0, left, right, "L2")
|
||||
assert.Error(t, err)
|
||||
|
||||
distances, err := CalcFloatDistance(dim, left, right, "L2")
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify the L2 distance algorithm is correct
|
||||
invalid := CreateFloatArray(rightNum, 10)
|
||||
_, err = CalcFloatDistance(dim, left, invalid, "L2")
|
||||
assert.Error(t, err)
|
||||
|
||||
for i := int64(0); i < leftNum; i++ {
|
||||
for j := int64(0); j < rightNum; j++ {
|
||||
v1 := left[i*dim : (i+1)*dim]
|
||||
v2 := right[j*dim : (j+1)*dim]
|
||||
sum := DistanceL2(v1, v2)
|
||||
assert.InEpsilon(t, sum, distances[i*rightNum+j], PRECISION)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify the IP distance algorithm is correct
|
||||
distances, err = CalcFloatDistance(dim, left, right, "IP")
|
||||
assert.NoError(t, err)
|
||||
|
||||
for i := int64(0); i < leftNum; i++ {
|
||||
for j := int64(0); j < rightNum; j++ {
|
||||
v1 := left[i*dim : (i+1)*dim]
|
||||
v2 := right[j*dim : (j+1)*dim]
|
||||
sum := DistanceIP(v1, v2)
|
||||
assert.InEpsilon(t, sum, distances[i*rightNum+j], PRECISION)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify the COSINE distance algorithm is correct
|
||||
distances, err = CalcFloatDistance(dim, left, right, "COSINE")
|
||||
assert.NoError(t, err)
|
||||
|
||||
for i := int64(0); i < leftNum; i++ {
|
||||
for j := int64(0); j < rightNum; j++ {
|
||||
v1 := left[i*dim : (i+1)*dim]
|
||||
v2 := right[j*dim : (j+1)*dim]
|
||||
sum := DistanceCosine(v1, v2)
|
||||
assert.InEpsilon(t, sum, distances[i*rightNum+j], PRECISION)
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user