mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-12-03 12:29:36 +08:00
27b9a51938
Signed-off-by: zhenshan.cao <zhenshan.cao@zilliz.com>
231 lines
5.9 KiB
Go
231 lines
5.9 KiB
Go
package proxy
|
|
|
|
import (
|
|
"github.com/apache/pulsar-client-go/pulsar"
|
|
"github.com/golang/protobuf/proto"
|
|
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
|
|
"github.com/zilliztech/milvus-distributed/internal/proto/internalpb"
|
|
"github.com/zilliztech/milvus-distributed/internal/proto/servicepb"
|
|
"log"
|
|
"sync"
|
|
)
|
|
|
|
type queryReq struct {
|
|
internalpb.SearchRequest
|
|
result []*internalpb.SearchResult
|
|
wg sync.WaitGroup
|
|
proxy *proxyServer
|
|
}
|
|
|
|
// BaseRequest interfaces
|
|
func (req *queryReq) Type() internalpb.ReqType {
|
|
return req.ReqType
|
|
}
|
|
|
|
func (req *queryReq) PreExecute() commonpb.Status {
|
|
return commonpb.Status{ErrorCode: commonpb.ErrorCode_SUCCESS}
|
|
}
|
|
|
|
func (req *queryReq) Execute() commonpb.Status {
|
|
req.proxy.reqSch.queryChan <- req
|
|
return commonpb.Status{ErrorCode: commonpb.ErrorCode_SUCCESS}
|
|
}
|
|
|
|
func (req *queryReq) PostExecute() commonpb.Status { // send into pulsar
|
|
req.wg.Add(1)
|
|
return commonpb.Status{ErrorCode: commonpb.ErrorCode_SUCCESS}
|
|
}
|
|
|
|
func (req *queryReq) WaitToFinish() commonpb.Status { // wait unitl send into pulsar
|
|
req.wg.Wait()
|
|
return commonpb.Status{ErrorCode: commonpb.ErrorCode_SUCCESS}
|
|
}
|
|
|
|
func (s *proxyServer) restartQueryRoutine(buf_size int) error {
|
|
s.reqSch.queryChan = make(chan *queryReq, buf_size)
|
|
pulsarClient, err := pulsar.NewClient(pulsar.ClientOptions{URL: s.pulsarAddr})
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
query, err := pulsarClient.CreateProducer(pulsar.ProducerOptions{Topic: s.queryTopic})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
result, err := pulsarClient.Subscribe(pulsar.ConsumerOptions{
|
|
Topic: s.resultTopic,
|
|
SubscriptionName: s.resultGroup,
|
|
Type: pulsar.KeyShared,
|
|
SubscriptionInitialPosition: pulsar.SubscriptionPositionEarliest,
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
resultMap := make(map[int64]*queryReq)
|
|
|
|
go func() {
|
|
defer result.Close()
|
|
defer query.Close()
|
|
defer pulsarClient.Close()
|
|
for {
|
|
select {
|
|
case <-s.ctx.Done():
|
|
return
|
|
case qm := <-s.reqSch.queryChan:
|
|
ts, err := s.getTimestamp(1)
|
|
if err != nil {
|
|
log.Printf("get time stamp failed")
|
|
break
|
|
}
|
|
qm.Timestamp = uint64(ts[0])
|
|
|
|
|
|
qb, err := proto.Marshal(qm)
|
|
if err != nil {
|
|
log.Printf("Marshal QueryReqMsg failed, error = %v", err)
|
|
continue
|
|
}
|
|
if _, err := query.Send(s.ctx, &pulsar.ProducerMessage{Payload: qb}); err != nil {
|
|
log.Printf("post into puslar failed, error = %v", err)
|
|
}
|
|
s.reqSch.qTimestampMux.Lock()
|
|
if s.reqSch.qTimestamp <= ts[0] {
|
|
s.reqSch.qTimestamp = ts[0]
|
|
} else {
|
|
log.Printf("there is some wrong with q_timestamp, it goes back, current = %d, previous = %d", ts[0], s.reqSch.qTimestamp)
|
|
}
|
|
s.reqSch.qTimestampMux.Unlock()
|
|
resultMap[qm.ReqId] = qm
|
|
//log.Printf("start search, query id = %d", qm.QueryId)
|
|
case cm, ok := <-result.Chan():
|
|
if !ok {
|
|
log.Printf("consumer of result topic has closed")
|
|
return
|
|
}
|
|
var rm internalpb.SearchResult
|
|
if err := proto.Unmarshal(cm.Message.Payload(), &rm); err != nil {
|
|
log.Printf("Unmarshal QueryReqMsg failed, error = %v", err)
|
|
break
|
|
}
|
|
if rm.ProxyId != s.proxyId {
|
|
break
|
|
}
|
|
qm, ok := resultMap[rm.ReqId]
|
|
if !ok {
|
|
log.Printf("unknown query id = %d", rm.ReqId)
|
|
break
|
|
}
|
|
qm.result = append(qm.result, &rm)
|
|
if len(qm.result) == s.numReaderNode {
|
|
qm.wg.Done()
|
|
delete(resultMap, rm.ReqId)
|
|
}
|
|
result.AckID(cm.ID())
|
|
}
|
|
|
|
}
|
|
}()
|
|
return nil
|
|
}
|
|
|
|
//func (s *proxyServer) reduceResult(query *queryReq) *servicepb.QueryResult {
|
|
//}
|
|
|
|
func (s *proxyServer) reduceResults(query *queryReq) *servicepb.QueryResult {
|
|
|
|
var results []*internalpb.SearchResult
|
|
var status commonpb.Status
|
|
status.ErrorCode = commonpb.ErrorCode_UNEXPECTED_ERROR
|
|
for _, r := range query.result {
|
|
status = *r.Status
|
|
if status.ErrorCode == commonpb.ErrorCode_SUCCESS {
|
|
results = append(results, r)
|
|
}else{
|
|
break
|
|
}
|
|
}
|
|
if len(results) != s.numReaderNode{
|
|
status.ErrorCode = commonpb.ErrorCode_UNEXPECTED_ERROR
|
|
}
|
|
if status.ErrorCode != commonpb.ErrorCode_SUCCESS{
|
|
result:= servicepb.QueryResult{
|
|
Status: &status,
|
|
}
|
|
return &result
|
|
}
|
|
|
|
if s.numReaderNode == 1 {
|
|
result:= servicepb.QueryResult{
|
|
Status: &commonpb.Status{
|
|
ErrorCode: commonpb.ErrorCode_SUCCESS,
|
|
},
|
|
Hits: results[0].Hits,
|
|
}
|
|
return &result
|
|
}
|
|
|
|
//var entities []*struct {
|
|
// Idx int64
|
|
// Score float32
|
|
// Hit *servicepb.Hits
|
|
//}
|
|
//var rows int
|
|
//
|
|
//result_err := func(msg string) *pb.QueryResult {
|
|
// return &pb.QueryResult{
|
|
// Status: &pb.Status{
|
|
// ErrorCode: pb.ErrorCode_UNEXPECTED_ERROR,
|
|
// Reason: msg,
|
|
// },
|
|
// }
|
|
//}
|
|
|
|
//for _, r := range results {
|
|
// for i := 0; i < len(r.Hits); i++ {
|
|
// entity := struct {
|
|
// Ids int64
|
|
// ValidRow bool
|
|
// RowsData *pb.RowData
|
|
// Scores float32
|
|
// Distances float32
|
|
// }{
|
|
// Ids: r.Entities.Ids[i],
|
|
// ValidRow: r.Entities.ValidRow[i],
|
|
// RowsData: r.Entities.RowsData[i],
|
|
// Scores: r.Scores[i],
|
|
// Distances: r.Distances[i],
|
|
// }
|
|
// entities = append(entities, &entity)
|
|
// }
|
|
//}
|
|
//sort.Slice(entities, func(i, j int) bool {
|
|
// if entities[i].ValidRow == true {
|
|
// if entities[j].ValidRow == false {
|
|
// return true
|
|
// }
|
|
// return entities[i].Scores > entities[j].Scores
|
|
// } else {
|
|
// return false
|
|
// }
|
|
//})
|
|
//rIds := make([]int64, 0, rows)
|
|
//rValidRow := make([]bool, 0, rows)
|
|
//rRowsData := make([]*pb.RowData, 0, rows)
|
|
//rScores := make([]float32, 0, rows)
|
|
//rDistances := make([]float32, 0, rows)
|
|
//for i := 0; i < rows; i++ {
|
|
// rIds = append(rIds, entities[i].Ids)
|
|
// rValidRow = append(rValidRow, entities[i].ValidRow)
|
|
// rRowsData = append(rRowsData, entities[i].RowsData)
|
|
// rScores = append(rScores, entities[i].Scores)
|
|
// rDistances = append(rDistances, entities[i].Distances)
|
|
//}
|
|
|
|
return &servicepb.QueryResult{
|
|
Status: &commonpb.Status{
|
|
ErrorCode: commonpb.ErrorCode_SUCCESS,
|
|
},
|
|
}
|
|
}
|