mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-12-06 05:49:24 +08:00
63ee525bf3
Signed-off-by: shaoyue.chen <shaoyue.chen@zilliz.com>
281 lines
7.3 KiB
Go
281 lines
7.3 KiB
Go
// Licensed to the LF AI & Data foundation under one
|
|
// or more contributor license agreements. See the NOTICE file
|
|
// distributed with this work for additional information
|
|
// regarding copyright ownership. The ASF licenses this file
|
|
// to you under the Apache License, Version 2.0 (the
|
|
// "License"); you may not use this file except in compliance
|
|
// with the License. You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package distance
|
|
|
|
import (
|
|
"errors"
|
|
"strings"
|
|
"sync"
|
|
)
|
|
|
|
const (
|
|
// L2 represents the Euclidean distance
|
|
L2 = "L2"
|
|
// IP represents the inner product distance
|
|
IP = "IP"
|
|
// HAMMING represents the hamming distance
|
|
HAMMING = "HAMMING"
|
|
// TANIMOTO represents the tanimoto distance
|
|
TANIMOTO = "TANIMOTO"
|
|
// JACCARD in string
|
|
JACCARD = "JACCARD"
|
|
// SUPERSTRUCTURE in string
|
|
SUPERSTRUCTURE = "SUPERSTRUCTURE"
|
|
// SUBSTRUCTURE in string
|
|
SUBSTRUCTURE = "SUBSTRUCTURE"
|
|
)
|
|
|
|
// ValidateMetricType returns metric text or error
|
|
func ValidateMetricType(metric string) (string, error) {
|
|
if metric == "" {
|
|
err := errors.New("metric type is empty")
|
|
return "", err
|
|
}
|
|
|
|
m := strings.ToUpper(metric)
|
|
if m == L2 || m == IP || m == HAMMING || m == TANIMOTO {
|
|
return m, nil
|
|
}
|
|
|
|
err := errors.New("invalid metric type")
|
|
return metric, err
|
|
}
|
|
|
|
// ValidateFloatArrayLength is used validate float vector length
|
|
func ValidateFloatArrayLength(dim int64, length int) error {
|
|
if length == 0 || int64(length)%dim != 0 {
|
|
err := errors.New("invalid float vector length")
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// CalcL2 returns the Euclidean distance of input vectors
|
|
func CalcL2(dim int64, left []float32, lIndex int64, right []float32, rIndex int64) float32 {
|
|
var sum float32
|
|
lFrom := lIndex * dim
|
|
rFrom := rIndex * dim
|
|
for i := int64(0); i < dim; i++ {
|
|
gap := left[lFrom+i] - right[rFrom+i]
|
|
sum += gap * gap
|
|
}
|
|
|
|
return sum
|
|
}
|
|
|
|
// CalcIP returns the inner product distance of input vectors
|
|
func CalcIP(dim int64, left []float32, lIndex int64, right []float32, rIndex int64) float32 {
|
|
var sum float32
|
|
lFrom := lIndex * dim
|
|
rFrom := rIndex * dim
|
|
for i := int64(0); i < dim; i++ {
|
|
sum += left[lFrom+i] * right[rFrom+i]
|
|
}
|
|
|
|
return sum
|
|
}
|
|
|
|
// CalcFFBatch calculate the distance of @left & @right vectors in batch by given @metic, store result in @result
|
|
func CalcFFBatch(dim int64, left []float32, lIndex int64, right []float32, metric string, result *[]float32) {
|
|
rightNum := int64(len(right)) / dim
|
|
for i := int64(0); i < rightNum; i++ {
|
|
var distance float32 = -1.0
|
|
if metric == L2 {
|
|
distance = CalcL2(dim, left, lIndex, right, i)
|
|
} else if metric == IP {
|
|
distance = CalcIP(dim, left, lIndex, right, i)
|
|
}
|
|
(*result)[lIndex*rightNum+i] = distance
|
|
}
|
|
}
|
|
|
|
// CalcFloatDistance calculate float distance by given metric
|
|
// it will checks input, and calculate the distance concurrently
|
|
func CalcFloatDistance(dim int64, left, right []float32, metric string) ([]float32, error) {
|
|
if dim <= 0 {
|
|
err := errors.New("invalid dimension")
|
|
return nil, err
|
|
}
|
|
|
|
metricUpper := strings.ToUpper(metric)
|
|
if metricUpper != L2 && metricUpper != IP {
|
|
err := errors.New("invalid metric type")
|
|
return nil, err
|
|
}
|
|
|
|
err := ValidateFloatArrayLength(dim, len(left))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
err = ValidateFloatArrayLength(dim, len(right))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
leftNum := int64(len(left)) / dim
|
|
rightNum := int64(len(right)) / dim
|
|
|
|
distArray := make([]float32, leftNum*rightNum)
|
|
|
|
// Multi-threads to calculate distance. TODO: avoid too many go routines
|
|
var waitGroup sync.WaitGroup
|
|
CalcWorker := func(index int64) {
|
|
CalcFFBatch(dim, left, index, right, metricUpper, &distArray)
|
|
waitGroup.Done()
|
|
}
|
|
for i := int64(0); i < leftNum; i++ {
|
|
waitGroup.Add(1)
|
|
go CalcWorker(i)
|
|
}
|
|
waitGroup.Wait()
|
|
|
|
return distArray, nil
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// SingleBitLen returns the bit length of @dim
|
|
func SingleBitLen(dim int64) int64 {
|
|
if dim%8 == 0 {
|
|
return dim
|
|
}
|
|
|
|
return dim + 8 - dim%8
|
|
}
|
|
|
|
// VectorCount counts bits by @dim & @length
|
|
func VectorCount(dim int64, length int) int64 {
|
|
singleBitLen := SingleBitLen(dim)
|
|
return int64(length*8) / singleBitLen
|
|
}
|
|
|
|
// ValidateBinaryArrayLength validates a binary array of @dim & @length
|
|
func ValidateBinaryArrayLength(dim int64, length int) error {
|
|
singleBitLen := SingleBitLen(dim)
|
|
totalBitLen := int64(length * 8)
|
|
if length == 0 || totalBitLen%singleBitLen != 0 {
|
|
err := errors.New("invalid binary vector length")
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// CountOne count 1 of uint8
|
|
// For 00000010, return 1
|
|
// Fro 11111111, return 8
|
|
func CountOne(n uint8) int32 {
|
|
count := int32(0)
|
|
for n != 0 {
|
|
count++
|
|
n = n & (n - 1)
|
|
}
|
|
return count
|
|
}
|
|
|
|
// CalcHamming calculate HAMMING distance
|
|
func CalcHamming(dim int64, left []byte, lIndex int64, right []byte, rIndex int64) int32 {
|
|
singleBitLen := SingleBitLen(dim)
|
|
numBytes := singleBitLen / 8
|
|
lFrom := lIndex * numBytes
|
|
rFrom := rIndex * numBytes
|
|
|
|
var hamming int32
|
|
for i := int64(0); i < numBytes; i++ {
|
|
var xor = left[lFrom+i] ^ right[rFrom+i]
|
|
|
|
// The dimension "dim" may not be an integer multiple of 8
|
|
// For example:
|
|
// dim = 11, each vector has 2 uint8 value
|
|
// the second uint8, only need to calculate 3 bits, the other 5 bits will be set to 0
|
|
if i == numBytes-1 && numBytes*8 > dim {
|
|
offset := numBytes*8 - dim
|
|
xor = xor & (255 << offset)
|
|
}
|
|
|
|
hamming += CountOne(xor)
|
|
}
|
|
|
|
return hamming
|
|
}
|
|
|
|
// CalcHammingBatch calculate HAMMING distance in batch, results are in @result
|
|
func CalcHammingBatch(dim int64, left []byte, lIndex int64, right []byte, result *[]int32) {
|
|
rightNum := VectorCount(dim, len(right))
|
|
|
|
for i := int64(0); i < rightNum; i++ {
|
|
hamming := CalcHamming(dim, left, lIndex, right, i)
|
|
(*result)[lIndex*rightNum+i] = hamming
|
|
}
|
|
}
|
|
|
|
func CalcHammingDistance(dim int64, left, right []byte) ([]int32, error) {
|
|
if dim <= 0 {
|
|
err := errors.New("invalid dimension")
|
|
return nil, err
|
|
}
|
|
|
|
err := ValidateBinaryArrayLength(dim, len(left))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
err = ValidateBinaryArrayLength(dim, len(right))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
leftNum := VectorCount(dim, len(left))
|
|
rightNum := VectorCount(dim, len(right))
|
|
distArray := make([]int32, leftNum*rightNum)
|
|
|
|
// Multi-threads to calculate distance. TODO: avoid too many go routines
|
|
var waitGroup sync.WaitGroup
|
|
CalcWorker := func(index int64) {
|
|
CalcHammingBatch(dim, left, index, right, &distArray)
|
|
waitGroup.Done()
|
|
}
|
|
for i := int64(0); i < leftNum; i++ {
|
|
waitGroup.Add(1)
|
|
go CalcWorker(i)
|
|
}
|
|
waitGroup.Wait()
|
|
|
|
return distArray, nil
|
|
}
|
|
|
|
func CalcTanimotoCoefficient(dim int64, hamming []int32) ([]float32, error) {
|
|
if dim <= 0 || len(hamming) == 0 {
|
|
err := errors.New("invalid input for tanimoto")
|
|
return nil, err
|
|
}
|
|
|
|
array := make([]float32, len(hamming))
|
|
for i := 0; i < len(hamming); i++ {
|
|
if hamming[i] > int32(dim) {
|
|
err := errors.New("invalid hamming for tanimoto")
|
|
return nil, err
|
|
}
|
|
equalBits := int32(dim) - hamming[i]
|
|
array[i] = float32(equalBits) / (float32(dim)*2 - float32(equalBits))
|
|
}
|
|
|
|
return array, nil
|
|
}
|