enhance: add efficient distance computations in Go (#28657)

Related #28656
Add more efficient calc_distance at Go side.

---------

Signed-off-by: chasingegg <chao.gao@zilliz.com>
This commit is contained in:
Gao 2023-11-28 18:20:26 +08:00 committed by GitHub
parent 606ec77b66
commit 5bd73d5503
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 717 additions and 0 deletions

View File

@ -0,0 +1,97 @@
//go:build ignore
package main
import (
. "github.com/mmcloughlin/avo/build"
. "github.com/mmcloughlin/avo/operand"
. "github.com/mmcloughlin/avo/reg"
)
var unroll = 4
// inspired by the avo example https://github.com/mmcloughlin/avo
// avo is a tool to generate go assembly
func main() {
TEXT("IP", NOSPLIT, "func(x, y []float32) float32")
Doc("inner product between x and y")
x := Mem{Base: Load(Param("x").Base(), GP64())}
y := Mem{Base: Load(Param("y").Base(), GP64())}
n := Load(Param("x").Len(), GP64())
acc := make([]VecVirtual, unroll)
for i := 0; i < unroll; i++ {
acc[i] = YMM()
}
// Zero initialization
for i := 0; i < unroll; i++ {
VXORPS(acc[i], acc[i], acc[i])
}
// Loop over blocks and process them with vector instructions
blockitems := 8 * unroll
blocksize := 4 * blockitems
Label("blockloop")
CMPQ(n, U32(blockitems))
JL(LabelRef("tail"))
// Load x
xs := make([]VecVirtual, unroll)
for i := 0; i < unroll; i++ {
xs[i] = YMM()
}
for i := 0; i < unroll; i++ {
VMOVUPS(x.Offset(32*i), xs[i])
}
// The actual FMA
for i := 0; i < unroll; i++ {
VFMADD231PS(y.Offset(32*i), xs[i], acc[i])
}
ADDQ(U32(blocksize), x.Base)
ADDQ(U32(blocksize), y.Base)
SUBQ(U32(blockitems), n)
JMP(LabelRef("blockloop"))
// Process any trailing entries
Label("tail")
tail := XMM()
VXORPS(tail, tail, tail)
Label("tailloop")
CMPQ(n, U32(0))
JE(LabelRef("reduce"))
xt := XMM()
VMOVSS(x, xt)
VFMADD231SS(y, xt, tail)
ADDQ(U32(4), x.Base)
ADDQ(U32(4), y.Base)
DECQ(n)
JMP(LabelRef("tailloop"))
// Reduce the lanes to one.
Label("reduce")
// Manual reduction
VADDPS(acc[0], acc[1], acc[0])
VADDPS(acc[2], acc[3], acc[2])
VADDPS(acc[0], acc[2], acc[0])
result := acc[0].AsX()
top := XMM()
VEXTRACTF128(U8(1), acc[0], top)
VADDPS(result, top, result)
VADDPS(result, tail, result)
VHADDPS(result, result, result)
VHADDPS(result, result, result)
Store(result, ReturnIndex(0))
RET()
Generate()
}

View File

@ -0,0 +1,55 @@
// Code generated by command: go run ip.go -out ip.s -stubs ip_stub.go. DO NOT EDIT.
#include "textflag.h"
// func IP(x []float32, y []float32) float32
// Requires: AVX, FMA3, SSE
TEXT ·IP(SB), NOSPLIT, $0-52
MOVQ x_base+0(FP), AX
MOVQ y_base+24(FP), CX
MOVQ x_len+8(FP), DX
VXORPS Y0, Y0, Y0
VXORPS Y1, Y1, Y1
VXORPS Y2, Y2, Y2
VXORPS Y3, Y3, Y3
blockloop:
CMPQ DX, $0x00000020
JL tail
VMOVUPS (AX), Y4
VMOVUPS 32(AX), Y5
VMOVUPS 64(AX), Y6
VMOVUPS 96(AX), Y7
VFMADD231PS (CX), Y4, Y0
VFMADD231PS 32(CX), Y5, Y1
VFMADD231PS 64(CX), Y6, Y2
VFMADD231PS 96(CX), Y7, Y3
ADDQ $0x00000080, AX
ADDQ $0x00000080, CX
SUBQ $0x00000020, DX
JMP blockloop
tail:
VXORPS X4, X4, X4
tailloop:
CMPQ DX, $0x00000000
JE reduce
VMOVSS (AX), X5
VFMADD231SS (CX), X5, X4
ADDQ $0x00000004, AX
ADDQ $0x00000004, CX
DECQ DX
JMP tailloop
reduce:
VADDPS Y0, Y1, Y0
VADDPS Y2, Y3, Y2
VADDPS Y0, Y2, Y0
VEXTRACTF128 $0x01, Y0, X1
VADDPS X0, X1, X0
VADDPS X0, X4, X0
VHADDPS X0, X0, X0
VHADDPS X0, X0, X0
MOVSS X0, ret+48(FP)
RET

View File

@ -0,0 +1,6 @@
// Code generated by command: go run ip.go -out ip.s -stubs ip_stub.go. DO NOT EDIT.
package asm
// inner product between x and y
func IP(x []float32, y []float32) float32

105
pkg/util/distance/asm/l2.go Normal file
View File

@ -0,0 +1,105 @@
//go:build ignore
package main
import (
. "github.com/mmcloughlin/avo/build"
. "github.com/mmcloughlin/avo/operand"
. "github.com/mmcloughlin/avo/reg"
)
var unroll = 4
// inspired by the avo example https://github.com/mmcloughlin/avo
// avo is a tool to generate go assembly
func main() {
TEXT("L2", NOSPLIT, "func(x, y []float32) float32")
Doc("squared l2 between x and y")
x := Mem{Base: Load(Param("x").Base(), GP64())}
y := Mem{Base: Load(Param("y").Base(), GP64())}
n := Load(Param("x").Len(), GP64())
acc := make([]VecVirtual, unroll)
diff := make([]VecVirtual, unroll)
for i := 0; i < unroll; i++ {
acc[i] = YMM()
diff[i] = YMM()
}
for i := 0; i < unroll; i++ {
VXORPS(acc[i], acc[i], acc[i])
VXORPS(diff[i], diff[i], diff[i])
}
blockitems := 8 * unroll
blocksize := 4 * blockitems
Label("blockloop")
CMPQ(n, U32(blockitems))
JL(LabelRef("tail"))
// Load x
xs := make([]VecVirtual, unroll)
for i := 0; i < unroll; i++ {
xs[i] = YMM()
}
for i := 0; i < unroll; i++ {
VMOVUPS(x.Offset(32*i), xs[i])
}
for i := 0; i < unroll; i++ {
VSUBPS(y.Offset(32*i), xs[i], diff[i])
}
for i := 0; i < unroll; i++ {
VFMADD231PS(diff[i], diff[i], acc[i])
}
ADDQ(U32(blocksize), x.Base)
ADDQ(U32(blocksize), y.Base)
SUBQ(U32(blockitems), n)
JMP(LabelRef("blockloop"))
// Process any trailing entries
Label("tail")
tail := XMM()
VXORPS(tail, tail, tail)
Label("tailloop")
CMPQ(n, U32(0))
JE(LabelRef("reduce"))
xt := XMM()
VMOVSS(x, xt)
difft := XMM()
VSUBSS(y, xt, difft)
VFMADD231SS(difft, difft, tail)
ADDQ(U32(4), x.Base)
ADDQ(U32(4), y.Base)
DECQ(n)
JMP(LabelRef("tailloop"))
// Reduce the lanes to one
Label("reduce")
// Manual reduction
VADDPS(acc[0], acc[1], acc[0])
VADDPS(acc[2], acc[3], acc[2])
VADDPS(acc[0], acc[2], acc[0])
result := acc[0].AsX()
top := XMM()
VEXTRACTF128(U8(1), acc[0], top)
VADDPS(result, top, result)
VADDPS(result, tail, result)
VHADDPS(result, result, result)
VHADDPS(result, result, result)
Store(result, ReturnIndex(0))
RET()
Generate()
}

View File

@ -0,0 +1,64 @@
// Code generated by command: go run l2.go -out l2.s -stubs l2_stub.go. DO NOT EDIT.
#include "textflag.h"
// func L2(x []float32, y []float32) float32
// Requires: AVX, FMA3, SSE
TEXT ·L2(SB), NOSPLIT, $0-52
MOVQ x_base+0(FP), AX
MOVQ y_base+24(FP), CX
MOVQ x_len+8(FP), DX
VXORPS Y0, Y0, Y0
VXORPS Y1, Y1, Y1
VXORPS Y2, Y2, Y2
VXORPS Y3, Y3, Y3
VXORPS Y4, Y4, Y4
VXORPS Y5, Y5, Y5
VXORPS Y6, Y6, Y6
VXORPS Y7, Y7, Y7
blockloop:
CMPQ DX, $0x00000020
JL tail
VMOVUPS (AX), Y1
VMOVUPS 32(AX), Y3
VMOVUPS 64(AX), Y5
VMOVUPS 96(AX), Y7
VSUBPS (CX), Y1, Y1
VSUBPS 32(CX), Y3, Y3
VSUBPS 64(CX), Y5, Y5
VSUBPS 96(CX), Y7, Y7
VFMADD231PS Y1, Y1, Y0
VFMADD231PS Y3, Y3, Y2
VFMADD231PS Y5, Y5, Y4
VFMADD231PS Y7, Y7, Y6
ADDQ $0x00000080, AX
ADDQ $0x00000080, CX
SUBQ $0x00000020, DX
JMP blockloop
tail:
VXORPS X1, X1, X1
tailloop:
CMPQ DX, $0x00000000
JE reduce
VMOVSS (AX), X3
VSUBSS (CX), X3, X3
VFMADD231SS X3, X3, X1
ADDQ $0x00000004, AX
ADDQ $0x00000004, CX
DECQ DX
JMP tailloop
reduce:
VADDPS Y0, Y2, Y0
VADDPS Y4, Y6, Y4
VADDPS Y0, Y4, Y0
VEXTRACTF128 $0x01, Y0, X2
VADDPS X0, X2, X0
VADDPS X0, X1, X0
VHADDPS X0, X0, X0
VHADDPS X0, X0, X0
MOVSS X0, ret+48(FP)
RET

View File

@ -0,0 +1,6 @@
// Code generated by command: go run l2.go -out l2.s -stubs l2_stub.go. DO NOT EDIT.
package asm
// squared l2 between x and y
func L2(x []float32, y []float32) float32

View File

@ -0,0 +1,166 @@
package distance
import (
"math"
"strings"
"sync"
"github.com/cockroachdb/errors"
"golang.org/x/sys/cpu"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/distance/asm"
)
/**
* Delete in #25663 Remove calc_distance
* Add back partially as clustering feature needs to calculate distance between search vector and clustering center
*/
const (
// L2 represents the Euclidean distance
L2 = "L2"
// IP represents the inner product distance
IP = "IP"
// COSINE represents the cosine distance
COSINE = "COSINE"
)
func L2ImplPure(a []float32, b []float32) float32 {
var sum float32
for i := range a {
sum += (a[i] - b[i]) * (a[i] - b[i])
}
return sum
}
func IPImplPure(a []float32, b []float32) float32 {
var sum float32
for i := range a {
sum += a[i] * b[i]
}
return sum
}
func CosineImplPure(a []float32, b []float32) float32 {
var sum, normA, normB float32
for i := range a {
sum += a[i] * b[i]
normA += a[i] * a[i]
normB += b[i] * b[i]
}
return sum / float32(math.Sqrt(float64(normA)*float64(normB)))
}
var (
L2Impl func(a []float32, b []float32) float32
IPImpl func(a []float32, b []float32) float32
CosineImpl func(a []float32, b []float32) float32
)
func init() {
if cpu.X86.HasAVX2 {
log.Info("Hook avx for go simd distance computation")
IPImpl = asm.IP
L2Impl = asm.L2
CosineImpl = func(a []float32, b []float32) float32 {
return asm.IP(a, b) / float32(math.Sqrt(float64(asm.IP(a, a))*float64((asm.IP(b, b)))))
}
} else {
log.Info("Use pure go distance computation")
IPImpl = IPImplPure
L2Impl = L2ImplPure
CosineImpl = CosineImplPure
}
}
// ValidateMetricType returns metric text or error
func ValidateMetricType(metric string) (string, error) {
if metric == "" {
err := errors.New("metric type is empty")
return "", err
}
m := strings.ToUpper(metric)
if m == L2 || m == IP || m == COSINE {
return m, nil
}
err := errors.New("invalid metric type")
return metric, err
}
// ValidateFloatArrayLength is used validate float vector length
func ValidateFloatArrayLength(dim int64, length int) error {
if length == 0 || int64(length)%dim != 0 {
err := errors.New("invalid float vector length")
return err
}
return nil
}
// CalcFFBatch calculate the distance of @left & @right vectors in batch by given @metic, store result in @result
func CalcFFBatch(dim int64, left []float32, lIndex int64, right []float32, metric string, result *[]float32) {
rightNum := int64(len(right)) / dim
for i := int64(0); i < rightNum; i++ {
var distance float32 = -1.0
if metric == L2 {
distance = L2Impl(left[lIndex*dim:lIndex*dim+dim], right[i*dim:i*dim+dim])
} else if metric == IP {
distance = IPImpl(left[lIndex*dim:lIndex*dim+dim], right[i*dim:i*dim+dim])
} else if metric == COSINE {
distance = CosineImpl(left[lIndex*dim:lIndex*dim+dim], right[i*dim:i*dim+dim])
}
(*result)[lIndex*rightNum+i] = distance
}
}
// CalcFloatDistance calculate float distance by given metric
// it will checks input, and calculate the distance concurrently
func CalcFloatDistance(dim int64, left, right []float32, metric string) ([]float32, error) {
if dim <= 0 {
err := errors.New("invalid dimension")
return nil, err
}
metricUpper := strings.ToUpper(metric)
if metricUpper != L2 && metricUpper != IP && metricUpper != COSINE {
err := errors.New("invalid metric type")
return nil, err
}
err := ValidateFloatArrayLength(dim, len(left))
if err != nil {
return nil, err
}
err = ValidateFloatArrayLength(dim, len(right))
if err != nil {
return nil, err
}
leftNum := int64(len(left)) / dim
rightNum := int64(len(right)) / dim
distArray := make([]float32, leftNum*rightNum)
// Multi-threads to calculate distance. TODO: avoid too many go routines
var waitGroup sync.WaitGroup
CalcWorker := func(index int64) {
CalcFFBatch(dim, left, index, right, metricUpper, &distArray)
waitGroup.Done()
}
for i := int64(0); i < leftNum; i++ {
waitGroup.Add(1)
go CalcWorker(i)
}
waitGroup.Wait()
return distArray, nil
}

View File

@ -0,0 +1,218 @@
package distance
import (
"math"
"math/rand"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
const PRECISION = 1e-5
func TestValidateMetricType(t *testing.T) {
invalidMetric := []string{"", "aaa"}
for _, str := range invalidMetric {
_, err := ValidateMetricType(str)
assert.Error(t, err)
}
validMetric := []string{"L2", "ip", "COSINE"}
for _, str := range validMetric {
metric, err := ValidateMetricType(str)
assert.NoError(t, err)
assert.True(t, metric == L2 || metric == IP || metric == COSINE)
}
}
func TestValidateFloatArrayLength(t *testing.T) {
err := ValidateFloatArrayLength(3, 12)
assert.NoError(t, err)
err = ValidateFloatArrayLength(5, 11)
assert.Error(t, err)
}
////////////////////////////////////////////////////////////////////////////////
func CreateFloatArray(n, dim int64) []float32 {
rand.Seed(time.Now().UnixNano())
num := n * dim
array := make([]float32, num)
for i := int64(0); i < num; i++ {
array[i] = rand.Float32()
}
return array
}
func DistanceL2(left, right []float32) float32 {
if len(left) != len(right) {
panic("array dimension not equal")
}
var sum float32
for i := 0; i < len(left); i++ {
gap := left[i] - right[i]
sum += gap * gap
}
return sum
}
func DistanceIP(left, right []float32) float32 {
if len(left) != len(right) {
panic("array dimension not equal")
}
var sum float32
for i := 0; i < len(left); i++ {
sum += left[i] * right[i]
}
return sum
}
func DistanceCosine(left, right []float32) float32 {
if len(left) != len(right) {
panic("array dimension not equal")
}
return DistanceIP(left, right) / float32(math.Sqrt(float64(DistanceIP(left, left))*float64(DistanceIP(right, right))))
}
func Test_CalcL2(t *testing.T) {
var dim int64 = 128
var leftNum int64 = 1
var rightNum int64 = 1
left := CreateFloatArray(leftNum, dim)
right := CreateFloatArray(rightNum, dim)
sum := DistanceL2(left, right)
distance := L2Impl(left, right)
assert.InEpsilon(t, sum, distance, PRECISION)
distance = L2ImplPure(left, right)
assert.InEpsilon(t, sum, distance, PRECISION)
left = []float32{0, 1, 2}
right = []float32{1, 2, 3}
expected := float32(3)
distance = L2Impl(left, right)
assert.Equal(t, expected, distance)
distance = L2ImplPure(left, right)
assert.Equal(t, expected, distance)
}
func Test_CalcIP(t *testing.T) {
var dim int64 = 128
var leftNum int64 = 1
var rightNum int64 = 1
left := CreateFloatArray(leftNum, dim)
right := CreateFloatArray(rightNum, dim)
sum := DistanceIP(left, right)
distance := IPImpl(left, right)
assert.InEpsilon(t, sum, distance, PRECISION)
distance = IPImplPure(left, right)
assert.InEpsilon(t, sum, distance, PRECISION)
left = []float32{0, 1, 2}
right = []float32{1, 2, 3}
expected := float32(8)
distance = IPImpl(left, right)
assert.Equal(t, expected, distance)
distance = IPImplPure(left, right)
assert.Equal(t, expected, distance)
}
func Test_CalcCosine(t *testing.T) {
var dim int64 = 128
var leftNum int64 = 1
var rightNum int64 = 1
left := CreateFloatArray(leftNum, dim)
right := CreateFloatArray(rightNum, dim)
sum := DistanceCosine(left, right)
distance := CosineImpl(left, right)
assert.InEpsilon(t, sum, distance, PRECISION)
distance = CosineImplPure(left, right)
assert.InEpsilon(t, sum, distance, PRECISION)
left = []float32{0, 0, 10}
right = []float32{6, 0, 8}
expected := float32(0.8)
distance = CosineImpl(left, right)
assert.Equal(t, expected, distance)
distance = CosineImplPure(left, right)
assert.Equal(t, expected, distance)
}
func Test_CalcFloatDistance(t *testing.T) {
var dim int64 = 128
var leftNum int64 = 10
var rightNum int64 = 5
left := CreateFloatArray(leftNum, dim)
right := CreateFloatArray(rightNum, dim)
// Verify illegal cases
_, err := CalcFloatDistance(dim, left, right, "HAMMIN")
assert.Error(t, err)
_, err = CalcFloatDistance(3, left, right, "L2")
assert.Error(t, err)
_, err = CalcFloatDistance(dim, left, right, "HAMMIN")
assert.Error(t, err)
_, err = CalcFloatDistance(0, left, right, "L2")
assert.Error(t, err)
distances, err := CalcFloatDistance(dim, left, right, "L2")
assert.NoError(t, err)
// Verify the L2 distance algorithm is correct
invalid := CreateFloatArray(rightNum, 10)
_, err = CalcFloatDistance(dim, left, invalid, "L2")
assert.Error(t, err)
for i := int64(0); i < leftNum; i++ {
for j := int64(0); j < rightNum; j++ {
v1 := left[i*dim : (i+1)*dim]
v2 := right[j*dim : (j+1)*dim]
sum := DistanceL2(v1, v2)
assert.InEpsilon(t, sum, distances[i*rightNum+j], PRECISION)
}
}
// Verify the IP distance algorithm is correct
distances, err = CalcFloatDistance(dim, left, right, "IP")
assert.NoError(t, err)
for i := int64(0); i < leftNum; i++ {
for j := int64(0); j < rightNum; j++ {
v1 := left[i*dim : (i+1)*dim]
v2 := right[j*dim : (j+1)*dim]
sum := DistanceIP(v1, v2)
assert.InEpsilon(t, sum, distances[i*rightNum+j], PRECISION)
}
}
// Verify the COSINE distance algorithm is correct
distances, err = CalcFloatDistance(dim, left, right, "COSINE")
assert.NoError(t, err)
for i := int64(0); i < leftNum; i++ {
for j := int64(0); j < rightNum; j++ {
v1 := left[i*dim : (i+1)*dim]
v2 := right[j*dim : (j+1)*dim]
sum := DistanceCosine(v1, v2)
assert.InEpsilon(t, sum, distances[i*rightNum+j], PRECISION)
}
}
}