mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-11-29 18:38:44 +08:00
Refactor reduce using cgo in query node
Signed-off-by: xige-16 <xi.ge@zilliz.com>
This commit is contained in:
parent
16c96fa170
commit
cf11212932
6
build/docker/env/cpu/ubuntu18.04/Dockerfile
vendored
6
build/docker/env/cpu/ubuntu18.04/Dockerfile
vendored
@ -17,10 +17,10 @@ SHELL ["/bin/bash", "-o", "pipefail", "-c"]
|
|||||||
|
|
||||||
ENV DEBIAN_FRONTEND noninteractive
|
ENV DEBIAN_FRONTEND noninteractive
|
||||||
|
|
||||||
RUN apt-get update && apt-get install -y --no-install-recommends wget curl ca-certificates gnupg2 && \
|
RUN apt-get update && apt-get install -y --no-install-recommends wget curl ca-certificates gnupg2 clang-format-10 && \
|
||||||
wget -qO- "https://cmake.org/files/v3.14/cmake-3.14.3-Linux-x86_64.tar.gz" | tar --strip-components=1 -xz -C /usr/local && \
|
wget -qO- "https://cmake.org/files/v3.14/cmake-3.14.3-Linux-x86_64.tar.gz" | tar --strip-components=1 -xz -C /usr/local && \
|
||||||
apt-get update && apt-get install -y --no-install-recommends \
|
apt-get update && apt-get install -y --no-install-recommends \
|
||||||
g++ gcc gfortran git make ccache libssl-dev zlib1g-dev libboost-regex-dev libboost-program-options-dev libboost-system-dev libboost-filesystem-dev libboost-serialization-dev python3-dev libboost-python-dev libcurl4-openssl-dev libtbb-dev clang-format-10 clang-tidy-10 lcov && \
|
g++ gcc gfortran git make ccache libssl-dev zlib1g-dev libboost-regex-dev libboost-program-options-dev libboost-system-dev libboost-filesystem-dev libboost-serialization-dev python3-dev libboost-python-dev libcurl4-openssl-dev libtbb-dev clang-format-7 clang-tidy-7 lcov && \
|
||||||
apt-get remove --purge -y && \
|
apt-get remove --purge -y && \
|
||||||
rm -rf /var/lib/apt/lists/*
|
rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
@ -28,7 +28,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends wget curl ca-ce
|
|||||||
RUN wget https://github.com/xianyi/OpenBLAS/archive/v0.3.9.tar.gz && \
|
RUN wget https://github.com/xianyi/OpenBLAS/archive/v0.3.9.tar.gz && \
|
||||||
tar zxvf v0.3.9.tar.gz && cd OpenBLAS-0.3.9 && \
|
tar zxvf v0.3.9.tar.gz && cd OpenBLAS-0.3.9 && \
|
||||||
make TARGET=CORE2 DYNAMIC_ARCH=1 DYNAMIC_OLDER=1 USE_THREAD=0 USE_OPENMP=0 FC=gfortran CC=gcc COMMON_OPT="-O3 -g -fPIC" FCOMMON_OPT="-O3 -g -fPIC -frecursive" NMAX="NUM_THREADS=128" LIBPREFIX="libopenblas" LAPACKE="NO_LAPACKE=1" INTERFACE64=0 NO_STATIC=1 && \
|
make TARGET=CORE2 DYNAMIC_ARCH=1 DYNAMIC_OLDER=1 USE_THREAD=0 USE_OPENMP=0 FC=gfortran CC=gcc COMMON_OPT="-O3 -g -fPIC" FCOMMON_OPT="-O3 -g -fPIC -frecursive" NMAX="NUM_THREADS=128" LIBPREFIX="libopenblas" LAPACKE="NO_LAPACKE=1" INTERFACE64=0 NO_STATIC=1 && \
|
||||||
make PREFIX=/usr NO_STATIC=1 install && \
|
make PREFIX=/usr install && \
|
||||||
cd .. && rm -rf OpenBLAS-0.3.9 && rm v0.3.9.tar.gz
|
cd .. && rm -rf OpenBLAS-0.3.9 && rm v0.3.9.tar.gz
|
||||||
ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/usr/lib"
|
ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/usr/lib"
|
||||||
|
|
||||||
|
@ -10,7 +10,8 @@ set(SEGCORE_FILES
|
|||||||
IndexingEntry.cpp
|
IndexingEntry.cpp
|
||||||
InsertRecord.cpp
|
InsertRecord.cpp
|
||||||
Reduce.cpp
|
Reduce.cpp
|
||||||
plan_c.cpp)
|
plan_c.cpp
|
||||||
|
reduce_c.cpp)
|
||||||
add_library(milvus_segcore SHARED
|
add_library(milvus_segcore SHARED
|
||||||
${SEGCORE_FILES}
|
${SEGCORE_FILES}
|
||||||
)
|
)
|
||||||
|
@ -1,8 +1,11 @@
|
|||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
|
||||||
|
#include "Reduce.h"
|
||||||
|
|
||||||
namespace milvus::segcore {
|
namespace milvus::segcore {
|
||||||
void
|
Status
|
||||||
merge_into(int64_t queries,
|
merge_into(int64_t queries,
|
||||||
int64_t topk,
|
int64_t topk,
|
||||||
float* distances,
|
float* distances,
|
||||||
@ -37,5 +40,6 @@ merge_into(int64_t queries,
|
|||||||
std::copy_n(buf_dis.data(), topk, src2_dis);
|
std::copy_n(buf_dis.data(), topk, src2_dis);
|
||||||
std::copy_n(buf_uids.data(), topk, src2_uids);
|
std::copy_n(buf_uids.data(), topk, src2_uids);
|
||||||
}
|
}
|
||||||
|
return Status::OK();
|
||||||
}
|
}
|
||||||
} // namespace milvus::segcore
|
} // namespace milvus::segcore
|
||||||
|
@ -2,8 +2,11 @@
|
|||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
|
||||||
|
#include "utils/Status.h"
|
||||||
|
|
||||||
namespace milvus::segcore {
|
namespace milvus::segcore {
|
||||||
void
|
Status
|
||||||
merge_into(int64_t num_queries,
|
merge_into(int64_t num_queries,
|
||||||
int64_t topk,
|
int64_t topk,
|
||||||
float* distances,
|
float* distances,
|
||||||
|
@ -20,8 +20,8 @@ ParsePlaceholderGroup(CPlan c_plan, void* placeholder_group_blob, long int blob_
|
|||||||
}
|
}
|
||||||
|
|
||||||
long int
|
long int
|
||||||
GetNumOfQueries(CPlaceholderGroup placeholderGroup) {
|
GetNumOfQueries(CPlaceholderGroup placeholder_group) {
|
||||||
auto res = milvus::query::GetNumOfQueries((milvus::query::PlaceholderGroup*)placeholderGroup);
|
auto res = milvus::query::GetNumOfQueries((milvus::query::PlaceholderGroup*)placeholder_group);
|
||||||
|
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
@ -41,8 +41,8 @@ DeletePlan(CPlan cPlan) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void
|
void
|
||||||
DeletePlaceholderGroup(CPlaceholderGroup cPlaceholderGroup) {
|
DeletePlaceholderGroup(CPlaceholderGroup cPlaceholder_group) {
|
||||||
auto placeHolderGroup = (milvus::query::PlaceholderGroup*)cPlaceholderGroup;
|
auto placeHolder_group = (milvus::query::PlaceholderGroup*)cPlaceholder_group;
|
||||||
delete placeHolderGroup;
|
delete placeHolder_group;
|
||||||
std::cout << "delete placeholder" << std::endl;
|
std::cout << "delete placeholder" << std::endl;
|
||||||
}
|
}
|
||||||
|
@ -15,7 +15,7 @@ CPlaceholderGroup
|
|||||||
ParsePlaceholderGroup(CPlan plan, void* placeholder_group_blob, long int blob_size);
|
ParsePlaceholderGroup(CPlan plan, void* placeholder_group_blob, long int blob_size);
|
||||||
|
|
||||||
long int
|
long int
|
||||||
GetNumOfQueries(CPlaceholderGroup placeholderGroup);
|
GetNumOfQueries(CPlaceholderGroup placeholder_group);
|
||||||
|
|
||||||
long int
|
long int
|
||||||
GetTopK(CPlan plan);
|
GetTopK(CPlan plan);
|
||||||
@ -24,7 +24,7 @@ void
|
|||||||
DeletePlan(CPlan plan);
|
DeletePlan(CPlan plan);
|
||||||
|
|
||||||
void
|
void
|
||||||
DeletePlaceholderGroup(CPlaceholderGroup placeholderGroup);
|
DeletePlaceholderGroup(CPlaceholderGroup placeholder_group);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
|
9
internal/core/src/segcore/reduce_c.cpp
Normal file
9
internal/core/src/segcore/reduce_c.cpp
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
#include "reduce_c.h"
|
||||||
|
#include "Reduce.h"
|
||||||
|
|
||||||
|
int
|
||||||
|
MergeInto(
|
||||||
|
long int num_queries, long int topk, float* distances, long int* uids, float* new_distances, long int* new_uids) {
|
||||||
|
auto status = milvus::segcore::merge_into(num_queries, topk, distances, uids, new_distances, new_uids);
|
||||||
|
return status.code();
|
||||||
|
}
|
13
internal/core/src/segcore/reduce_c.h
Normal file
13
internal/core/src/segcore/reduce_c.h
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include <stdbool.h>
|
||||||
|
|
||||||
|
int
|
||||||
|
MergeInto(
|
||||||
|
long int num_queries, long int topk, float* distances, long int* uids, float* new_distances, long int* new_uids);
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
@ -6,6 +6,7 @@
|
|||||||
#include "segcore/collection_c.h"
|
#include "segcore/collection_c.h"
|
||||||
#include "segcore/segment_c.h"
|
#include "segcore/segment_c.h"
|
||||||
#include "pb/service_msg.pb.h"
|
#include "pb/service_msg.pb.h"
|
||||||
|
#include "segcore/reduce_c.h"
|
||||||
|
|
||||||
#include <chrono>
|
#include <chrono>
|
||||||
namespace chrono = std::chrono;
|
namespace chrono = std::chrono;
|
||||||
@ -510,4 +511,33 @@ TEST(CApiTest, GetRowCountTest) {
|
|||||||
// auto segment = NewSegment(collection, 0);
|
// auto segment = NewSegment(collection, 0);
|
||||||
// DeleteCollection(collection);
|
// DeleteCollection(collection);
|
||||||
// DeleteSegment(segment);
|
// DeleteSegment(segment);
|
||||||
//}
|
//}
|
||||||
|
|
||||||
|
TEST(CApiTest, MergeInto) {
|
||||||
|
std::vector<int64_t> uids;
|
||||||
|
std::vector<float> distance;
|
||||||
|
|
||||||
|
std::vector<int64_t> new_uids;
|
||||||
|
std::vector<float> new_distance;
|
||||||
|
|
||||||
|
int64_t num_queries = 1;
|
||||||
|
int64_t topk = 2;
|
||||||
|
|
||||||
|
uids.push_back(1);
|
||||||
|
uids.push_back(2);
|
||||||
|
distance.push_back(5);
|
||||||
|
distance.push_back(1000);
|
||||||
|
|
||||||
|
new_uids.push_back(3);
|
||||||
|
new_uids.push_back(4);
|
||||||
|
new_distance.push_back(2);
|
||||||
|
new_distance.push_back(6);
|
||||||
|
|
||||||
|
auto res = MergeInto(num_queries, topk, distance.data(), uids.data(), new_distance.data(), new_uids.data());
|
||||||
|
|
||||||
|
ASSERT_EQ(res, 0);
|
||||||
|
ASSERT_EQ(uids[0], 3);
|
||||||
|
ASSERT_EQ(distance[0], 2);
|
||||||
|
ASSERT_EQ(uids[1], 1);
|
||||||
|
ASSERT_EQ(distance[1], 5);
|
||||||
|
}
|
||||||
|
@ -6,7 +6,8 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"sort"
|
"math"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/golang/protobuf/proto"
|
"github.com/golang/protobuf/proto"
|
||||||
|
|
||||||
@ -17,8 +18,11 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type searchService struct {
|
type searchService struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
cancel context.CancelFunc
|
wait sync.WaitGroup
|
||||||
|
cancel context.CancelFunc
|
||||||
|
msgBuffer chan msgstream.TsMsg
|
||||||
|
unsolvedMsg []msgstream.TsMsg
|
||||||
|
|
||||||
replica *collectionReplica
|
replica *collectionReplica
|
||||||
tSafeWatcher *tSafeWatcher
|
tSafeWatcher *tSafeWatcher
|
||||||
@ -29,11 +33,6 @@ type searchService struct {
|
|||||||
|
|
||||||
type ResultEntityIds []UniqueID
|
type ResultEntityIds []UniqueID
|
||||||
|
|
||||||
type SearchResult struct {
|
|
||||||
ResultIds []UniqueID
|
|
||||||
ResultDistances []float32
|
|
||||||
}
|
|
||||||
|
|
||||||
func newSearchService(ctx context.Context, replica *collectionReplica) *searchService {
|
func newSearchService(ctx context.Context, replica *collectionReplica) *searchService {
|
||||||
receiveBufSize := Params.searchReceiveBufSize()
|
receiveBufSize := Params.searchReceiveBufSize()
|
||||||
pulsarBufSize := Params.searchPulsarBufSize()
|
pulsarBufSize := Params.searchPulsarBufSize()
|
||||||
@ -58,9 +57,13 @@ func newSearchService(ctx context.Context, replica *collectionReplica) *searchSe
|
|||||||
var outputStream msgstream.MsgStream = searchResultStream
|
var outputStream msgstream.MsgStream = searchResultStream
|
||||||
|
|
||||||
searchServiceCtx, searchServiceCancel := context.WithCancel(ctx)
|
searchServiceCtx, searchServiceCancel := context.WithCancel(ctx)
|
||||||
|
msgBuffer := make(chan msgstream.TsMsg, receiveBufSize)
|
||||||
|
unsolvedMsg := make([]msgstream.TsMsg, 0)
|
||||||
return &searchService{
|
return &searchService{
|
||||||
ctx: searchServiceCtx,
|
ctx: searchServiceCtx,
|
||||||
cancel: searchServiceCancel,
|
cancel: searchServiceCancel,
|
||||||
|
msgBuffer: msgBuffer,
|
||||||
|
unsolvedMsg: unsolvedMsg,
|
||||||
|
|
||||||
replica: replica,
|
replica: replica,
|
||||||
tSafeWatcher: newTSafeWatcher(),
|
tSafeWatcher: newTSafeWatcher(),
|
||||||
@ -73,27 +76,10 @@ func newSearchService(ctx context.Context, replica *collectionReplica) *searchSe
|
|||||||
func (ss *searchService) start() {
|
func (ss *searchService) start() {
|
||||||
(*ss.searchMsgStream).Start()
|
(*ss.searchMsgStream).Start()
|
||||||
(*ss.searchResultMsgStream).Start()
|
(*ss.searchResultMsgStream).Start()
|
||||||
|
ss.wait.Add(2)
|
||||||
go func() {
|
go ss.receiveSearchMsg()
|
||||||
for {
|
go ss.startSearchService()
|
||||||
select {
|
ss.wait.Wait()
|
||||||
case <-ss.ctx.Done():
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
msgPack := (*ss.searchMsgStream).Consume()
|
|
||||||
if msgPack == nil || len(msgPack.Msgs) <= 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
// TODO: add serviceTime check
|
|
||||||
err := ss.search(msgPack.Msgs)
|
|
||||||
if err != nil {
|
|
||||||
fmt.Println("search Failed")
|
|
||||||
ss.publishFailedSearchResult()
|
|
||||||
}
|
|
||||||
fmt.Println("Do search done")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ss *searchService) close() {
|
func (ss *searchService) close() {
|
||||||
@ -114,12 +100,68 @@ func (ss *searchService) waitNewTSafe() Timestamp {
|
|||||||
return timestamp
|
return timestamp
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ss *searchService) search(searchMessages []msgstream.TsMsg) error {
|
func (ss *searchService) receiveSearchMsg() {
|
||||||
|
defer ss.wait.Done()
|
||||||
type SearchResult struct {
|
for {
|
||||||
ResultID int64
|
select {
|
||||||
ResultDistance float32
|
case <-ss.ctx.Done():
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
msgPack := (*ss.searchMsgStream).Consume()
|
||||||
|
if msgPack == nil || len(msgPack.Msgs) <= 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for i := range msgPack.Msgs {
|
||||||
|
ss.msgBuffer <- msgPack.Msgs[i]
|
||||||
|
//fmt.Println("receive a search msg")
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ss *searchService) startSearchService() {
|
||||||
|
defer ss.wait.Done()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ss.ctx.Done():
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
serviceTimestamp := (*(*ss.replica).getTSafe()).get()
|
||||||
|
searchMsg := make([]msgstream.TsMsg, 0)
|
||||||
|
tempMsg := make([]msgstream.TsMsg, 0)
|
||||||
|
tempMsg = append(tempMsg, ss.unsolvedMsg...)
|
||||||
|
ss.unsolvedMsg = ss.unsolvedMsg[:0]
|
||||||
|
for _, msg := range tempMsg {
|
||||||
|
if msg.BeginTs() > serviceTimestamp {
|
||||||
|
searchMsg = append(searchMsg, msg)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
ss.unsolvedMsg = append(ss.unsolvedMsg, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
msgBufferLength := len(ss.msgBuffer)
|
||||||
|
for i := 0; i < msgBufferLength; i++ {
|
||||||
|
msg := <-ss.msgBuffer
|
||||||
|
if msg.BeginTs() > serviceTimestamp {
|
||||||
|
searchMsg = append(searchMsg, msg)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
ss.unsolvedMsg = append(ss.unsolvedMsg, msg)
|
||||||
|
}
|
||||||
|
if len(searchMsg) <= 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
err := ss.search(searchMsg)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println("search Failed")
|
||||||
|
ss.publishFailedSearchResult()
|
||||||
|
}
|
||||||
|
fmt.Println("Do search done")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ss *searchService) search(searchMessages []msgstream.TsMsg) error {
|
||||||
// TODO:: cache map[dsl]plan
|
// TODO:: cache map[dsl]plan
|
||||||
// TODO: reBatched search requests
|
// TODO: reBatched search requests
|
||||||
for _, msg := range searchMessages {
|
for _, msg := range searchMessages {
|
||||||
@ -129,8 +171,6 @@ func (ss *searchService) search(searchMessages []msgstream.TsMsg) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
searchTimestamp := searchMsg.Timestamp
|
searchTimestamp := searchMsg.Timestamp
|
||||||
|
|
||||||
// TODO:: add serviceable time
|
|
||||||
var queryBlob = searchMsg.Query.Value
|
var queryBlob = searchMsg.Query.Value
|
||||||
query := servicepb.Query{}
|
query := servicepb.Query{}
|
||||||
err := proto.Unmarshal(queryBlob, &query)
|
err := proto.Unmarshal(queryBlob, &query)
|
||||||
@ -162,9 +202,11 @@ func (ss *searchService) search(searchMessages []msgstream.TsMsg) error {
|
|||||||
for _, pg := range placeholderGroups {
|
for _, pg := range placeholderGroups {
|
||||||
numQueries += pg.GetNumOfQuery()
|
numQueries += pg.GetNumOfQuery()
|
||||||
}
|
}
|
||||||
var searchResults = make([][]SearchResult, numQueries)
|
|
||||||
for i := 0; i < int(numQueries); i++ {
|
resultIds := make([]IntPrimaryKey, topK*numQueries)
|
||||||
searchResults[i] = make([]SearchResult, 0)
|
resultDistances := make([]float32, topK*numQueries)
|
||||||
|
for i := range resultDistances {
|
||||||
|
resultDistances[i] = math.MaxFloat32
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. Do search in all segments
|
// 3. Do search in all segments
|
||||||
@ -174,42 +216,27 @@ func (ss *searchService) search(searchMessages []msgstream.TsMsg) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
for _, segment := range partition.segments {
|
for _, segment := range partition.segments {
|
||||||
res, err := segment.segmentSearch(plan, placeholderGroups, []Timestamp{searchTimestamp}, numQueries, topK)
|
err := segment.segmentSearch(plan,
|
||||||
|
placeholderGroups,
|
||||||
|
[]Timestamp{searchTimestamp},
|
||||||
|
resultIds,
|
||||||
|
resultDistances,
|
||||||
|
numQueries,
|
||||||
|
topK)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
for i := 0; int64(i) < numQueries; i++ {
|
|
||||||
for j := int64(i) * topK; j < int64(i+1)*topK; j++ {
|
|
||||||
searchResults[i] = append(searchResults[i], SearchResult{
|
|
||||||
ResultID: res.ResultIds[j],
|
|
||||||
ResultDistance: res.ResultDistances[j],
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 4. Reduce results
|
|
||||||
// TODO::reduce in c++ merge_into func
|
|
||||||
for _, temp := range searchResults {
|
|
||||||
sort.Slice(temp, func(i, j int) bool {
|
|
||||||
return temp[i].ResultDistance < temp[j].ResultDistance
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, tmp := range searchResults {
|
|
||||||
if int64(len(tmp)) > topK {
|
|
||||||
searchResults[i] = searchResults[i][:topK]
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 4. return results
|
||||||
hits := make([]*servicepb.Hits, 0)
|
hits := make([]*servicepb.Hits, 0)
|
||||||
for _, value := range searchResults {
|
for i := int64(0); i < numQueries; i++ {
|
||||||
hit := servicepb.Hits{}
|
hit := servicepb.Hits{}
|
||||||
score := servicepb.Score{}
|
score := servicepb.Score{}
|
||||||
for j := 0; int64(j) < topK; j++ {
|
for j := i * topK; j < (i+1)*topK; j++ {
|
||||||
hit.IDs = append(hit.IDs, value[j].ResultID)
|
hit.IDs = append(hit.IDs, resultIds[j])
|
||||||
score.Values = append(score.Values, value[j].ResultDistance)
|
score.Values = append(score.Values, resultDistances[j])
|
||||||
}
|
}
|
||||||
hit.Scores = append(hit.Scores, &score)
|
hit.Scores = append(hit.Scores, &score)
|
||||||
hits = append(hits, &hit)
|
hits = append(hits, &hit)
|
||||||
|
@ -175,8 +175,9 @@ func TestSearch_Search(t *testing.T) {
|
|||||||
searchStream.SetPulsarClient(pulsarURL)
|
searchStream.SetPulsarClient(pulsarURL)
|
||||||
searchStream.CreatePulsarProducers(searchProducerChannels)
|
searchStream.CreatePulsarProducers(searchProducerChannels)
|
||||||
|
|
||||||
|
var vecSearch = [DIM]float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 17}
|
||||||
var searchRawData []byte
|
var searchRawData []byte
|
||||||
for _, ele := range vec {
|
for _, ele := range vecSearch {
|
||||||
buf := make([]byte, 4)
|
buf := make([]byte, 4)
|
||||||
binary.LittleEndian.PutUint32(buf, math.Float32bits(ele))
|
binary.LittleEndian.PutUint32(buf, math.Float32bits(ele))
|
||||||
searchRawData = append(searchRawData, buf...)
|
searchRawData = append(searchRawData, buf...)
|
||||||
|
@ -9,6 +9,7 @@ package reader
|
|||||||
#include "collection_c.h"
|
#include "collection_c.h"
|
||||||
#include "segment_c.h"
|
#include "segment_c.h"
|
||||||
#include "plan_c.h"
|
#include "plan_c.h"
|
||||||
|
#include "reduce_c.h"
|
||||||
|
|
||||||
*/
|
*/
|
||||||
import "C"
|
import "C"
|
||||||
@ -178,14 +179,24 @@ func (s *Segment) segmentDelete(offset int64, entityIDs *[]UniqueID, timestamps
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Segment) segmentSearch(plan *Plan, placeHolderGroups []*PlaceholderGroup, timestamp []Timestamp, numQueries int64, topK int64) (*SearchResult, error) {
|
func (s *Segment) segmentSearch(plan *Plan,
|
||||||
|
placeHolderGroups []*PlaceholderGroup,
|
||||||
|
timestamp []Timestamp,
|
||||||
|
resultIds []IntPrimaryKey,
|
||||||
|
resultDistances []float32,
|
||||||
|
numQueries int64,
|
||||||
|
topK int64) error {
|
||||||
/*
|
/*
|
||||||
void* Search(void* plan, void* placeholder_groups, uint64_t* timestamps, int num_groups, long int* result_ids,
|
void* Search(void* plan,
|
||||||
float* result_distances)
|
void* placeholder_groups,
|
||||||
|
uint64_t* timestamps,
|
||||||
|
int num_groups,
|
||||||
|
long int* result_ids,
|
||||||
|
float* result_distances);
|
||||||
*/
|
*/
|
||||||
|
|
||||||
resultIds := make([]IntPrimaryKey, topK*numQueries)
|
newResultIds := make([]IntPrimaryKey, topK*numQueries)
|
||||||
resultDistances := make([]float32, topK*numQueries)
|
NewResultDistances := make([]float32, topK*numQueries)
|
||||||
cPlaceholderGroups := make([]C.CPlaceholderGroup, 0)
|
cPlaceholderGroups := make([]C.CPlaceholderGroup, 0)
|
||||||
for _, pg := range placeHolderGroups {
|
for _, pg := range placeHolderGroups {
|
||||||
cPlaceholderGroups = append(cPlaceholderGroups, (*pg).cPlaceholderGroup)
|
cPlaceholderGroups = append(cPlaceholderGroups, (*pg).cPlaceholderGroup)
|
||||||
@ -194,16 +205,22 @@ func (s *Segment) segmentSearch(plan *Plan, placeHolderGroups []*PlaceholderGrou
|
|||||||
var cTimestamp = (*C.ulong)(×tamp[0])
|
var cTimestamp = (*C.ulong)(×tamp[0])
|
||||||
var cResultIds = (*C.long)(&resultIds[0])
|
var cResultIds = (*C.long)(&resultIds[0])
|
||||||
var cResultDistances = (*C.float)(&resultDistances[0])
|
var cResultDistances = (*C.float)(&resultDistances[0])
|
||||||
|
var cNewResultIds = (*C.long)(&newResultIds[0])
|
||||||
|
var cNewResultDistances = (*C.float)(&NewResultDistances[0])
|
||||||
var cPlaceHolder = (*C.CPlaceholderGroup)(&cPlaceholderGroups[0])
|
var cPlaceHolder = (*C.CPlaceholderGroup)(&cPlaceholderGroups[0])
|
||||||
var cNumGroups = C.int(len(placeHolderGroups))
|
var cNumGroups = C.int(len(placeHolderGroups))
|
||||||
|
|
||||||
var status = C.Search(s.segmentPtr, plan.cPlan, cPlaceHolder, cTimestamp, cNumGroups, cResultIds, cResultDistances)
|
var status = C.Search(s.segmentPtr, plan.cPlan, cPlaceHolder, cTimestamp, cNumGroups, cNewResultIds, cNewResultDistances)
|
||||||
|
|
||||||
if status != 0 {
|
if status != 0 {
|
||||||
return nil, errors.New("search failed, error code = " + strconv.Itoa(int(status)))
|
return errors.New("search failed, error code = " + strconv.Itoa(int(status)))
|
||||||
}
|
}
|
||||||
|
|
||||||
//fmt.Println("search Result---- Ids =", resultIds, ", Distances =", resultDistances)
|
cNumQueries := C.long(numQueries)
|
||||||
|
cTopK := C.long(topK)
|
||||||
return &SearchResult{ResultIds: resultIds, ResultDistances: resultDistances}, nil
|
// reduce search result
|
||||||
|
status = C.MergeInto(cNumQueries, cTopK, cResultDistances, cResultIds, cNewResultDistances, cNewResultIds)
|
||||||
|
if status != 0 {
|
||||||
|
return errors.New("merge search result failed, error code = " + strconv.Itoa(int(status)))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -661,8 +661,13 @@ func TestSegment_segmentSearch(t *testing.T) {
|
|||||||
for _, pg := range placeholderGroups {
|
for _, pg := range placeholderGroups {
|
||||||
numQueries += pg.GetNumOfQuery()
|
numQueries += pg.GetNumOfQuery()
|
||||||
}
|
}
|
||||||
|
resultIds := make([]IntPrimaryKey, topK*numQueries)
|
||||||
|
resultDistances := make([]float32, topK*numQueries)
|
||||||
|
for i := range resultDistances {
|
||||||
|
resultDistances[i] = math.MaxFloat32
|
||||||
|
}
|
||||||
|
|
||||||
_, err = segment.segmentSearch(cPlan, placeholderGroups, []Timestamp{searchTimestamp}, numQueries, topK)
|
err = segment.segmentSearch(cPlan, placeholderGroups, []Timestamp{searchTimestamp}, resultIds, resultDistances, numQueries, topK)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user