Refactor reduce using cgo in query node

Signed-off-by: xige-16 <xi.ge@zilliz.com>
This commit is contained in:
xige-16 2020-11-19 14:13:39 +08:00 committed by yefu.chen
parent 16c96fa170
commit cf11212932
13 changed files with 205 additions and 95 deletions

View File

@ -17,10 +17,10 @@ SHELL ["/bin/bash", "-o", "pipefail", "-c"]
ENV DEBIAN_FRONTEND noninteractive ENV DEBIAN_FRONTEND noninteractive
RUN apt-get update && apt-get install -y --no-install-recommends wget curl ca-certificates gnupg2 && \ RUN apt-get update && apt-get install -y --no-install-recommends wget curl ca-certificates gnupg2 clang-format-10 && \
wget -qO- "https://cmake.org/files/v3.14/cmake-3.14.3-Linux-x86_64.tar.gz" | tar --strip-components=1 -xz -C /usr/local && \ wget -qO- "https://cmake.org/files/v3.14/cmake-3.14.3-Linux-x86_64.tar.gz" | tar --strip-components=1 -xz -C /usr/local && \
apt-get update && apt-get install -y --no-install-recommends \ apt-get update && apt-get install -y --no-install-recommends \
g++ gcc gfortran git make ccache libssl-dev zlib1g-dev libboost-regex-dev libboost-program-options-dev libboost-system-dev libboost-filesystem-dev libboost-serialization-dev python3-dev libboost-python-dev libcurl4-openssl-dev libtbb-dev clang-format-10 clang-tidy-10 lcov && \ g++ gcc gfortran git make ccache libssl-dev zlib1g-dev libboost-regex-dev libboost-program-options-dev libboost-system-dev libboost-filesystem-dev libboost-serialization-dev python3-dev libboost-python-dev libcurl4-openssl-dev libtbb-dev clang-format-7 clang-tidy-7 lcov && \
apt-get remove --purge -y && \ apt-get remove --purge -y && \
rm -rf /var/lib/apt/lists/* rm -rf /var/lib/apt/lists/*
@ -28,7 +28,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends wget curl ca-ce
RUN wget https://github.com/xianyi/OpenBLAS/archive/v0.3.9.tar.gz && \ RUN wget https://github.com/xianyi/OpenBLAS/archive/v0.3.9.tar.gz && \
tar zxvf v0.3.9.tar.gz && cd OpenBLAS-0.3.9 && \ tar zxvf v0.3.9.tar.gz && cd OpenBLAS-0.3.9 && \
make TARGET=CORE2 DYNAMIC_ARCH=1 DYNAMIC_OLDER=1 USE_THREAD=0 USE_OPENMP=0 FC=gfortran CC=gcc COMMON_OPT="-O3 -g -fPIC" FCOMMON_OPT="-O3 -g -fPIC -frecursive" NMAX="NUM_THREADS=128" LIBPREFIX="libopenblas" LAPACKE="NO_LAPACKE=1" INTERFACE64=0 NO_STATIC=1 && \ make TARGET=CORE2 DYNAMIC_ARCH=1 DYNAMIC_OLDER=1 USE_THREAD=0 USE_OPENMP=0 FC=gfortran CC=gcc COMMON_OPT="-O3 -g -fPIC" FCOMMON_OPT="-O3 -g -fPIC -frecursive" NMAX="NUM_THREADS=128" LIBPREFIX="libopenblas" LAPACKE="NO_LAPACKE=1" INTERFACE64=0 NO_STATIC=1 && \
make PREFIX=/usr NO_STATIC=1 install && \ make PREFIX=/usr install && \
cd .. && rm -rf OpenBLAS-0.3.9 && rm v0.3.9.tar.gz cd .. && rm -rf OpenBLAS-0.3.9 && rm v0.3.9.tar.gz
ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/usr/lib" ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/usr/lib"

View File

@ -10,7 +10,8 @@ set(SEGCORE_FILES
IndexingEntry.cpp IndexingEntry.cpp
InsertRecord.cpp InsertRecord.cpp
Reduce.cpp Reduce.cpp
plan_c.cpp) plan_c.cpp
reduce_c.cpp)
add_library(milvus_segcore SHARED add_library(milvus_segcore SHARED
${SEGCORE_FILES} ${SEGCORE_FILES}
) )

View File

@ -1,8 +1,11 @@
#include <cstdint> #include <cstdint>
#include <vector> #include <vector>
#include <algorithm> #include <algorithm>
#include "Reduce.h"
namespace milvus::segcore { namespace milvus::segcore {
void Status
merge_into(int64_t queries, merge_into(int64_t queries,
int64_t topk, int64_t topk,
float* distances, float* distances,
@ -37,5 +40,6 @@ merge_into(int64_t queries,
std::copy_n(buf_dis.data(), topk, src2_dis); std::copy_n(buf_dis.data(), topk, src2_dis);
std::copy_n(buf_uids.data(), topk, src2_uids); std::copy_n(buf_uids.data(), topk, src2_uids);
} }
return Status::OK();
} }
} // namespace milvus::segcore } // namespace milvus::segcore

View File

@ -2,8 +2,11 @@
#include <cstdint> #include <cstdint>
#include <vector> #include <vector>
#include <algorithm> #include <algorithm>
#include "utils/Status.h"
namespace milvus::segcore { namespace milvus::segcore {
void Status
merge_into(int64_t num_queries, merge_into(int64_t num_queries,
int64_t topk, int64_t topk,
float* distances, float* distances,

View File

@ -20,8 +20,8 @@ ParsePlaceholderGroup(CPlan c_plan, void* placeholder_group_blob, long int blob_
} }
long int long int
GetNumOfQueries(CPlaceholderGroup placeholderGroup) { GetNumOfQueries(CPlaceholderGroup placeholder_group) {
auto res = milvus::query::GetNumOfQueries((milvus::query::PlaceholderGroup*)placeholderGroup); auto res = milvus::query::GetNumOfQueries((milvus::query::PlaceholderGroup*)placeholder_group);
return res; return res;
} }
@ -41,8 +41,8 @@ DeletePlan(CPlan cPlan) {
} }
void void
DeletePlaceholderGroup(CPlaceholderGroup cPlaceholderGroup) { DeletePlaceholderGroup(CPlaceholderGroup cPlaceholder_group) {
auto placeHolderGroup = (milvus::query::PlaceholderGroup*)cPlaceholderGroup; auto placeHolder_group = (milvus::query::PlaceholderGroup*)cPlaceholder_group;
delete placeHolderGroup; delete placeHolder_group;
std::cout << "delete placeholder" << std::endl; std::cout << "delete placeholder" << std::endl;
} }

View File

@ -15,7 +15,7 @@ CPlaceholderGroup
ParsePlaceholderGroup(CPlan plan, void* placeholder_group_blob, long int blob_size); ParsePlaceholderGroup(CPlan plan, void* placeholder_group_blob, long int blob_size);
long int long int
GetNumOfQueries(CPlaceholderGroup placeholderGroup); GetNumOfQueries(CPlaceholderGroup placeholder_group);
long int long int
GetTopK(CPlan plan); GetTopK(CPlan plan);
@ -24,7 +24,7 @@ void
DeletePlan(CPlan plan); DeletePlan(CPlan plan);
void void
DeletePlaceholderGroup(CPlaceholderGroup placeholderGroup); DeletePlaceholderGroup(CPlaceholderGroup placeholder_group);
#ifdef __cplusplus #ifdef __cplusplus
} }

View File

@ -0,0 +1,9 @@
#include "reduce_c.h"
#include "Reduce.h"
int
MergeInto(
long int num_queries, long int topk, float* distances, long int* uids, float* new_distances, long int* new_uids) {
auto status = milvus::segcore::merge_into(num_queries, topk, distances, uids, new_distances, new_uids);
return status.code();
}

View File

@ -0,0 +1,13 @@
#ifdef __cplusplus
extern "C" {
#endif
#include <stdbool.h>
int
MergeInto(
long int num_queries, long int topk, float* distances, long int* uids, float* new_distances, long int* new_uids);
#ifdef __cplusplus
}
#endif

View File

@ -6,6 +6,7 @@
#include "segcore/collection_c.h" #include "segcore/collection_c.h"
#include "segcore/segment_c.h" #include "segcore/segment_c.h"
#include "pb/service_msg.pb.h" #include "pb/service_msg.pb.h"
#include "segcore/reduce_c.h"
#include <chrono> #include <chrono>
namespace chrono = std::chrono; namespace chrono = std::chrono;
@ -510,4 +511,33 @@ TEST(CApiTest, GetRowCountTest) {
// auto segment = NewSegment(collection, 0); // auto segment = NewSegment(collection, 0);
// DeleteCollection(collection); // DeleteCollection(collection);
// DeleteSegment(segment); // DeleteSegment(segment);
//} //}
TEST(CApiTest, MergeInto) {
std::vector<int64_t> uids;
std::vector<float> distance;
std::vector<int64_t> new_uids;
std::vector<float> new_distance;
int64_t num_queries = 1;
int64_t topk = 2;
uids.push_back(1);
uids.push_back(2);
distance.push_back(5);
distance.push_back(1000);
new_uids.push_back(3);
new_uids.push_back(4);
new_distance.push_back(2);
new_distance.push_back(6);
auto res = MergeInto(num_queries, topk, distance.data(), uids.data(), new_distance.data(), new_uids.data());
ASSERT_EQ(res, 0);
ASSERT_EQ(uids[0], 3);
ASSERT_EQ(distance[0], 2);
ASSERT_EQ(uids[1], 1);
ASSERT_EQ(distance[1], 5);
}

View File

@ -6,7 +6,8 @@ import (
"errors" "errors"
"fmt" "fmt"
"log" "log"
"sort" "math"
"sync"
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
@ -17,8 +18,11 @@ import (
) )
type searchService struct { type searchService struct {
ctx context.Context ctx context.Context
cancel context.CancelFunc wait sync.WaitGroup
cancel context.CancelFunc
msgBuffer chan msgstream.TsMsg
unsolvedMsg []msgstream.TsMsg
replica *collectionReplica replica *collectionReplica
tSafeWatcher *tSafeWatcher tSafeWatcher *tSafeWatcher
@ -29,11 +33,6 @@ type searchService struct {
type ResultEntityIds []UniqueID type ResultEntityIds []UniqueID
type SearchResult struct {
ResultIds []UniqueID
ResultDistances []float32
}
func newSearchService(ctx context.Context, replica *collectionReplica) *searchService { func newSearchService(ctx context.Context, replica *collectionReplica) *searchService {
receiveBufSize := Params.searchReceiveBufSize() receiveBufSize := Params.searchReceiveBufSize()
pulsarBufSize := Params.searchPulsarBufSize() pulsarBufSize := Params.searchPulsarBufSize()
@ -58,9 +57,13 @@ func newSearchService(ctx context.Context, replica *collectionReplica) *searchSe
var outputStream msgstream.MsgStream = searchResultStream var outputStream msgstream.MsgStream = searchResultStream
searchServiceCtx, searchServiceCancel := context.WithCancel(ctx) searchServiceCtx, searchServiceCancel := context.WithCancel(ctx)
msgBuffer := make(chan msgstream.TsMsg, receiveBufSize)
unsolvedMsg := make([]msgstream.TsMsg, 0)
return &searchService{ return &searchService{
ctx: searchServiceCtx, ctx: searchServiceCtx,
cancel: searchServiceCancel, cancel: searchServiceCancel,
msgBuffer: msgBuffer,
unsolvedMsg: unsolvedMsg,
replica: replica, replica: replica,
tSafeWatcher: newTSafeWatcher(), tSafeWatcher: newTSafeWatcher(),
@ -73,27 +76,10 @@ func newSearchService(ctx context.Context, replica *collectionReplica) *searchSe
func (ss *searchService) start() { func (ss *searchService) start() {
(*ss.searchMsgStream).Start() (*ss.searchMsgStream).Start()
(*ss.searchResultMsgStream).Start() (*ss.searchResultMsgStream).Start()
ss.wait.Add(2)
go func() { go ss.receiveSearchMsg()
for { go ss.startSearchService()
select { ss.wait.Wait()
case <-ss.ctx.Done():
return
default:
msgPack := (*ss.searchMsgStream).Consume()
if msgPack == nil || len(msgPack.Msgs) <= 0 {
continue
}
// TODO: add serviceTime check
err := ss.search(msgPack.Msgs)
if err != nil {
fmt.Println("search Failed")
ss.publishFailedSearchResult()
}
fmt.Println("Do search done")
}
}
}()
} }
func (ss *searchService) close() { func (ss *searchService) close() {
@ -114,12 +100,68 @@ func (ss *searchService) waitNewTSafe() Timestamp {
return timestamp return timestamp
} }
func (ss *searchService) search(searchMessages []msgstream.TsMsg) error { func (ss *searchService) receiveSearchMsg() {
defer ss.wait.Done()
type SearchResult struct { for {
ResultID int64 select {
ResultDistance float32 case <-ss.ctx.Done():
return
default:
msgPack := (*ss.searchMsgStream).Consume()
if msgPack == nil || len(msgPack.Msgs) <= 0 {
continue
}
for i := range msgPack.Msgs {
ss.msgBuffer <- msgPack.Msgs[i]
//fmt.Println("receive a search msg")
}
}
} }
}
func (ss *searchService) startSearchService() {
defer ss.wait.Done()
for {
select {
case <-ss.ctx.Done():
return
default:
serviceTimestamp := (*(*ss.replica).getTSafe()).get()
searchMsg := make([]msgstream.TsMsg, 0)
tempMsg := make([]msgstream.TsMsg, 0)
tempMsg = append(tempMsg, ss.unsolvedMsg...)
ss.unsolvedMsg = ss.unsolvedMsg[:0]
for _, msg := range tempMsg {
if msg.BeginTs() > serviceTimestamp {
searchMsg = append(searchMsg, msg)
continue
}
ss.unsolvedMsg = append(ss.unsolvedMsg, msg)
}
msgBufferLength := len(ss.msgBuffer)
for i := 0; i < msgBufferLength; i++ {
msg := <-ss.msgBuffer
if msg.BeginTs() > serviceTimestamp {
searchMsg = append(searchMsg, msg)
continue
}
ss.unsolvedMsg = append(ss.unsolvedMsg, msg)
}
if len(searchMsg) <= 0 {
continue
}
err := ss.search(searchMsg)
if err != nil {
fmt.Println("search Failed")
ss.publishFailedSearchResult()
}
fmt.Println("Do search done")
}
}
}
func (ss *searchService) search(searchMessages []msgstream.TsMsg) error {
// TODO:: cache map[dsl]plan // TODO:: cache map[dsl]plan
// TODO: reBatched search requests // TODO: reBatched search requests
for _, msg := range searchMessages { for _, msg := range searchMessages {
@ -129,8 +171,6 @@ func (ss *searchService) search(searchMessages []msgstream.TsMsg) error {
} }
searchTimestamp := searchMsg.Timestamp searchTimestamp := searchMsg.Timestamp
// TODO:: add serviceable time
var queryBlob = searchMsg.Query.Value var queryBlob = searchMsg.Query.Value
query := servicepb.Query{} query := servicepb.Query{}
err := proto.Unmarshal(queryBlob, &query) err := proto.Unmarshal(queryBlob, &query)
@ -162,9 +202,11 @@ func (ss *searchService) search(searchMessages []msgstream.TsMsg) error {
for _, pg := range placeholderGroups { for _, pg := range placeholderGroups {
numQueries += pg.GetNumOfQuery() numQueries += pg.GetNumOfQuery()
} }
var searchResults = make([][]SearchResult, numQueries)
for i := 0; i < int(numQueries); i++ { resultIds := make([]IntPrimaryKey, topK*numQueries)
searchResults[i] = make([]SearchResult, 0) resultDistances := make([]float32, topK*numQueries)
for i := range resultDistances {
resultDistances[i] = math.MaxFloat32
} }
// 3. Do search in all segments // 3. Do search in all segments
@ -174,42 +216,27 @@ func (ss *searchService) search(searchMessages []msgstream.TsMsg) error {
return err return err
} }
for _, segment := range partition.segments { for _, segment := range partition.segments {
res, err := segment.segmentSearch(plan, placeholderGroups, []Timestamp{searchTimestamp}, numQueries, topK) err := segment.segmentSearch(plan,
placeholderGroups,
[]Timestamp{searchTimestamp},
resultIds,
resultDistances,
numQueries,
topK)
if err != nil { if err != nil {
return err return err
} }
for i := 0; int64(i) < numQueries; i++ {
for j := int64(i) * topK; j < int64(i+1)*topK; j++ {
searchResults[i] = append(searchResults[i], SearchResult{
ResultID: res.ResultIds[j],
ResultDistance: res.ResultDistances[j],
})
}
}
}
}
// 4. Reduce results
// TODO::reduce in c++ merge_into func
for _, temp := range searchResults {
sort.Slice(temp, func(i, j int) bool {
return temp[i].ResultDistance < temp[j].ResultDistance
})
}
for i, tmp := range searchResults {
if int64(len(tmp)) > topK {
searchResults[i] = searchResults[i][:topK]
} }
} }
// 4. return results
hits := make([]*servicepb.Hits, 0) hits := make([]*servicepb.Hits, 0)
for _, value := range searchResults { for i := int64(0); i < numQueries; i++ {
hit := servicepb.Hits{} hit := servicepb.Hits{}
score := servicepb.Score{} score := servicepb.Score{}
for j := 0; int64(j) < topK; j++ { for j := i * topK; j < (i+1)*topK; j++ {
hit.IDs = append(hit.IDs, value[j].ResultID) hit.IDs = append(hit.IDs, resultIds[j])
score.Values = append(score.Values, value[j].ResultDistance) score.Values = append(score.Values, resultDistances[j])
} }
hit.Scores = append(hit.Scores, &score) hit.Scores = append(hit.Scores, &score)
hits = append(hits, &hit) hits = append(hits, &hit)

View File

@ -175,8 +175,9 @@ func TestSearch_Search(t *testing.T) {
searchStream.SetPulsarClient(pulsarURL) searchStream.SetPulsarClient(pulsarURL)
searchStream.CreatePulsarProducers(searchProducerChannels) searchStream.CreatePulsarProducers(searchProducerChannels)
var vecSearch = [DIM]float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 17}
var searchRawData []byte var searchRawData []byte
for _, ele := range vec { for _, ele := range vecSearch {
buf := make([]byte, 4) buf := make([]byte, 4)
binary.LittleEndian.PutUint32(buf, math.Float32bits(ele)) binary.LittleEndian.PutUint32(buf, math.Float32bits(ele))
searchRawData = append(searchRawData, buf...) searchRawData = append(searchRawData, buf...)

View File

@ -9,6 +9,7 @@ package reader
#include "collection_c.h" #include "collection_c.h"
#include "segment_c.h" #include "segment_c.h"
#include "plan_c.h" #include "plan_c.h"
#include "reduce_c.h"
*/ */
import "C" import "C"
@ -178,14 +179,24 @@ func (s *Segment) segmentDelete(offset int64, entityIDs *[]UniqueID, timestamps
return nil return nil
} }
func (s *Segment) segmentSearch(plan *Plan, placeHolderGroups []*PlaceholderGroup, timestamp []Timestamp, numQueries int64, topK int64) (*SearchResult, error) { func (s *Segment) segmentSearch(plan *Plan,
placeHolderGroups []*PlaceholderGroup,
timestamp []Timestamp,
resultIds []IntPrimaryKey,
resultDistances []float32,
numQueries int64,
topK int64) error {
/* /*
void* Search(void* plan, void* placeholder_groups, uint64_t* timestamps, int num_groups, long int* result_ids, void* Search(void* plan,
float* result_distances) void* placeholder_groups,
uint64_t* timestamps,
int num_groups,
long int* result_ids,
float* result_distances);
*/ */
resultIds := make([]IntPrimaryKey, topK*numQueries) newResultIds := make([]IntPrimaryKey, topK*numQueries)
resultDistances := make([]float32, topK*numQueries) NewResultDistances := make([]float32, topK*numQueries)
cPlaceholderGroups := make([]C.CPlaceholderGroup, 0) cPlaceholderGroups := make([]C.CPlaceholderGroup, 0)
for _, pg := range placeHolderGroups { for _, pg := range placeHolderGroups {
cPlaceholderGroups = append(cPlaceholderGroups, (*pg).cPlaceholderGroup) cPlaceholderGroups = append(cPlaceholderGroups, (*pg).cPlaceholderGroup)
@ -194,16 +205,22 @@ func (s *Segment) segmentSearch(plan *Plan, placeHolderGroups []*PlaceholderGrou
var cTimestamp = (*C.ulong)(&timestamp[0]) var cTimestamp = (*C.ulong)(&timestamp[0])
var cResultIds = (*C.long)(&resultIds[0]) var cResultIds = (*C.long)(&resultIds[0])
var cResultDistances = (*C.float)(&resultDistances[0]) var cResultDistances = (*C.float)(&resultDistances[0])
var cNewResultIds = (*C.long)(&newResultIds[0])
var cNewResultDistances = (*C.float)(&NewResultDistances[0])
var cPlaceHolder = (*C.CPlaceholderGroup)(&cPlaceholderGroups[0]) var cPlaceHolder = (*C.CPlaceholderGroup)(&cPlaceholderGroups[0])
var cNumGroups = C.int(len(placeHolderGroups)) var cNumGroups = C.int(len(placeHolderGroups))
var status = C.Search(s.segmentPtr, plan.cPlan, cPlaceHolder, cTimestamp, cNumGroups, cResultIds, cResultDistances) var status = C.Search(s.segmentPtr, plan.cPlan, cPlaceHolder, cTimestamp, cNumGroups, cNewResultIds, cNewResultDistances)
if status != 0 { if status != 0 {
return nil, errors.New("search failed, error code = " + strconv.Itoa(int(status))) return errors.New("search failed, error code = " + strconv.Itoa(int(status)))
} }
//fmt.Println("search Result---- Ids =", resultIds, ", Distances =", resultDistances) cNumQueries := C.long(numQueries)
cTopK := C.long(topK)
return &SearchResult{ResultIds: resultIds, ResultDistances: resultDistances}, nil // reduce search result
status = C.MergeInto(cNumQueries, cTopK, cResultDistances, cResultIds, cNewResultDistances, cNewResultIds)
if status != 0 {
return errors.New("merge search result failed, error code = " + strconv.Itoa(int(status)))
}
return nil
} }

View File

@ -661,8 +661,13 @@ func TestSegment_segmentSearch(t *testing.T) {
for _, pg := range placeholderGroups { for _, pg := range placeholderGroups {
numQueries += pg.GetNumOfQuery() numQueries += pg.GetNumOfQuery()
} }
resultIds := make([]IntPrimaryKey, topK*numQueries)
resultDistances := make([]float32, topK*numQueries)
for i := range resultDistances {
resultDistances[i] = math.MaxFloat32
}
_, err = segment.segmentSearch(cPlan, placeholderGroups, []Timestamp{searchTimestamp}, numQueries, topK) err = segment.segmentSearch(cPlan, placeholderGroups, []Timestamp{searchTimestamp}, resultIds, resultDistances, numQueries, topK)
assert.NoError(t, err) assert.NoError(t, err)
} }