milvus/internal/proxy/query_req.go
zhenshan.cao 27b9a51938 Add logic of allocate id
Signed-off-by: zhenshan.cao <zhenshan.cao@zilliz.com>
2020-10-30 16:27:58 +08:00

231 lines
5.9 KiB
Go

package proxy
import (
"github.com/apache/pulsar-client-go/pulsar"
"github.com/golang/protobuf/proto"
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
"github.com/zilliztech/milvus-distributed/internal/proto/internalpb"
"github.com/zilliztech/milvus-distributed/internal/proto/servicepb"
"log"
"sync"
)
type queryReq struct {
internalpb.SearchRequest
result []*internalpb.SearchResult
wg sync.WaitGroup
proxy *proxyServer
}
// BaseRequest interfaces
func (req *queryReq) Type() internalpb.ReqType {
return req.ReqType
}
func (req *queryReq) PreExecute() commonpb.Status {
return commonpb.Status{ErrorCode: commonpb.ErrorCode_SUCCESS}
}
func (req *queryReq) Execute() commonpb.Status {
req.proxy.reqSch.queryChan <- req
return commonpb.Status{ErrorCode: commonpb.ErrorCode_SUCCESS}
}
func (req *queryReq) PostExecute() commonpb.Status { // send into pulsar
req.wg.Add(1)
return commonpb.Status{ErrorCode: commonpb.ErrorCode_SUCCESS}
}
func (req *queryReq) WaitToFinish() commonpb.Status { // wait unitl send into pulsar
req.wg.Wait()
return commonpb.Status{ErrorCode: commonpb.ErrorCode_SUCCESS}
}
func (s *proxyServer) restartQueryRoutine(buf_size int) error {
s.reqSch.queryChan = make(chan *queryReq, buf_size)
pulsarClient, err := pulsar.NewClient(pulsar.ClientOptions{URL: s.pulsarAddr})
if err != nil {
return nil
}
query, err := pulsarClient.CreateProducer(pulsar.ProducerOptions{Topic: s.queryTopic})
if err != nil {
return err
}
result, err := pulsarClient.Subscribe(pulsar.ConsumerOptions{
Topic: s.resultTopic,
SubscriptionName: s.resultGroup,
Type: pulsar.KeyShared,
SubscriptionInitialPosition: pulsar.SubscriptionPositionEarliest,
})
if err != nil {
return err
}
resultMap := make(map[int64]*queryReq)
go func() {
defer result.Close()
defer query.Close()
defer pulsarClient.Close()
for {
select {
case <-s.ctx.Done():
return
case qm := <-s.reqSch.queryChan:
ts, err := s.getTimestamp(1)
if err != nil {
log.Printf("get time stamp failed")
break
}
qm.Timestamp = uint64(ts[0])
qb, err := proto.Marshal(qm)
if err != nil {
log.Printf("Marshal QueryReqMsg failed, error = %v", err)
continue
}
if _, err := query.Send(s.ctx, &pulsar.ProducerMessage{Payload: qb}); err != nil {
log.Printf("post into puslar failed, error = %v", err)
}
s.reqSch.qTimestampMux.Lock()
if s.reqSch.qTimestamp <= ts[0] {
s.reqSch.qTimestamp = ts[0]
} else {
log.Printf("there is some wrong with q_timestamp, it goes back, current = %d, previous = %d", ts[0], s.reqSch.qTimestamp)
}
s.reqSch.qTimestampMux.Unlock()
resultMap[qm.ReqId] = qm
//log.Printf("start search, query id = %d", qm.QueryId)
case cm, ok := <-result.Chan():
if !ok {
log.Printf("consumer of result topic has closed")
return
}
var rm internalpb.SearchResult
if err := proto.Unmarshal(cm.Message.Payload(), &rm); err != nil {
log.Printf("Unmarshal QueryReqMsg failed, error = %v", err)
break
}
if rm.ProxyId != s.proxyId {
break
}
qm, ok := resultMap[rm.ReqId]
if !ok {
log.Printf("unknown query id = %d", rm.ReqId)
break
}
qm.result = append(qm.result, &rm)
if len(qm.result) == s.numReaderNode {
qm.wg.Done()
delete(resultMap, rm.ReqId)
}
result.AckID(cm.ID())
}
}
}()
return nil
}
//func (s *proxyServer) reduceResult(query *queryReq) *servicepb.QueryResult {
//}
func (s *proxyServer) reduceResults(query *queryReq) *servicepb.QueryResult {
var results []*internalpb.SearchResult
var status commonpb.Status
status.ErrorCode = commonpb.ErrorCode_UNEXPECTED_ERROR
for _, r := range query.result {
status = *r.Status
if status.ErrorCode == commonpb.ErrorCode_SUCCESS {
results = append(results, r)
}else{
break
}
}
if len(results) != s.numReaderNode{
status.ErrorCode = commonpb.ErrorCode_UNEXPECTED_ERROR
}
if status.ErrorCode != commonpb.ErrorCode_SUCCESS{
result:= servicepb.QueryResult{
Status: &status,
}
return &result
}
if s.numReaderNode == 1 {
result:= servicepb.QueryResult{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_SUCCESS,
},
Hits: results[0].Hits,
}
return &result
}
//var entities []*struct {
// Idx int64
// Score float32
// Hit *servicepb.Hits
//}
//var rows int
//
//result_err := func(msg string) *pb.QueryResult {
// return &pb.QueryResult{
// Status: &pb.Status{
// ErrorCode: pb.ErrorCode_UNEXPECTED_ERROR,
// Reason: msg,
// },
// }
//}
//for _, r := range results {
// for i := 0; i < len(r.Hits); i++ {
// entity := struct {
// Ids int64
// ValidRow bool
// RowsData *pb.RowData
// Scores float32
// Distances float32
// }{
// Ids: r.Entities.Ids[i],
// ValidRow: r.Entities.ValidRow[i],
// RowsData: r.Entities.RowsData[i],
// Scores: r.Scores[i],
// Distances: r.Distances[i],
// }
// entities = append(entities, &entity)
// }
//}
//sort.Slice(entities, func(i, j int) bool {
// if entities[i].ValidRow == true {
// if entities[j].ValidRow == false {
// return true
// }
// return entities[i].Scores > entities[j].Scores
// } else {
// return false
// }
//})
//rIds := make([]int64, 0, rows)
//rValidRow := make([]bool, 0, rows)
//rRowsData := make([]*pb.RowData, 0, rows)
//rScores := make([]float32, 0, rows)
//rDistances := make([]float32, 0, rows)
//for i := 0; i < rows; i++ {
// rIds = append(rIds, entities[i].Ids)
// rValidRow = append(rValidRow, entities[i].ValidRow)
// rRowsData = append(rRowsData, entities[i].RowsData)
// rScores = append(rScores, entities[i].Scores)
// rDistances = append(rDistances, entities[i].Distances)
//}
return &servicepb.QueryResult{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_SUCCESS,
},
}
}