mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-11-30 02:48:45 +08:00
Add tests/benchmark and tests/python_test using new python SDK
Signed-off-by: Cai Yudong <yudong.cai@zilliz.com>
This commit is contained in:
parent
d5631b09e8
commit
84110d2684
@ -85,14 +85,14 @@ func (dsService *dataSyncService) initNodes() {
|
||||
var insertBufferNode Node = newInsertBufferNode(dsService.ctx, mt, dsService.replica, dsService.idAllocator, dsService.msFactory)
|
||||
var gcNode Node = newGCNode(dsService.replica)
|
||||
|
||||
dsService.fg.AddNode(&dmStreamNode)
|
||||
dsService.fg.AddNode(&ddStreamNode)
|
||||
dsService.fg.AddNode(dmStreamNode)
|
||||
dsService.fg.AddNode(ddStreamNode)
|
||||
|
||||
dsService.fg.AddNode(&filterDmNode)
|
||||
dsService.fg.AddNode(&ddNode)
|
||||
dsService.fg.AddNode(filterDmNode)
|
||||
dsService.fg.AddNode(ddNode)
|
||||
|
||||
dsService.fg.AddNode(&insertBufferNode)
|
||||
dsService.fg.AddNode(&gcNode)
|
||||
dsService.fg.AddNode(insertBufferNode)
|
||||
dsService.fg.AddNode(gcNode)
|
||||
|
||||
// dmStreamNode
|
||||
err = dsService.fg.SetEdges(dmStreamNode.Name(),
|
||||
|
@ -66,7 +66,7 @@ func (ddNode *ddNode) Name() string {
|
||||
return "ddNode"
|
||||
}
|
||||
|
||||
func (ddNode *ddNode) Operate(in []*Msg) []*Msg {
|
||||
func (ddNode *ddNode) Operate(ctx context.Context, in []Msg) ([]Msg, context.Context) {
|
||||
//fmt.Println("Do filterDdNode operation")
|
||||
|
||||
if len(in) != 1 {
|
||||
@ -74,7 +74,7 @@ func (ddNode *ddNode) Operate(in []*Msg) []*Msg {
|
||||
// TODO: add error handling
|
||||
}
|
||||
|
||||
msMsg, ok := (*in[0]).(*MsgStreamMsg)
|
||||
msMsg, ok := in[0].(*MsgStreamMsg)
|
||||
if !ok {
|
||||
log.Println("type assertion failed for MsgStreamMsg")
|
||||
// TODO: add error handling
|
||||
@ -141,7 +141,7 @@ func (ddNode *ddNode) Operate(in []*Msg) []*Msg {
|
||||
}
|
||||
|
||||
var res Msg = ddNode.ddMsg
|
||||
return []*Msg{&res}
|
||||
return []Msg{res}, ctx
|
||||
}
|
||||
|
||||
func (ddNode *ddNode) flush() {
|
||||
|
@ -154,5 +154,5 @@ func TestFlowGraphDDNode_Operate(t *testing.T) {
|
||||
tsMessages = append(tsMessages, msgstream.TsMsg(&dropPartitionMsg))
|
||||
msgStream := flowgraph.GenerateMsgStreamMsg(tsMessages, Timestamp(0), Timestamp(3), make([]*internalpb2.MsgPosition, 0))
|
||||
var inMsg Msg = msgStream
|
||||
ddNode.Operate([]*Msg{&inMsg})
|
||||
ddNode.Operate(ctx, []Msg{inMsg})
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package datanode
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"math"
|
||||
|
||||
@ -18,7 +19,7 @@ func (fdmNode *filterDmNode) Name() string {
|
||||
return "fdmNode"
|
||||
}
|
||||
|
||||
func (fdmNode *filterDmNode) Operate(in []*Msg) []*Msg {
|
||||
func (fdmNode *filterDmNode) Operate(ctx context.Context, in []Msg) ([]Msg, context.Context) {
|
||||
//fmt.Println("Do filterDmNode operation")
|
||||
|
||||
if len(in) != 2 {
|
||||
@ -26,13 +27,13 @@ func (fdmNode *filterDmNode) Operate(in []*Msg) []*Msg {
|
||||
// TODO: add error handling
|
||||
}
|
||||
|
||||
msgStreamMsg, ok := (*in[0]).(*MsgStreamMsg)
|
||||
msgStreamMsg, ok := in[0].(*MsgStreamMsg)
|
||||
if !ok {
|
||||
log.Println("type assertion failed for MsgStreamMsg")
|
||||
// TODO: add error handling
|
||||
}
|
||||
|
||||
ddMsg, ok := (*in[1]).(*ddMsg)
|
||||
ddMsg, ok := in[1].(*ddMsg)
|
||||
if !ok {
|
||||
log.Println("type assertion failed for ddMsg")
|
||||
// TODO: add error handling
|
||||
@ -69,7 +70,7 @@ func (fdmNode *filterDmNode) Operate(in []*Msg) []*Msg {
|
||||
iMsg.startPositions = append(iMsg.startPositions, msgStreamMsg.StartPositions()...)
|
||||
iMsg.gcRecord = ddMsg.gcRecord
|
||||
var res Msg = &iMsg
|
||||
return []*Msg{&res}
|
||||
return []Msg{res}, ctx
|
||||
}
|
||||
|
||||
func (fdmNode *filterDmNode) filterInvalidInsertMessage(msg *msgstream.InsertMsg) *msgstream.InsertMsg {
|
||||
|
@ -1,6 +1,7 @@
|
||||
package datanode
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
)
|
||||
|
||||
@ -13,7 +14,7 @@ func (gcNode *gcNode) Name() string {
|
||||
return "gcNode"
|
||||
}
|
||||
|
||||
func (gcNode *gcNode) Operate(in []*Msg) []*Msg {
|
||||
func (gcNode *gcNode) Operate(ctx context.Context, in []Msg) ([]Msg, context.Context) {
|
||||
//fmt.Println("Do gcNode operation")
|
||||
|
||||
if len(in) != 1 {
|
||||
@ -21,7 +22,7 @@ func (gcNode *gcNode) Operate(in []*Msg) []*Msg {
|
||||
// TODO: add error handling
|
||||
}
|
||||
|
||||
gcMsg, ok := (*in[0]).(*gcMsg)
|
||||
gcMsg, ok := in[0].(*gcMsg)
|
||||
if !ok {
|
||||
log.Println("type assertion failed for gcMsg")
|
||||
// TODO: add error handling
|
||||
@ -35,7 +36,7 @@ func (gcNode *gcNode) Operate(in []*Msg) []*Msg {
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return nil, ctx
|
||||
}
|
||||
|
||||
func newGCNode(replica Replica) *gcNode {
|
||||
|
@ -85,7 +85,7 @@ func (ibNode *insertBufferNode) Name() string {
|
||||
return "ibNode"
|
||||
}
|
||||
|
||||
func (ibNode *insertBufferNode) Operate(in []*Msg) []*Msg {
|
||||
func (ibNode *insertBufferNode) Operate(ctx context.Context, in []Msg) ([]Msg, context.Context) {
|
||||
// log.Println("=========== insert buffer Node Operating")
|
||||
|
||||
if len(in) != 1 {
|
||||
@ -93,7 +93,7 @@ func (ibNode *insertBufferNode) Operate(in []*Msg) []*Msg {
|
||||
// TODO: add error handling
|
||||
}
|
||||
|
||||
iMsg, ok := (*in[0]).(*insertMsg)
|
||||
iMsg, ok := in[0].(*insertMsg)
|
||||
if !ok {
|
||||
log.Println("Error: type assertion failed for insertMsg")
|
||||
// TODO: add error handling
|
||||
@ -472,7 +472,7 @@ func (ibNode *insertBufferNode) Operate(in []*Msg) []*Msg {
|
||||
timeRange: iMsg.timeRange,
|
||||
}
|
||||
|
||||
return []*Msg{&res}
|
||||
return []Msg{res}, ctx
|
||||
}
|
||||
|
||||
func (ibNode *insertBufferNode) flushSegment(segID UniqueID, partitionID UniqueID, collID UniqueID) error {
|
||||
|
@ -53,7 +53,7 @@ func TestFlowGraphInsertBufferNode_Operate(t *testing.T) {
|
||||
iBNode := newInsertBufferNode(ctx, newMetaTable(), replica, idFactory, msFactory)
|
||||
inMsg := genInsertMsg()
|
||||
var iMsg flowgraph.Msg = &inMsg
|
||||
iBNode.Operate([]*flowgraph.Msg{&iMsg})
|
||||
iBNode.Operate(ctx, []flowgraph.Msg{iMsg})
|
||||
}
|
||||
|
||||
func genInsertMsg() insertMsg {
|
||||
|
@ -329,7 +329,7 @@ func (s *Server) startStatsChannel(ctx context.Context) {
|
||||
return
|
||||
default:
|
||||
}
|
||||
msgPack := statsStream.Consume()
|
||||
msgPack, _ := statsStream.Consume()
|
||||
for _, msg := range msgPack.Msgs {
|
||||
statistics, ok := msg.(*msgstream.SegmentStatisticsMsg)
|
||||
if !ok {
|
||||
@ -358,7 +358,7 @@ func (s *Server) startSegmentFlushChannel(ctx context.Context) {
|
||||
return
|
||||
default:
|
||||
}
|
||||
msgPack := flushStream.Consume()
|
||||
msgPack, _ := flushStream.Consume()
|
||||
for _, msg := range msgPack.Msgs {
|
||||
if msg.Type() != commonpb.MsgType_kSegmentFlushDone {
|
||||
continue
|
||||
@ -393,7 +393,7 @@ func (s *Server) startDDChannel(ctx context.Context) {
|
||||
return
|
||||
default:
|
||||
}
|
||||
msgPack := ddStream.Consume()
|
||||
msgPack, _ := ddStream.Consume()
|
||||
for _, msg := range msgPack.Msgs {
|
||||
if err := s.ddHandler.HandleDDMsg(msg); err != nil {
|
||||
log.Error("handle dd msg error", zap.Error(err))
|
||||
|
@ -2,18 +2,21 @@ package grpcindexservice
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/internalpb2"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/milvuspb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/util/funcutil"
|
||||
|
||||
"github.com/opentracing/opentracing-go"
|
||||
"github.com/uber/jaeger-client-go/config"
|
||||
"github.com/zilliztech/milvus-distributed/internal/indexservice"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/indexpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/internalpb2"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/milvuspb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/util/funcutil"
|
||||
"github.com/zilliztech/milvus-distributed/internal/util/typeutil"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
@ -30,6 +33,8 @@ type Server struct {
|
||||
loopCtx context.Context
|
||||
loopCancel func()
|
||||
loopWg sync.WaitGroup
|
||||
|
||||
closer io.Closer
|
||||
}
|
||||
|
||||
func (s *Server) Run() error {
|
||||
@ -71,6 +76,9 @@ func (s *Server) start() error {
|
||||
}
|
||||
|
||||
func (s *Server) Stop() error {
|
||||
if err := s.closer.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
if s.impl != nil {
|
||||
s.impl.Stop()
|
||||
}
|
||||
@ -191,5 +199,19 @@ func NewServer(ctx context.Context) (*Server, error) {
|
||||
grpcErrChan: make(chan error),
|
||||
}
|
||||
|
||||
cfg := &config.Configuration{
|
||||
ServiceName: "index_service",
|
||||
Sampler: &config.SamplerConfig{
|
||||
Type: "const",
|
||||
Param: 1,
|
||||
},
|
||||
}
|
||||
tracer, closer, err := cfg.NewTracer()
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("ERROR: cannot init Jaeger: %v\n", err))
|
||||
}
|
||||
opentracing.SetGlobalTracer(tracer)
|
||||
s.closer = closer
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
@ -68,7 +68,7 @@ func NewServer(ctx context.Context, factory msgstream.Factory) (*Server, error)
|
||||
|
||||
//TODO
|
||||
cfg := &config.Configuration{
|
||||
ServiceName: "proxy_service",
|
||||
ServiceName: "master_service",
|
||||
Sampler: &config.SamplerConfig{
|
||||
Type: "const",
|
||||
Param: 1,
|
||||
|
@ -2,11 +2,15 @@ package indexnode
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/milvuspb"
|
||||
|
||||
"github.com/opentracing/opentracing-go"
|
||||
"github.com/uber/jaeger-client-go/config"
|
||||
"github.com/zilliztech/milvus-distributed/internal/errors"
|
||||
"github.com/zilliztech/milvus-distributed/internal/kv"
|
||||
miniokv "github.com/zilliztech/milvus-distributed/internal/kv/minio"
|
||||
@ -40,6 +44,8 @@ type NodeImpl struct {
|
||||
// Add callback functions at different stages
|
||||
startCallbacks []func()
|
||||
closeCallbacks []func()
|
||||
|
||||
closer io.Closer
|
||||
}
|
||||
|
||||
func NewNodeImpl(ctx context.Context) (*NodeImpl, error) {
|
||||
@ -57,9 +63,7 @@ func NewNodeImpl(ctx context.Context) (*NodeImpl, error) {
|
||||
}
|
||||
|
||||
func (i *NodeImpl) Init() error {
|
||||
log.Println("AAAAAAAAAAAAAAAAA", i.serviceClient)
|
||||
err := funcutil.WaitForComponentHealthy(i.serviceClient, "IndexService", 10, time.Second)
|
||||
log.Println("BBBBBBBBB", i.serviceClient)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
@ -87,6 +91,21 @@ func (i *NodeImpl) Init() error {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO
|
||||
cfg := &config.Configuration{
|
||||
ServiceName: fmt.Sprintf("index_node_%d", Params.NodeID),
|
||||
Sampler: &config.SamplerConfig{
|
||||
Type: "const",
|
||||
Param: 1,
|
||||
},
|
||||
}
|
||||
tracer, closer, err := cfg.NewTracer()
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("ERROR: cannot init Jaeger: %v\n", err))
|
||||
}
|
||||
opentracing.SetGlobalTracer(tracer)
|
||||
i.closer = closer
|
||||
|
||||
connectMinIOFn := func() error {
|
||||
option := &miniokv.Option{
|
||||
Address: Params.MinIOAddress,
|
||||
@ -126,6 +145,9 @@ func (i *NodeImpl) Start() error {
|
||||
|
||||
// Close closes the server.
|
||||
func (i *NodeImpl) Stop() error {
|
||||
if err := i.closer.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
i.loopCancel()
|
||||
if i.sched != nil {
|
||||
i.sched.Close()
|
||||
|
@ -32,7 +32,7 @@ type MsgStream interface {
|
||||
|
||||
Produce(context.Context, *MsgPack) error
|
||||
Broadcast(context.Context, *MsgPack) error
|
||||
Consume() *MsgPack
|
||||
Consume() (*MsgPack, context.Context)
|
||||
Seek(offset *MsgPosition) error
|
||||
}
|
||||
|
||||
|
@ -160,7 +160,7 @@ func TestStream_task_Insert(t *testing.T) {
|
||||
}
|
||||
receiveCount := 0
|
||||
for {
|
||||
result := outputStream.Consume()
|
||||
result, _ := outputStream.Consume()
|
||||
if len(result.Msgs) > 0 {
|
||||
msgs := result.Msgs
|
||||
for _, v := range msgs {
|
||||
|
@ -5,14 +5,12 @@ import (
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/apache/pulsar-client-go/pulsar"
|
||||
"github.com/golang/protobuf/proto"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/opentracing/opentracing-go"
|
||||
"github.com/zilliztech/milvus-distributed/internal/errors"
|
||||
"github.com/zilliztech/milvus-distributed/internal/log"
|
||||
"github.com/zilliztech/milvus-distributed/internal/msgstream"
|
||||
@ -21,6 +19,7 @@ import (
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/internalpb2"
|
||||
"github.com/zilliztech/milvus-distributed/internal/util/trace"
|
||||
"github.com/zilliztech/milvus-distributed/internal/util/typeutil"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type TsMsg = msgstream.TsMsg
|
||||
@ -52,6 +51,8 @@ type PulsarMsgStream struct {
|
||||
pulsarBufSize int64
|
||||
consumerLock *sync.Mutex
|
||||
consumerReflects []reflect.SelectCase
|
||||
|
||||
scMap *sync.Map
|
||||
}
|
||||
|
||||
func newPulsarMsgStream(ctx context.Context,
|
||||
@ -92,6 +93,7 @@ func newPulsarMsgStream(ctx context.Context,
|
||||
consumerReflects: consumerReflects,
|
||||
consumerLock: &sync.Mutex{},
|
||||
wait: &sync.WaitGroup{},
|
||||
scMap: &sync.Map{},
|
||||
}
|
||||
|
||||
return stream, nil
|
||||
@ -182,29 +184,6 @@ func (ms *PulsarMsgStream) Close() {
|
||||
}
|
||||
}
|
||||
|
||||
type propertiesReaderWriter struct {
|
||||
ppMap map[string]string
|
||||
}
|
||||
|
||||
func (ppRW *propertiesReaderWriter) Set(key, val string) {
|
||||
// The GRPC HPACK implementation rejects any uppercase keys here.
|
||||
//
|
||||
// As such, since the HTTP_HEADERS format is case-insensitive anyway, we
|
||||
// blindly lowercase the key (which is guaranteed to work in the
|
||||
// Inject/Extract sense per the OpenTracing spec).
|
||||
key = strings.ToLower(key)
|
||||
ppRW.ppMap[key] = val
|
||||
}
|
||||
|
||||
func (ppRW *propertiesReaderWriter) ForeachKey(handler func(key, val string) error) error {
|
||||
for k, val := range ppRW.ppMap {
|
||||
if err := handler(k, val); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ms *PulsarMsgStream) Produce(ctx context.Context, msgPack *MsgPack) error {
|
||||
tsMsgs := msgPack.Msgs
|
||||
if len(tsMsgs) <= 0 {
|
||||
@ -316,18 +295,31 @@ func (ms *PulsarMsgStream) Broadcast(ctx context.Context, msgPack *MsgPack) erro
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ms *PulsarMsgStream) Consume() *MsgPack {
|
||||
func (ms *PulsarMsgStream) Consume() (*MsgPack, context.Context) {
|
||||
for {
|
||||
select {
|
||||
case cm, ok := <-ms.receiveBuf:
|
||||
if !ok {
|
||||
log.Debug("buf chan closed")
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
return cm
|
||||
var ctx context.Context
|
||||
var opts []opentracing.StartSpanOption
|
||||
for _, msg := range cm.Msgs {
|
||||
sc, loaded := ms.scMap.LoadAndDelete(msg.ID())
|
||||
if loaded {
|
||||
opts = append(opts, opentracing.ChildOf(sc.(opentracing.SpanContext)))
|
||||
}
|
||||
}
|
||||
if len(opts) != 0 {
|
||||
ctx = context.Background()
|
||||
}
|
||||
sp, ctx := trace.StartSpanFromContext(ctx, opts...)
|
||||
sp.Finish()
|
||||
return cm, ctx
|
||||
case <-ms.ctx.Done():
|
||||
log.Debug("context closed")
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -360,8 +352,15 @@ func (ms *PulsarMsgStream) receiveMsg(consumer Consumer) {
|
||||
MsgID: typeutil.PulsarMsgIDToString(pulsarMsg.ID()),
|
||||
})
|
||||
|
||||
sp, ok := trace.ExtractFromPulsarMsgProperties(tsMsg, pulsarMsg.Properties())
|
||||
if ok {
|
||||
ms.scMap.Store(tsMsg.ID(), sp.Context())
|
||||
}
|
||||
|
||||
msgPack := MsgPack{Msgs: []TsMsg{tsMsg}}
|
||||
ms.receiveBuf <- &msgPack
|
||||
|
||||
sp.Finish()
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -687,12 +686,18 @@ func (ms *PulsarTtMsgStream) findTimeTick(consumer Consumer,
|
||||
log.Error("Failed to unmarshal tsMsg", zap.Error(err))
|
||||
continue
|
||||
}
|
||||
|
||||
// set pulsar info to tsMsg
|
||||
tsMsg.SetPosition(&msgstream.MsgPosition{
|
||||
ChannelName: filepath.Base(pulsarMsg.Topic()),
|
||||
MsgID: typeutil.PulsarMsgIDToString(pulsarMsg.ID()),
|
||||
})
|
||||
|
||||
sp, ok := trace.ExtractFromPulsarMsgProperties(tsMsg, pulsarMsg.Properties())
|
||||
if ok {
|
||||
ms.scMap.Store(tsMsg.ID(), sp.Context())
|
||||
}
|
||||
|
||||
ms.unsolvedMutex.Lock()
|
||||
ms.unsolvedBuf[consumer] = append(ms.unsolvedBuf[consumer], tsMsg)
|
||||
ms.unsolvedMutex.Unlock()
|
||||
@ -701,8 +706,10 @@ func (ms *PulsarTtMsgStream) findTimeTick(consumer Consumer,
|
||||
findMapMutex.Lock()
|
||||
eofMsgMap[consumer] = tsMsg.(*TimeTickMsg).Base.Timestamp
|
||||
findMapMutex.Unlock()
|
||||
sp.Finish()
|
||||
return
|
||||
}
|
||||
sp.Finish()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -223,7 +223,7 @@ func initPulsarTtStream(pulsarAddress string,
|
||||
func receiveMsg(outputStream msgstream.MsgStream, msgCount int) {
|
||||
receiveCount := 0
|
||||
for {
|
||||
result := outputStream.Consume()
|
||||
result, _ := outputStream.Consume()
|
||||
if len(result.Msgs) > 0 {
|
||||
msgs := result.Msgs
|
||||
for _, v := range msgs {
|
||||
@ -607,13 +607,13 @@ func TestStream_PulsarTtMsgStream_Seek(t *testing.T) {
|
||||
assert.Nil(t, err)
|
||||
|
||||
outputStream.Consume()
|
||||
receivedMsg := outputStream.Consume()
|
||||
receivedMsg, _ := outputStream.Consume()
|
||||
for _, position := range receivedMsg.StartPositions {
|
||||
outputStream.Seek(position)
|
||||
}
|
||||
err = inputStream.Broadcast(ctx, &msgPack5)
|
||||
assert.Nil(t, err)
|
||||
seekMsg := outputStream.Consume()
|
||||
seekMsg, _ := outputStream.Consume()
|
||||
for _, msg := range seekMsg.Msgs {
|
||||
assert.Equal(t, msg.BeginTs(), uint64(14))
|
||||
}
|
||||
|
@ -219,18 +219,18 @@ func (ms *RmqMsgStream) Broadcast(ctx context.Context, msgPack *MsgPack) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ms *RmqMsgStream) Consume() *msgstream.MsgPack {
|
||||
func (ms *RmqMsgStream) Consume() (*msgstream.MsgPack, context.Context) {
|
||||
for {
|
||||
select {
|
||||
case cm, ok := <-ms.receiveBuf:
|
||||
if !ok {
|
||||
log.Println("buf chan closed")
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
return cm
|
||||
return cm, nil
|
||||
case <-ms.ctx.Done():
|
||||
log.Printf("context closed")
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -239,7 +239,7 @@ func initRmqTtStream(producerChannels []string,
|
||||
func receiveMsg(outputStream msgstream.MsgStream, msgCount int) {
|
||||
receiveCount := 0
|
||||
for {
|
||||
result := outputStream.Consume()
|
||||
result, _ := outputStream.Consume()
|
||||
if len(result.Msgs) > 0 {
|
||||
msgs := result.Msgs
|
||||
for _, v := range msgs {
|
||||
|
@ -5,11 +5,9 @@ import (
|
||||
"log"
|
||||
"sync"
|
||||
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
|
||||
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/internalpb2"
|
||||
|
||||
"github.com/zilliztech/milvus-distributed/internal/msgstream"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/internalpb2"
|
||||
)
|
||||
|
||||
type (
|
||||
@ -58,6 +56,9 @@ func (tt *TimeTickImpl) Start() error {
|
||||
},
|
||||
}
|
||||
msgPack.Msgs = append(msgPack.Msgs, timeTickMsg)
|
||||
for _, msg := range msgPack.Msgs {
|
||||
log.Println("msg type xxxxxxxxxxxxxxxxxxxxxxxx", msg.Type())
|
||||
}
|
||||
for _, channel := range tt.channels {
|
||||
err = channel.Broadcast(tt.ctx, &msgPack)
|
||||
if err != nil {
|
||||
|
@ -56,15 +56,15 @@ func (dsService *dataSyncService) initNodes() {
|
||||
var serviceTimeNode node = newServiceTimeNode(dsService.ctx, dsService.replica, dsService.msFactory)
|
||||
var gcNode node = newGCNode(dsService.replica)
|
||||
|
||||
dsService.fg.AddNode(&dmStreamNode)
|
||||
dsService.fg.AddNode(&ddStreamNode)
|
||||
dsService.fg.AddNode(dmStreamNode)
|
||||
dsService.fg.AddNode(ddStreamNode)
|
||||
|
||||
dsService.fg.AddNode(&filterDmNode)
|
||||
dsService.fg.AddNode(&ddNode)
|
||||
dsService.fg.AddNode(filterDmNode)
|
||||
dsService.fg.AddNode(ddNode)
|
||||
|
||||
dsService.fg.AddNode(&insertNode)
|
||||
dsService.fg.AddNode(&serviceTimeNode)
|
||||
dsService.fg.AddNode(&gcNode)
|
||||
dsService.fg.AddNode(insertNode)
|
||||
dsService.fg.AddNode(serviceTimeNode)
|
||||
dsService.fg.AddNode(gcNode)
|
||||
|
||||
// dmStreamNode
|
||||
var err = dsService.fg.SetEdges(dmStreamNode.Name(),
|
||||
|
@ -1,6 +1,7 @@
|
||||
package querynode
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
@ -19,15 +20,15 @@ func (ddNode *ddNode) Name() string {
|
||||
return "ddNode"
|
||||
}
|
||||
|
||||
func (ddNode *ddNode) Operate(in []*Msg) []*Msg {
|
||||
//fmt.Println("Do ddNode operation")
|
||||
func (ddNode *ddNode) Operate(ctx context.Context, in []Msg) ([]Msg, context.Context) {
|
||||
//fmt.Println("Do filterDmNode operation")
|
||||
|
||||
if len(in) != 1 {
|
||||
log.Println("Invalid operate message input in ddNode, input length = ", len(in))
|
||||
// TODO: add error handling
|
||||
}
|
||||
|
||||
msMsg, ok := (*in[0]).(*MsgStreamMsg)
|
||||
msMsg, ok := in[0].(*MsgStreamMsg)
|
||||
if !ok {
|
||||
log.Println("type assertion failed for MsgStreamMsg")
|
||||
// TODO: add error handling
|
||||
@ -72,7 +73,7 @@ func (ddNode *ddNode) Operate(in []*Msg) []*Msg {
|
||||
//}
|
||||
|
||||
var res Msg = ddNode.ddMsg
|
||||
return []*Msg{&res}
|
||||
return []Msg{res}, ctx
|
||||
}
|
||||
|
||||
func (ddNode *ddNode) createCollection(msg *msgstream.CreateCollectionMsg) {
|
||||
|
@ -1,6 +1,7 @@
|
||||
package querynode
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"math"
|
||||
|
||||
@ -18,7 +19,7 @@ func (fdmNode *filterDmNode) Name() string {
|
||||
return "fdmNode"
|
||||
}
|
||||
|
||||
func (fdmNode *filterDmNode) Operate(in []*Msg) []*Msg {
|
||||
func (fdmNode *filterDmNode) Operate(ctx context.Context, in []Msg) ([]Msg, context.Context) {
|
||||
//fmt.Println("Do filterDmNode operation")
|
||||
|
||||
if len(in) != 2 {
|
||||
@ -26,13 +27,13 @@ func (fdmNode *filterDmNode) Operate(in []*Msg) []*Msg {
|
||||
// TODO: add error handling
|
||||
}
|
||||
|
||||
msgStreamMsg, ok := (*in[0]).(*MsgStreamMsg)
|
||||
msgStreamMsg, ok := in[0].(*MsgStreamMsg)
|
||||
if !ok {
|
||||
log.Println("type assertion failed for MsgStreamMsg")
|
||||
// TODO: add error handling
|
||||
}
|
||||
|
||||
ddMsg, ok := (*in[1]).(*ddMsg)
|
||||
ddMsg, ok := in[1].(*ddMsg)
|
||||
if !ok {
|
||||
log.Println("type assertion failed for ddMsg")
|
||||
// TODO: add error handling
|
||||
@ -63,7 +64,7 @@ func (fdmNode *filterDmNode) Operate(in []*Msg) []*Msg {
|
||||
iMsg.gcRecord = ddMsg.gcRecord
|
||||
var res Msg = &iMsg
|
||||
|
||||
return []*Msg{&res}
|
||||
return []Msg{res}, ctx
|
||||
}
|
||||
|
||||
func (fdmNode *filterDmNode) filterInvalidInsertMessage(msg *msgstream.InsertMsg) *msgstream.InsertMsg {
|
||||
|
@ -1,6 +1,7 @@
|
||||
package querynode
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
)
|
||||
|
||||
@ -13,7 +14,7 @@ func (gcNode *gcNode) Name() string {
|
||||
return "gcNode"
|
||||
}
|
||||
|
||||
func (gcNode *gcNode) Operate(in []*Msg) []*Msg {
|
||||
func (gcNode *gcNode) Operate(ctx context.Context, in []Msg) ([]Msg, context.Context) {
|
||||
//fmt.Println("Do gcNode operation")
|
||||
|
||||
if len(in) != 1 {
|
||||
@ -21,7 +22,7 @@ func (gcNode *gcNode) Operate(in []*Msg) []*Msg {
|
||||
// TODO: add error handling
|
||||
}
|
||||
|
||||
_, ok := (*in[0]).(*gcMsg)
|
||||
_, ok := in[0].(*gcMsg)
|
||||
if !ok {
|
||||
log.Println("type assertion failed for gcMsg")
|
||||
// TODO: add error handling
|
||||
@ -47,7 +48,7 @@ func (gcNode *gcNode) Operate(in []*Msg) []*Msg {
|
||||
// }
|
||||
//}
|
||||
|
||||
return nil
|
||||
return nil, ctx
|
||||
}
|
||||
|
||||
func newGCNode(replica collectionReplica) *gcNode {
|
||||
|
@ -26,7 +26,7 @@ func (iNode *insertNode) Name() string {
|
||||
return "iNode"
|
||||
}
|
||||
|
||||
func (iNode *insertNode) Operate(in []*Msg) []*Msg {
|
||||
func (iNode *insertNode) Operate(ctx context.Context, in []Msg) ([]Msg, context.Context) {
|
||||
// fmt.Println("Do insertNode operation")
|
||||
|
||||
if len(in) != 1 {
|
||||
@ -34,7 +34,7 @@ func (iNode *insertNode) Operate(in []*Msg) []*Msg {
|
||||
// TODO: add error handling
|
||||
}
|
||||
|
||||
iMsg, ok := (*in[0]).(*insertMsg)
|
||||
iMsg, ok := in[0].(*insertMsg)
|
||||
if !ok {
|
||||
log.Println("type assertion failed for insertMsg")
|
||||
// TODO: add error handling
|
||||
@ -90,7 +90,7 @@ func (iNode *insertNode) Operate(in []*Msg) []*Msg {
|
||||
gcRecord: iMsg.gcRecord,
|
||||
timeRange: iMsg.timeRange,
|
||||
}
|
||||
return []*Msg{&res}
|
||||
return []Msg{res}, ctx
|
||||
}
|
||||
|
||||
func (iNode *insertNode) insert(insertData *InsertData, segmentID int64, wg *sync.WaitGroup) {
|
||||
|
@ -19,7 +19,7 @@ func (stNode *serviceTimeNode) Name() string {
|
||||
return "stNode"
|
||||
}
|
||||
|
||||
func (stNode *serviceTimeNode) Operate(in []*Msg) []*Msg {
|
||||
func (stNode *serviceTimeNode) Operate(ctx context.Context, in []Msg) ([]Msg, context.Context) {
|
||||
//fmt.Println("Do serviceTimeNode operation")
|
||||
|
||||
if len(in) != 1 {
|
||||
@ -27,7 +27,7 @@ func (stNode *serviceTimeNode) Operate(in []*Msg) []*Msg {
|
||||
// TODO: add error handling
|
||||
}
|
||||
|
||||
serviceTimeMsg, ok := (*in[0]).(*serviceTimeMsg)
|
||||
serviceTimeMsg, ok := in[0].(*serviceTimeMsg)
|
||||
if !ok {
|
||||
log.Println("type assertion failed for serviceTimeMsg")
|
||||
// TODO: add error handling
|
||||
@ -45,7 +45,7 @@ func (stNode *serviceTimeNode) Operate(in []*Msg) []*Msg {
|
||||
gcRecord: serviceTimeMsg.gcRecord,
|
||||
timeRange: serviceTimeMsg.timeRange,
|
||||
}
|
||||
return []*Msg{&res}
|
||||
return []Msg{res}, ctx
|
||||
}
|
||||
|
||||
func (stNode *serviceTimeNode) sendTimeTick(ts Timestamp) error {
|
||||
|
@ -15,12 +15,12 @@ import "C"
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/opentracing/opentracing-go"
|
||||
"github.com/uber/jaeger-client-go/config"
|
||||
"io"
|
||||
"log"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/opentracing/opentracing-go"
|
||||
"github.com/uber/jaeger-client-go/config"
|
||||
"github.com/zilliztech/milvus-distributed/internal/errors"
|
||||
"github.com/zilliztech/milvus-distributed/internal/msgstream"
|
||||
"github.com/zilliztech/milvus-distributed/internal/msgstream/pulsarms"
|
||||
@ -91,20 +91,6 @@ func NewQueryNode(ctx context.Context, queryNodeID UniqueID, factory msgstream.F
|
||||
msFactory: factory,
|
||||
}
|
||||
|
||||
cfg := &config.Configuration{
|
||||
ServiceName: fmt.Sprintf("query_node_%d", node.QueryNodeID),
|
||||
Sampler: &config.SamplerConfig{
|
||||
Type: "const",
|
||||
Param: 1,
|
||||
},
|
||||
}
|
||||
tracer, closer, err := cfg.NewTracer()
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("ERROR: cannot init Jaeger: %v\n", err))
|
||||
}
|
||||
opentracing.SetGlobalTracer(tracer)
|
||||
node.closer = closer
|
||||
|
||||
node.replica = newCollectionReplicaImpl()
|
||||
node.UpdateStateCode(internalpb2.StateCode_ABNORMAL)
|
||||
return node
|
||||
@ -167,6 +153,20 @@ func (node *QueryNode) Init() error {
|
||||
|
||||
fmt.Println("QueryNodeID is", Params.QueryNodeID)
|
||||
|
||||
cfg := &config.Configuration{
|
||||
ServiceName: fmt.Sprintf("query_node_%d", node.QueryNodeID),
|
||||
Sampler: &config.SamplerConfig{
|
||||
Type: "const",
|
||||
Param: 1,
|
||||
},
|
||||
}
|
||||
tracer, closer, err := cfg.NewTracer()
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("ERROR: cannot init Jaeger: %v\n", err))
|
||||
}
|
||||
opentracing.SetGlobalTracer(tracer)
|
||||
node.closer = closer
|
||||
|
||||
if node.masterClient == nil {
|
||||
log.Println("WARN: null master service detected")
|
||||
}
|
||||
@ -212,9 +212,6 @@ func (node *QueryNode) Start() error {
|
||||
}
|
||||
|
||||
func (node *QueryNode) Stop() error {
|
||||
if err := node.closer.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
node.UpdateStateCode(internalpb2.StateCode_ABNORMAL)
|
||||
node.queryNodeLoopCancel()
|
||||
|
||||
|
@ -121,7 +121,7 @@ func (ss *searchService) receiveSearchMsg() {
|
||||
case <-ss.ctx.Done():
|
||||
return
|
||||
default:
|
||||
msgPack := ss.searchMsgStream.Consume()
|
||||
msgPack, _ := ss.searchMsgStream.Consume()
|
||||
if msgPack == nil || len(msgPack.Msgs) <= 0 {
|
||||
continue
|
||||
}
|
||||
|
@ -93,7 +93,7 @@ func (ttBarrier *softTimeTickBarrier) Start() {
|
||||
return
|
||||
default:
|
||||
}
|
||||
ttmsgs := ttBarrier.ttStream.Consume()
|
||||
ttmsgs, _ := ttBarrier.ttStream.Consume()
|
||||
if len(ttmsgs.Msgs) > 0 {
|
||||
for _, timetickmsg := range ttmsgs.Msgs {
|
||||
ttmsg := timetickmsg.(*ms.TimeTickMsg)
|
||||
@ -156,7 +156,7 @@ func (ttBarrier *hardTimeTickBarrier) Start() {
|
||||
return
|
||||
default:
|
||||
}
|
||||
ttmsgs := ttBarrier.ttStream.Consume()
|
||||
ttmsgs, _ := ttBarrier.ttStream.Consume()
|
||||
if len(ttmsgs.Msgs) > 0 {
|
||||
log.Printf("receive tt msg")
|
||||
for _, timetickmsg := range ttmsgs.Msgs {
|
||||
|
@ -13,11 +13,11 @@ type TimeTickedFlowGraph struct {
|
||||
nodeCtx map[NodeName]*nodeCtx
|
||||
}
|
||||
|
||||
func (fg *TimeTickedFlowGraph) AddNode(node *Node) {
|
||||
nodeName := (*node).Name()
|
||||
func (fg *TimeTickedFlowGraph) AddNode(node Node) {
|
||||
nodeName := node.Name()
|
||||
nodeCtx := nodeCtx{
|
||||
node: node,
|
||||
inputChannels: make([]chan *Msg, 0),
|
||||
inputChannels: make([]chan *MsgWithCtx, 0),
|
||||
downstreamInputChanIdx: make(map[string]int),
|
||||
}
|
||||
fg.nodeCtx[nodeName] = &nodeCtx
|
||||
@ -50,8 +50,8 @@ func (fg *TimeTickedFlowGraph) SetEdges(nodeName string, in []string, out []stri
|
||||
errMsg := "Cannot find out node:" + n
|
||||
return errors.New(errMsg)
|
||||
}
|
||||
maxQueueLength := (*outNode.node).MaxQueueLength()
|
||||
outNode.inputChannels = append(outNode.inputChannels, make(chan *Msg, maxQueueLength))
|
||||
maxQueueLength := outNode.node.MaxQueueLength()
|
||||
outNode.inputChannels = append(outNode.inputChannels, make(chan *MsgWithCtx, maxQueueLength))
|
||||
currentNode.downstream[i] = outNode
|
||||
}
|
||||
|
||||
@ -70,8 +70,8 @@ func (fg *TimeTickedFlowGraph) Start() {
|
||||
func (fg *TimeTickedFlowGraph) Close() {
|
||||
for _, v := range fg.nodeCtx {
|
||||
// close message stream
|
||||
if (*v.node).IsInputNode() {
|
||||
inStream, ok := (*v.node).(*InputNode)
|
||||
if v.node.IsInputNode() {
|
||||
inStream, ok := v.node.(*InputNode)
|
||||
if !ok {
|
||||
log.Fatal("Invalid inputNode")
|
||||
}
|
||||
|
@ -47,19 +47,19 @@ func (m *intMsg) DownStreamNodeIdx() int {
|
||||
return 1
|
||||
}
|
||||
|
||||
func intMsg2Msg(in []*intMsg) []*Msg {
|
||||
out := make([]*Msg, 0)
|
||||
func intMsg2Msg(in []*intMsg) []Msg {
|
||||
out := make([]Msg, 0)
|
||||
for _, msg := range in {
|
||||
var m Msg = msg
|
||||
out = append(out, &m)
|
||||
out = append(out, m)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func msg2IntMsg(in []*Msg) []*intMsg {
|
||||
func msg2IntMsg(in []Msg) []*intMsg {
|
||||
out := make([]*intMsg, 0)
|
||||
for _, msg := range in {
|
||||
out = append(out, (*msg).(*intMsg))
|
||||
out = append(out, msg.(*intMsg))
|
||||
}
|
||||
return out
|
||||
}
|
||||
@ -68,43 +68,43 @@ func (a *nodeA) Name() string {
|
||||
return "NodeA"
|
||||
}
|
||||
|
||||
func (a *nodeA) Operate(in []*Msg) []*Msg {
|
||||
return append(in, in...)
|
||||
func (a *nodeA) Operate(ctx context.Context, in []Msg) ([]Msg, context.Context) {
|
||||
return append(in, in...), nil
|
||||
}
|
||||
|
||||
func (b *nodeB) Name() string {
|
||||
return "NodeB"
|
||||
}
|
||||
|
||||
func (b *nodeB) Operate(in []*Msg) []*Msg {
|
||||
func (b *nodeB) Operate(ctx context.Context, in []Msg) ([]Msg, context.Context) {
|
||||
messages := make([]*intMsg, 0)
|
||||
for _, msg := range msg2IntMsg(in) {
|
||||
messages = append(messages, &intMsg{
|
||||
num: math.Pow(msg.num, 2),
|
||||
})
|
||||
}
|
||||
return intMsg2Msg(messages)
|
||||
return intMsg2Msg(messages), nil
|
||||
}
|
||||
|
||||
func (c *nodeC) Name() string {
|
||||
return "NodeC"
|
||||
}
|
||||
|
||||
func (c *nodeC) Operate(in []*Msg) []*Msg {
|
||||
func (c *nodeC) Operate(ctx context.Context, in []Msg) ([]Msg, context.Context) {
|
||||
messages := make([]*intMsg, 0)
|
||||
for _, msg := range msg2IntMsg(in) {
|
||||
messages = append(messages, &intMsg{
|
||||
num: math.Sqrt(msg.num),
|
||||
})
|
||||
}
|
||||
return intMsg2Msg(messages)
|
||||
return intMsg2Msg(messages), nil
|
||||
}
|
||||
|
||||
func (d *nodeD) Name() string {
|
||||
return "NodeD"
|
||||
}
|
||||
|
||||
func (d *nodeD) Operate(in []*Msg) []*Msg {
|
||||
func (d *nodeD) Operate(ctx context.Context, in []Msg) ([]Msg, context.Context) {
|
||||
messages := make([]*intMsg, 0)
|
||||
outLength := len(in) / 2
|
||||
inMessages := msg2IntMsg(in)
|
||||
@ -117,7 +117,7 @@ func (d *nodeD) Operate(in []*Msg) []*Msg {
|
||||
d.d = messages[0].num
|
||||
d.resChan <- d.d
|
||||
fmt.Println("flow graph result:", d.d)
|
||||
return intMsg2Msg(messages)
|
||||
return intMsg2Msg(messages), nil
|
||||
}
|
||||
|
||||
func sendMsgFromCmd(ctx context.Context, fg *TimeTickedFlowGraph) {
|
||||
@ -129,8 +129,12 @@ func sendMsgFromCmd(ctx context.Context, fg *TimeTickedFlowGraph) {
|
||||
time.Sleep(time.Millisecond * time.Duration(500))
|
||||
var num = float64(rand.Int() % 100)
|
||||
var msg Msg = &intMsg{num: num}
|
||||
var msgWithContext = &MsgWithCtx{
|
||||
ctx: ctx,
|
||||
msg: msg,
|
||||
}
|
||||
a := nodeA{}
|
||||
fg.nodeCtx[a.Name()].inputChannels[0] <- &msg
|
||||
fg.nodeCtx[a.Name()].inputChannels[0] <- msgWithContext
|
||||
fmt.Println("send number", num, "to node", a.Name())
|
||||
res, ok := receiveResult(ctx, fg)
|
||||
if !ok {
|
||||
@ -156,7 +160,7 @@ func sendMsgFromCmd(ctx context.Context, fg *TimeTickedFlowGraph) {
|
||||
func receiveResultFromNodeD(res *float64, fg *TimeTickedFlowGraph, wg *sync.WaitGroup) {
|
||||
d := nodeD{}
|
||||
node := fg.nodeCtx[d.Name()]
|
||||
nd, ok := (*node.node).(*nodeD)
|
||||
nd, ok := node.node.(*nodeD)
|
||||
if !ok {
|
||||
log.Fatal("not nodeD type")
|
||||
}
|
||||
@ -167,7 +171,7 @@ func receiveResultFromNodeD(res *float64, fg *TimeTickedFlowGraph, wg *sync.Wait
|
||||
func receiveResult(ctx context.Context, fg *TimeTickedFlowGraph) (float64, bool) {
|
||||
d := nodeD{}
|
||||
node := fg.nodeCtx[d.Name()]
|
||||
nd, ok := (*node.node).(*nodeD)
|
||||
nd, ok := node.node.(*nodeD)
|
||||
if !ok {
|
||||
log.Fatal("not nodeD type")
|
||||
}
|
||||
@ -211,10 +215,10 @@ func TestTimeTickedFlowGraph_Start(t *testing.T) {
|
||||
resChan: make(chan float64),
|
||||
}
|
||||
|
||||
fg.AddNode(&a)
|
||||
fg.AddNode(&b)
|
||||
fg.AddNode(&c)
|
||||
fg.AddNode(&d)
|
||||
fg.AddNode(a)
|
||||
fg.AddNode(b)
|
||||
fg.AddNode(c)
|
||||
fg.AddNode(d)
|
||||
|
||||
var err = fg.SetEdges(a.Name(),
|
||||
[]string{},
|
||||
@ -250,7 +254,7 @@ func TestTimeTickedFlowGraph_Start(t *testing.T) {
|
||||
|
||||
// init node A
|
||||
nodeCtxA := fg.nodeCtx[a.Name()]
|
||||
nodeCtxA.inputChannels = []chan *Msg{make(chan *Msg, 10)}
|
||||
nodeCtxA.inputChannels = []chan *MsgWithCtx{make(chan *MsgWithCtx, 10)}
|
||||
|
||||
go fg.Start()
|
||||
|
||||
|
@ -1,9 +1,13 @@
|
||||
package flowgraph
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
|
||||
"github.com/opentracing/opentracing-go"
|
||||
"github.com/zilliztech/milvus-distributed/internal/errors"
|
||||
"github.com/zilliztech/milvus-distributed/internal/msgstream"
|
||||
"github.com/zilliztech/milvus-distributed/internal/util/trace"
|
||||
)
|
||||
|
||||
type InputNode struct {
|
||||
@ -25,15 +29,19 @@ func (inNode *InputNode) InStream() *msgstream.MsgStream {
|
||||
}
|
||||
|
||||
// empty input and return one *Msg
|
||||
func (inNode *InputNode) Operate([]*Msg) []*Msg {
|
||||
func (inNode *InputNode) Operate(ctx context.Context, msgs []Msg) ([]Msg, context.Context) {
|
||||
//fmt.Println("Do InputNode operation")
|
||||
|
||||
msgPack := (*inNode.inStream).Consume()
|
||||
msgPack, ctx := (*inNode.inStream).Consume()
|
||||
|
||||
sp, ctx := trace.StartSpanFromContext(ctx, opentracing.Tag{Key: "NodeName", Value: inNode.Name()})
|
||||
defer sp.Finish()
|
||||
|
||||
// TODO: add status
|
||||
if msgPack == nil {
|
||||
log.Println("null msg pack")
|
||||
return nil
|
||||
trace.LogError(sp, errors.New("null msg pack"))
|
||||
return nil, ctx
|
||||
}
|
||||
|
||||
var msgStreamMsg Msg = &MsgStreamMsg{
|
||||
@ -43,7 +51,7 @@ func (inNode *InputNode) Operate([]*Msg) []*Msg {
|
||||
startPositions: msgPack.StartPositions,
|
||||
}
|
||||
|
||||
return []*Msg{&msgStreamMsg}
|
||||
return []Msg{msgStreamMsg}, ctx
|
||||
}
|
||||
|
||||
func NewInputNode(inStream *msgstream.MsgStream, nodeName string, maxQueueLength int32, maxParallelism int32) *InputNode {
|
||||
|
@ -6,13 +6,16 @@ import (
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/opentracing/opentracing-go"
|
||||
"github.com/zilliztech/milvus-distributed/internal/util/trace"
|
||||
)
|
||||
|
||||
type Node interface {
|
||||
Name() string
|
||||
MaxQueueLength() int32
|
||||
MaxParallelism() int32
|
||||
Operate(in []*Msg) []*Msg
|
||||
Operate(ctx context.Context, in []Msg) ([]Msg, context.Context)
|
||||
IsInputNode() bool
|
||||
}
|
||||
|
||||
@ -22,9 +25,9 @@ type BaseNode struct {
|
||||
}
|
||||
|
||||
type nodeCtx struct {
|
||||
node *Node
|
||||
inputChannels []chan *Msg
|
||||
inputMessages []*Msg
|
||||
node Node
|
||||
inputChannels []chan *MsgWithCtx
|
||||
inputMessages []Msg
|
||||
downstream []*nodeCtx
|
||||
downstreamInputChanIdx map[string]int
|
||||
|
||||
@ -32,10 +35,15 @@ type nodeCtx struct {
|
||||
NumCompletedTasks int64
|
||||
}
|
||||
|
||||
type MsgWithCtx struct {
|
||||
ctx context.Context
|
||||
msg Msg
|
||||
}
|
||||
|
||||
func (nodeCtx *nodeCtx) Start(ctx context.Context, wg *sync.WaitGroup) {
|
||||
if (*nodeCtx.node).IsInputNode() {
|
||||
if nodeCtx.node.IsInputNode() {
|
||||
// fmt.Println("start InputNode.inStream")
|
||||
inStream, ok := (*nodeCtx.node).(*InputNode)
|
||||
inStream, ok := nodeCtx.node.(*InputNode)
|
||||
if !ok {
|
||||
log.Fatal("Invalid inputNode")
|
||||
}
|
||||
@ -46,19 +54,23 @@ func (nodeCtx *nodeCtx) Start(ctx context.Context, wg *sync.WaitGroup) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
wg.Done()
|
||||
fmt.Println((*nodeCtx.node).Name(), "closed")
|
||||
fmt.Println(nodeCtx.node.Name(), "closed")
|
||||
return
|
||||
default:
|
||||
// inputs from inputsMessages for Operate
|
||||
inputs := make([]*Msg, 0)
|
||||
inputs := make([]Msg, 0)
|
||||
|
||||
if !(*nodeCtx.node).IsInputNode() {
|
||||
nodeCtx.collectInputMessages()
|
||||
var msgCtx context.Context
|
||||
var res []Msg
|
||||
var sp opentracing.Span
|
||||
if !nodeCtx.node.IsInputNode() {
|
||||
msgCtx = nodeCtx.collectInputMessages()
|
||||
inputs = nodeCtx.inputMessages
|
||||
}
|
||||
|
||||
n := *nodeCtx.node
|
||||
res := n.Operate(inputs)
|
||||
n := nodeCtx.node
|
||||
res, msgCtx = n.Operate(msgCtx, inputs)
|
||||
sp, msgCtx = trace.StartSpanFromContext(msgCtx)
|
||||
sp.SetTag("node name", n.Name())
|
||||
|
||||
downstreamLength := len(nodeCtx.downstreamInputChanIdx)
|
||||
if len(nodeCtx.downstream) < downstreamLength {
|
||||
@ -72,9 +84,10 @@ func (nodeCtx *nodeCtx) Start(ctx context.Context, wg *sync.WaitGroup) {
|
||||
w := sync.WaitGroup{}
|
||||
for i := 0; i < downstreamLength; i++ {
|
||||
w.Add(1)
|
||||
go nodeCtx.downstream[i].ReceiveMsg(&w, res[i], nodeCtx.downstreamInputChanIdx[(*nodeCtx.downstream[i].node).Name()])
|
||||
go nodeCtx.downstream[i].ReceiveMsg(msgCtx, &w, res[i], nodeCtx.downstreamInputChanIdx[nodeCtx.downstream[i].node.Name()])
|
||||
}
|
||||
w.Wait()
|
||||
sp.Finish()
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -86,38 +99,54 @@ func (nodeCtx *nodeCtx) Close() {
|
||||
}
|
||||
}
|
||||
|
||||
func (nodeCtx *nodeCtx) ReceiveMsg(wg *sync.WaitGroup, msg *Msg, inputChanIdx int) {
|
||||
nodeCtx.inputChannels[inputChanIdx] <- msg
|
||||
func (nodeCtx *nodeCtx) ReceiveMsg(ctx context.Context, wg *sync.WaitGroup, msg Msg, inputChanIdx int) {
|
||||
sp, ctx := trace.StartSpanFromContext(ctx)
|
||||
defer sp.Finish()
|
||||
nodeCtx.inputChannels[inputChanIdx] <- &MsgWithCtx{ctx: ctx, msg: msg}
|
||||
//fmt.Println((*nodeCtx.node).Name(), "receive to input channel ", inputChanIdx)
|
||||
|
||||
wg.Done()
|
||||
}
|
||||
|
||||
func (nodeCtx *nodeCtx) collectInputMessages() {
|
||||
func (nodeCtx *nodeCtx) collectInputMessages() context.Context {
|
||||
var opts []opentracing.StartSpanOption
|
||||
|
||||
inputsNum := len(nodeCtx.inputChannels)
|
||||
nodeCtx.inputMessages = make([]*Msg, inputsNum)
|
||||
nodeCtx.inputMessages = make([]Msg, inputsNum)
|
||||
|
||||
// init inputMessages,
|
||||
// receive messages from inputChannels,
|
||||
// and move them to inputMessages.
|
||||
for i := 0; i < inputsNum; i++ {
|
||||
channel := nodeCtx.inputChannels[i]
|
||||
msg, ok := <-channel
|
||||
msgWithCtx, ok := <-channel
|
||||
if !ok {
|
||||
// TODO: add status
|
||||
log.Println("input channel closed")
|
||||
return
|
||||
return nil
|
||||
}
|
||||
nodeCtx.inputMessages[i] = msg
|
||||
nodeCtx.inputMessages[i] = msgWithCtx.msg
|
||||
if msgWithCtx.ctx != nil {
|
||||
sp, _ := trace.StartSpanFromContext(msgWithCtx.ctx)
|
||||
opts = append(opts, opentracing.ChildOf(sp.Context()))
|
||||
sp.Finish()
|
||||
}
|
||||
}
|
||||
|
||||
var ctx context.Context
|
||||
var sp opentracing.Span
|
||||
if len(opts) != 0 {
|
||||
sp, ctx = trace.StartSpanFromContext(context.Background(), opts...)
|
||||
defer sp.Finish()
|
||||
}
|
||||
|
||||
// timeTick alignment check
|
||||
if len(nodeCtx.inputMessages) > 1 {
|
||||
t := (*nodeCtx.inputMessages[0]).TimeTick()
|
||||
t := nodeCtx.inputMessages[0].TimeTick()
|
||||
latestTime := t
|
||||
for i := 1; i < len(nodeCtx.inputMessages); i++ {
|
||||
if t < (*nodeCtx.inputMessages[i]).TimeTick() {
|
||||
latestTime = (*nodeCtx.inputMessages[i]).TimeTick()
|
||||
if t < nodeCtx.inputMessages[i].TimeTick() {
|
||||
latestTime = nodeCtx.inputMessages[i].TimeTick()
|
||||
//err := errors.New("Fatal, misaligned time tick," +
|
||||
// "t1=" + strconv.FormatUint(time, 10) +
|
||||
// ", t2=" + strconv.FormatUint((*nodeCtx.inputMessages[i]).TimeTick(), 10) +
|
||||
@ -127,7 +156,7 @@ func (nodeCtx *nodeCtx) collectInputMessages() {
|
||||
}
|
||||
// wait for time tick
|
||||
for i := 0; i < len(nodeCtx.inputMessages); i++ {
|
||||
for (*nodeCtx.inputMessages[i]).TimeTick() != latestTime {
|
||||
for nodeCtx.inputMessages[i].TimeTick() != latestTime {
|
||||
channel := nodeCtx.inputChannels[i]
|
||||
select {
|
||||
case <-time.After(10 * time.Second):
|
||||
@ -135,13 +164,14 @@ func (nodeCtx *nodeCtx) collectInputMessages() {
|
||||
case msg, ok := <-channel:
|
||||
if !ok {
|
||||
log.Println("input channel closed")
|
||||
return
|
||||
return nil
|
||||
}
|
||||
nodeCtx.inputMessages[i] = msg
|
||||
nodeCtx.inputMessages[i] = msg.msg
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return ctx
|
||||
}
|
||||
|
||||
func (node *BaseNode) MaxQueueLength() int32 {
|
||||
|
@ -16,7 +16,7 @@ import (
|
||||
|
||||
func StartSpanFromContext(ctx context.Context, opts ...opentracing.StartSpanOption) (opentracing.Span, context.Context) {
|
||||
if ctx == nil {
|
||||
panic("StartSpanFromContext called with nil context")
|
||||
return noopSpan(), ctx
|
||||
}
|
||||
|
||||
var pcs [1]uintptr
|
||||
@ -45,7 +45,7 @@ func StartSpanFromContext(ctx context.Context, opts ...opentracing.StartSpanOpti
|
||||
|
||||
func StartSpanFromContextWithOperationName(ctx context.Context, operationName string, opts ...opentracing.StartSpanOption) (opentracing.Span, context.Context) {
|
||||
if ctx == nil {
|
||||
panic("StartSpanFromContextWithOperationName called with nil context")
|
||||
return noopSpan(), ctx
|
||||
}
|
||||
|
||||
var pcs [1]uintptr
|
||||
@ -109,9 +109,9 @@ func InjectContextToPulsarMsgProperties(sc opentracing.SpanContext, properties m
|
||||
tracer.Inject(sc, opentracing.TextMap, propertiesReaderWriter{properties})
|
||||
}
|
||||
|
||||
func ExtractFromPulsarMsgProperties(msg msgstream.TsMsg, properties map[string]string) opentracing.Span {
|
||||
func ExtractFromPulsarMsgProperties(msg msgstream.TsMsg, properties map[string]string) (opentracing.Span, bool) {
|
||||
if !allowTrace(msg) {
|
||||
return noopSpan()
|
||||
return noopSpan(), false
|
||||
}
|
||||
tracer := opentracing.GlobalTracer()
|
||||
sc, _ := tracer.Extract(opentracing.TextMap, propertiesReaderWriter{properties})
|
||||
@ -124,21 +124,42 @@ func ExtractFromPulsarMsgProperties(msg msgstream.TsMsg, properties map[string]s
|
||||
"HashKeys": msg.HashKeys(),
|
||||
"Position": msg.Position(),
|
||||
}}
|
||||
return opentracing.StartSpan(name, opts...)
|
||||
return opentracing.StartSpan(name, opts...), true
|
||||
}
|
||||
|
||||
func MsgSpanFromCtx(ctx context.Context, msg msgstream.TsMsg, opts ...opentracing.StartSpanOption) (opentracing.Span, context.Context) {
|
||||
if ctx == nil {
|
||||
return noopSpan(), ctx
|
||||
}
|
||||
if !allowTrace(msg) {
|
||||
return noopSpan(), ctx
|
||||
}
|
||||
name := "send pulsar msg"
|
||||
operationName := "send pulsar msg"
|
||||
opts = append(opts, opentracing.Tags{
|
||||
"ID": msg.ID(),
|
||||
"Type": msg.Type(),
|
||||
"HashKeys": msg.HashKeys(),
|
||||
"Position": msg.Position(),
|
||||
})
|
||||
return StartSpanFromContextWithOperationName(ctx, name, opts...)
|
||||
|
||||
var pcs [1]uintptr
|
||||
n := runtime.Callers(2, pcs[:])
|
||||
if n < 1 {
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, operationName, opts...)
|
||||
span.LogFields(log.Error(errors.New("runtime.Callers failed")))
|
||||
return span, ctx
|
||||
}
|
||||
file, line := runtime.FuncForPC(pcs[0]).FileLine(pcs[0])
|
||||
|
||||
if parentSpan := opentracing.SpanFromContext(ctx); parentSpan != nil {
|
||||
opts = append(opts, opentracing.ChildOf(parentSpan.Context()))
|
||||
}
|
||||
span := opentracing.StartSpan(operationName, opts...)
|
||||
ctx = opentracing.ContextWithSpan(ctx, span)
|
||||
|
||||
span.LogFields(log.String("filename", file), log.Int("line", line))
|
||||
|
||||
return span, ctx
|
||||
}
|
||||
|
||||
type propertiesReaderWriter struct {
|
||||
|
39
tests/benchmark/README.md
Normal file
39
tests/benchmark/README.md
Normal file
@ -0,0 +1,39 @@
|
||||
# Quick start
|
||||
|
||||
### Description:
|
||||
|
||||
This project is used to test performance/reliability/stability for milvus server
|
||||
- Test cases can be organized with `yaml`
|
||||
- Test can run with local mode or helm mode
|
||||
|
||||
### Usage:
|
||||
`pip install requirements.txt`
|
||||
|
||||
if using local mode, the following libs is optional
|
||||
|
||||
`pymongo==3.10.0`
|
||||
|
||||
`kubernetes==10.0.1`
|
||||
|
||||
### Demos:
|
||||
|
||||
1. Local test:
|
||||
|
||||
`python3 main.py --local --host=*.* --port=19530 --suite=suites/gpu_search_performance_random50m.yaml`
|
||||
|
||||
### Definitions of test suites:
|
||||
|
||||
Testers need to write test suite config if adding a customizised test into the current test framework
|
||||
|
||||
1. search_performance: the test type,also we have`build_performance`,`insert_performance`,`accuracy`,`stability`,`search_stability`
|
||||
2. tables: list of test cases
|
||||
3. The following fields are in the `table` field:
|
||||
- server: run host
|
||||
- milvus: config in milvus
|
||||
- collection_name: currently support one collection
|
||||
- run_count: search count
|
||||
- search_params: params of query
|
||||
|
||||
## Test result:
|
||||
|
||||
Test result will be uploaded if tests run in helm mode, and will be used to judge if the test run pass or failed
|
0
tests/benchmark/__init__.py
Normal file
0
tests/benchmark/__init__.py
Normal file
BIN
tests/benchmark/assets/Parameters.png
Normal file
BIN
tests/benchmark/assets/Parameters.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 50 KiB |
BIN
tests/benchmark/assets/gpu_search_performance_random50m-yaml.png
Normal file
BIN
tests/benchmark/assets/gpu_search_performance_random50m-yaml.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 65 KiB |
Binary file not shown.
After Width: | Height: | Size: 44 KiB |
10
tests/benchmark/ci/function/file_transfer.groovy
Normal file
10
tests/benchmark/ci/function/file_transfer.groovy
Normal file
@ -0,0 +1,10 @@
|
||||
def FileTransfer (sourceFiles, remoteDirectory, remoteIP, protocol = "ftp", makeEmptyDirs = true) {
|
||||
if (protocol == "ftp") {
|
||||
ftpPublisher masterNodeName: '', paramPublish: [parameterName: ''], alwaysPublishFromMaster: false, continueOnError: false, failOnError: true, publishers: [
|
||||
[configName: "${remoteIP}", transfers: [
|
||||
[asciiMode: false, cleanRemote: false, excludes: '', flatten: false, makeEmptyDirs: "${makeEmptyDirs}", noDefaultExcludes: false, patternSeparator: '[, ]+', remoteDirectory: "${remoteDirectory}", remoteDirectorySDF: false, removePrefix: '', sourceFiles: "${sourceFiles}"]], usePromotionTimestamp: true, useWorkspaceInPromotion: false, verbose: true
|
||||
]
|
||||
]
|
||||
}
|
||||
}
|
||||
return this
|
13
tests/benchmark/ci/jenkinsfile/cleanup.groovy
Normal file
13
tests/benchmark/ci/jenkinsfile/cleanup.groovy
Normal file
@ -0,0 +1,13 @@
|
||||
try {
|
||||
def result = sh script: "helm status -n milvus ${env.HELM_RELEASE_NAME}", returnStatus: true
|
||||
if (!result) {
|
||||
sh "helm uninstall -n milvus ${env.HELM_RELEASE_NAME}"
|
||||
}
|
||||
} catch (exc) {
|
||||
def result = sh script: "helm status -n milvus ${env.HELM_RELEASE_NAME}", returnStatus: true
|
||||
if (!result) {
|
||||
sh "helm uninstall -n milvus ${env.HELM_RELEASE_NAME}"
|
||||
}
|
||||
throw exc
|
||||
}
|
||||
|
13
tests/benchmark/ci/jenkinsfile/cleanupShards.groovy
Normal file
13
tests/benchmark/ci/jenkinsfile/cleanupShards.groovy
Normal file
@ -0,0 +1,13 @@
|
||||
try {
|
||||
def result = sh script: "helm status -n milvus ${env.HELM_SHARDS_RELEASE_NAME}", returnStatus: true
|
||||
if (!result) {
|
||||
sh "helm uninstall -n milvus ${env.HELM_SHARDS_RELEASE_NAME}"
|
||||
}
|
||||
} catch (exc) {
|
||||
def result = sh script: "helm status -n milvus ${env.HELM_SHARDS_RELEASE_NAME}", returnStatus: true
|
||||
if (!result) {
|
||||
sh "helm uninstall -n milvus ${env.HELM_SHARDS_RELEASE_NAME}"
|
||||
}
|
||||
throw exc
|
||||
}
|
||||
|
21
tests/benchmark/ci/jenkinsfile/deploy_shards_test.groovy
Normal file
21
tests/benchmark/ci/jenkinsfile/deploy_shards_test.groovy
Normal file
@ -0,0 +1,21 @@
|
||||
timeout(time: 12, unit: 'HOURS') {
|
||||
try {
|
||||
dir ("milvus-helm") {
|
||||
// sh 'helm init --client-only --skip-refresh --stable-repo-url https://kubernetes.oss-cn-hangzhou.aliyuncs.com/charts'
|
||||
// sh 'helm repo update'
|
||||
checkout([$class: 'GitSCM', branches: [[name: "${HELM_BRANCH}"]], userRemoteConfigs: [[url: "${HELM_URL}", name: 'origin', refspec: "+refs/heads/${HELM_BRANCH}:refs/remotes/origin/${HELM_BRANCH}"]]])
|
||||
}
|
||||
dir ("milvus_benchmark") {
|
||||
print "Git clone url: ${TEST_URL}:${TEST_BRANCH}"
|
||||
checkout([$class: 'GitSCM', branches: [[name: "${TEST_BRANCH}"]], doGenerateSubmoduleConfigurations: false, extensions: [], submoduleCfg: [], userRemoteConfigs: [[credentialsId: "${params.GIT_USER}", url: "${TEST_URL}", name: 'origin', refspec: "+refs/heads/${TEST_BRANCH}:refs/remotes/origin/${TEST_BRANCH}"]]])
|
||||
print "Install requirements"
|
||||
// sh "python3 -m pip install -r requirements.txt -i http://pypi.douban.com/simple --trusted-host pypi.douban.com"
|
||||
sh "python3 -m pip install -r requirements.txt"
|
||||
sh "python3 -m pip install git+${TEST_LIB_URL}"
|
||||
sh "python3 main.py --image-version=${params.IMAGE_VERSION} --schedule-conf=scheduler/${params.SHARDS_CONFIG_FILE} --deploy-mode=${params.DEPLOY_MODE}"
|
||||
}
|
||||
} catch (exc) {
|
||||
echo 'Deploy SHARDS Test Failed !'
|
||||
throw exc
|
||||
}
|
||||
}
|
19
tests/benchmark/ci/jenkinsfile/deploy_test.groovy
Normal file
19
tests/benchmark/ci/jenkinsfile/deploy_test.groovy
Normal file
@ -0,0 +1,19 @@
|
||||
try {
|
||||
dir ("milvus-helm") {
|
||||
// sh 'helm init --client-only --skip-refresh --stable-repo-url https://kubernetes.oss-cn-hangzhou.aliyuncs.com/charts'
|
||||
// sh 'helm repo update'
|
||||
checkout([$class: 'GitSCM', branches: [[name: "${HELM_BRANCH}"]], userRemoteConfigs: [[url: "${HELM_URL}", name: 'origin', refspec: "+refs/heads/${HELM_BRANCH}:refs/remotes/origin/${HELM_BRANCH}"]]])
|
||||
}
|
||||
dir ("milvus_benchmark") {
|
||||
print "Git clone url: ${TEST_URL}:${TEST_BRANCH}"
|
||||
checkout([$class: 'GitSCM', branches: [[name: "${TEST_BRANCH}"]], doGenerateSubmoduleConfigurations: false, extensions: [], submoduleCfg: [], userRemoteConfigs: [[credentialsId: "${params.GIT_USER}", url: "${TEST_URL}", name: 'origin', refspec: "+refs/heads/${TEST_BRANCH}:refs/remotes/origin/${TEST_BRANCH}"]]])
|
||||
print "Install requirements"
|
||||
sh "python3 -m pip install -r requirements.txt -i http://pypi.douban.com/simple --trusted-host pypi.douban.com"
|
||||
// sh "python3 -m pip install -r requirements.txt"
|
||||
sh "python3 -m pip install git+${TEST_LIB_URL}"
|
||||
sh "python3 main.py --image-version=${params.IMAGE_VERSION} --schedule-conf=scheduler/${params.CONFIG_FILE} --deploy-mode=${params.DEPLOY_MODE}"
|
||||
}
|
||||
} catch (exc) {
|
||||
echo 'Deploy Test Failed !'
|
||||
throw exc
|
||||
}
|
15
tests/benchmark/ci/jenkinsfile/notify.groovy
Normal file
15
tests/benchmark/ci/jenkinsfile/notify.groovy
Normal file
@ -0,0 +1,15 @@
|
||||
def notify() {
|
||||
if (!currentBuild.resultIsBetterOrEqualTo('SUCCESS')) {
|
||||
// Send an email only if the build status has changed from green/unstable to red
|
||||
emailext subject: '$DEFAULT_SUBJECT',
|
||||
body: '$DEFAULT_CONTENT',
|
||||
recipientProviders: [
|
||||
[$class: 'DevelopersRecipientProvider'],
|
||||
[$class: 'RequesterRecipientProvider']
|
||||
],
|
||||
replyTo: '$DEFAULT_REPLYTO',
|
||||
to: '$DEFAULT_RECIPIENTS'
|
||||
}
|
||||
}
|
||||
return this
|
||||
|
46
tests/benchmark/ci/jenkinsfile/publishDailyImages.groovy
Normal file
46
tests/benchmark/ci/jenkinsfile/publishDailyImages.groovy
Normal file
@ -0,0 +1,46 @@
|
||||
timeout(time: 30, unit: 'MINUTES') {
|
||||
def imageName = "milvus/engine:${DOCKER_VERSION}"
|
||||
def remoteImageName = "milvusdb/daily-build:${REMOTE_DOCKER_VERSION}"
|
||||
def localDockerRegistryImage = "${params.LOCAL_DOKCER_REGISTRY_URL}/${imageName}"
|
||||
def remoteDockerRegistryImage = "${params.REMOTE_DOKCER_REGISTRY_URL}/${remoteImageName}"
|
||||
try {
|
||||
deleteImages("${localDockerRegistryImage}", true)
|
||||
|
||||
def pullSourceImageStatus = sh(returnStatus: true, script: "docker pull ${localDockerRegistryImage}")
|
||||
|
||||
if (pullSourceImageStatus == 0) {
|
||||
def renameImageStatus = sh(returnStatus: true, script: "docker tag ${localDockerRegistryImage} ${remoteImageName} && docker rmi ${localDockerRegistryImage}")
|
||||
def sourceImage = docker.image("${remoteImageName}")
|
||||
docker.withRegistry("https://${params.REMOTE_DOKCER_REGISTRY_URL}", "${params.REMOTE_DOCKER_CREDENTIALS_ID}") {
|
||||
sourceImage.push()
|
||||
sourceImage.push("${REMOTE_DOCKER_LATEST_VERSION}")
|
||||
}
|
||||
} else {
|
||||
echo "\"${localDockerRegistryImage}\" image does not exist !"
|
||||
}
|
||||
} catch (exc) {
|
||||
throw exc
|
||||
} finally {
|
||||
deleteImages("${localDockerRegistryImage}", true)
|
||||
deleteImages("${remoteDockerRegistryImage}", true)
|
||||
}
|
||||
}
|
||||
|
||||
boolean deleteImages(String imageName, boolean force) {
|
||||
def imageNameStr = imageName.trim()
|
||||
def isExistImage = sh(returnStatus: true, script: "docker inspect --type=image ${imageNameStr} 2>&1 > /dev/null")
|
||||
if (isExistImage == 0) {
|
||||
def deleteImageStatus = 0
|
||||
if (force) {
|
||||
def imageID = sh(returnStdout: true, script: "docker inspect --type=image --format \"{{.ID}}\" ${imageNameStr}")
|
||||
deleteImageStatus = sh(returnStatus: true, script: "docker rmi -f ${imageID}")
|
||||
} else {
|
||||
deleteImageStatus = sh(returnStatus: true, script: "docker rmi ${imageNameStr}")
|
||||
}
|
||||
|
||||
if (deleteImageStatus != 0) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
148
tests/benchmark/ci/main_jenkinsfile
Normal file
148
tests/benchmark/ci/main_jenkinsfile
Normal file
@ -0,0 +1,148 @@
|
||||
pipeline {
|
||||
agent none
|
||||
|
||||
options {
|
||||
timestamps()
|
||||
}
|
||||
|
||||
parameters{
|
||||
string defaultValue: '0.11.1', description: 'server image version', name: 'IMAGE_VERSION', trim: true
|
||||
choice choices: ['single', 'shards'], description: 'server deploy mode', name: 'DEPLOY_MODE'
|
||||
string defaultValue: '011_data.json', description: 'test suite config yaml', name: 'CONFIG_FILE', trim: true
|
||||
string defaultValue: 'shards.json', description: 'shards test suite config yaml', name: 'SHARDS_CONFIG_FILE', trim: true
|
||||
string defaultValue: '09509e53-9125-4f5d-9ce8-42855987ad67', description: 'git credentials', name: 'GIT_USER', trim: true
|
||||
}
|
||||
|
||||
environment {
|
||||
HELM_URL = "https://github.com/milvus-io/milvus-helm.git"
|
||||
HELM_BRANCH = "0.11.1"
|
||||
TEST_URL = "git@192.168.1.105:Test/milvus_benchmark.git"
|
||||
TEST_BRANCH = "0.11.1"
|
||||
TEST_LIB_URL = "http://192.168.1.105:6060/Test/milvus_metrics.git"
|
||||
HELM_RELEASE_NAME = "milvus-benchmark-test-${env.BUILD_NUMBER}"
|
||||
HELM_SHARDS_RELEASE_NAME = "milvus-shards-benchmark-test-${env.BUILD_NUMBER}"
|
||||
}
|
||||
|
||||
stages {
|
||||
stage("Setup env") {
|
||||
agent {
|
||||
kubernetes {
|
||||
label "test-benchmark-${env.JOB_NAME}-${env.BUILD_NUMBER}"
|
||||
defaultContainer 'jnlp'
|
||||
yaml """
|
||||
apiVersion: v1
|
||||
kind: Pod
|
||||
metadata:
|
||||
labels:
|
||||
app: milvus
|
||||
componet: test
|
||||
spec:
|
||||
containers:
|
||||
- name: milvus-test-env
|
||||
image: registry.zilliz.com/milvus/milvus-test-env:v0.3
|
||||
command:
|
||||
- cat
|
||||
tty: true
|
||||
volumeMounts:
|
||||
- name: kubeconf
|
||||
mountPath: /root/.kube/
|
||||
readOnly: true
|
||||
- name: db-data-path
|
||||
mountPath: /test
|
||||
readOnly: false
|
||||
nodeSelector:
|
||||
kubernetes.io/hostname: idc-sh002
|
||||
tolerations:
|
||||
- key: worker
|
||||
operator: Equal
|
||||
value: performance
|
||||
effect: NoSchedule
|
||||
volumes:
|
||||
- name: kubeconf
|
||||
secret:
|
||||
secretName: test-cluster-config
|
||||
- name: db-data-path
|
||||
flexVolume:
|
||||
driver: "fstab/cifs"
|
||||
fsType: "cifs"
|
||||
secretRef:
|
||||
name: "cifs-test-secret"
|
||||
options:
|
||||
networkPath: "//172.16.70.249/test"
|
||||
mountOptions: "vers=1.0"
|
||||
"""
|
||||
}
|
||||
}
|
||||
|
||||
stages {
|
||||
stage("Publish Daily Docker images") {
|
||||
steps {
|
||||
container('milvus-test-env') {
|
||||
script {
|
||||
boolean isNightlyTest = isTimeTriggeredBuild()
|
||||
if (isNightlyTest) {
|
||||
build job: 'milvus-publish-daily-docker', parameters: [string(name: 'LOCAL_DOKCER_REGISTRY_URL', value: 'registry.zilliz.com'), string(name: 'REMOTE_DOKCER_REGISTRY_URL', value: 'registry-1.docker.io'), string(name: 'REMOTE_DOCKER_CREDENTIALS_ID', value: 'milvus-docker-access-token'), string(name: 'BRANCH', value: String.valueOf(IMAGE_VERSION))], wait: false
|
||||
} else {
|
||||
echo "Skip publish daily docker images ..."
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
stage("Deploy Test") {
|
||||
steps {
|
||||
container('milvus-test-env') {
|
||||
script {
|
||||
print "In Deploy Test Stage"
|
||||
if ("${params.DEPLOY_MODE}" == "single") {
|
||||
load "${env.WORKSPACE}/ci/jenkinsfile/deploy_test.groovy"
|
||||
} else {
|
||||
load "${env.WORKSPACE}/ci/jenkinsfile/deploy_shards_test.groovy"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
stage ("Cleanup Env") {
|
||||
steps {
|
||||
container('milvus-test-env') {
|
||||
script {
|
||||
if ("${params.DEPLOY_MODE}" == "single") {
|
||||
load "${env.WORKSPACE}/ci/jenkinsfile/cleanup.groovy"
|
||||
} else {
|
||||
load "${env.WORKSPACE}/ci/jenkinsfile/cleanupShards.groovy"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
post {
|
||||
success {
|
||||
script {
|
||||
echo "Milvus benchmark test success !"
|
||||
}
|
||||
}
|
||||
aborted {
|
||||
script {
|
||||
echo "Milvus benchmark test aborted !"
|
||||
}
|
||||
}
|
||||
failure {
|
||||
script {
|
||||
echo "Milvus benchmark test failed !"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
boolean isTimeTriggeredBuild() {
|
||||
if (currentBuild.getBuildCauses('hudson.triggers.TimerTrigger$TimerTriggerCause').size() != 0) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
13
tests/benchmark/ci/pod_containers/milvus-testframework.yaml
Normal file
13
tests/benchmark/ci/pod_containers/milvus-testframework.yaml
Normal file
@ -0,0 +1,13 @@
|
||||
apiVersion: v1
|
||||
kind: Pod
|
||||
metadata:
|
||||
labels:
|
||||
app: milvus
|
||||
componet: testframework
|
||||
spec:
|
||||
containers:
|
||||
- name: milvus-testframework
|
||||
image: registry.zilliz.com/milvus/milvus-test:v0.2
|
||||
command:
|
||||
- cat
|
||||
tty: true
|
104
tests/benchmark/ci/publish_jenkinsfile
Normal file
104
tests/benchmark/ci/publish_jenkinsfile
Normal file
@ -0,0 +1,104 @@
|
||||
pipeline {
|
||||
agent none
|
||||
|
||||
options {
|
||||
timestamps()
|
||||
}
|
||||
|
||||
parameters{
|
||||
string defaultValue: 'registry.zilliz.com', description: 'Local Docker registry URL', name: 'LOCAL_DOKCER_REGISTRY_URL', trim: true
|
||||
string defaultValue: 'registry-1.docker.io', description: 'Remote Docker registry URL', name: 'REMOTE_DOKCER_REGISTRY_URL', trim: true
|
||||
string defaultValue: 'milvus-docker-access-token', description: 'Remote Docker credentials id', name: 'REMOTE_DOCKER_CREDENTIALS_ID', trim: true
|
||||
string(defaultValue: "master", description: 'Milvus server version', name: 'BRANCH')
|
||||
}
|
||||
|
||||
environment {
|
||||
DAILY_BUILD_VERSION = VersionNumber([
|
||||
versionNumberString : '${BUILD_DATE_FORMATTED, "yyyyMMdd"}'
|
||||
]);
|
||||
}
|
||||
|
||||
stages {
|
||||
stage('Push Daily Docker Images') {
|
||||
matrix {
|
||||
agent none
|
||||
axes {
|
||||
axis {
|
||||
name 'OS_NAME'
|
||||
values 'centos7'
|
||||
}
|
||||
|
||||
axis {
|
||||
name 'CPU_ARCH'
|
||||
values 'amd64'
|
||||
}
|
||||
|
||||
axis {
|
||||
name 'BINARY_VERSION'
|
||||
values 'gpu', 'cpu'
|
||||
}
|
||||
}
|
||||
|
||||
stages {
|
||||
stage("Publish Docker Images") {
|
||||
environment {
|
||||
DOCKER_VERSION = "${params.BRANCH}-${BINARY_VERSION}-${OS_NAME}-release"
|
||||
REMOTE_DOCKER_VERSION = "${params.BRANCH}-${OS_NAME}-${BINARY_VERSION}-${DAILY_BUILD_VERSION}"
|
||||
REMOTE_DOCKER_LATEST_VERSION = "${params.BRANCH}-${OS_NAME}-${BINARY_VERSION}-latest"
|
||||
}
|
||||
|
||||
agent {
|
||||
kubernetes {
|
||||
label "${OS_NAME}-${BINARY_VERSION}-publish-${env.BUILD_NUMBER}"
|
||||
defaultContainer 'jnlp'
|
||||
yaml """
|
||||
apiVersion: v1
|
||||
kind: Pod
|
||||
metadata:
|
||||
labels:
|
||||
app: publish
|
||||
componet: docker
|
||||
spec:
|
||||
containers:
|
||||
- name: publish-images
|
||||
image: registry.zilliz.com/library/docker:v1.0.0
|
||||
securityContext:
|
||||
privileged: true
|
||||
command:
|
||||
- cat
|
||||
tty: true
|
||||
resources:
|
||||
limits:
|
||||
memory: "4Gi"
|
||||
cpu: "1.0"
|
||||
requests:
|
||||
memory: "2Gi"
|
||||
cpu: "0.5"
|
||||
volumeMounts:
|
||||
- name: docker-sock
|
||||
mountPath: /var/run/docker.sock
|
||||
volumes:
|
||||
- name: docker-sock
|
||||
hostPath:
|
||||
path: /var/run/docker.sock
|
||||
"""
|
||||
}
|
||||
}
|
||||
|
||||
stages {
|
||||
stage('Publish') {
|
||||
steps {
|
||||
container('publish-images') {
|
||||
script {
|
||||
load "${env.WORKSPACE}/ci/jenkinsfile/publishDailyImages.groovy"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
536
tests/benchmark/ci/scripts/yaml_processor.py
Executable file
536
tests/benchmark/ci/scripts/yaml_processor.py
Executable file
@ -0,0 +1,536 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
from argparse import Namespace
|
||||
import os, shutil
|
||||
import getopt
|
||||
from ruamel.yaml import YAML, yaml_object
|
||||
from ruamel.yaml.comments import CommentedSeq, CommentedMap
|
||||
from ruamel.yaml.tokens import CommentToken
|
||||
|
||||
##
|
||||
yaml = YAML(typ="rt")
|
||||
## format yaml file
|
||||
yaml.indent(mapping=2, sequence=4, offset=2)
|
||||
|
||||
|
||||
############################################
|
||||
# Comment operation
|
||||
#
|
||||
############################################
|
||||
def _extract_comment(_comment):
|
||||
"""
|
||||
remove '#' at start of comment
|
||||
"""
|
||||
# if _comment is empty, do nothing
|
||||
if not _comment:
|
||||
return _comment
|
||||
|
||||
# str_ = _comment.lstrip(" ")
|
||||
str_ = _comment.strip()
|
||||
str_ = str_.lstrip("#")
|
||||
|
||||
return str_
|
||||
|
||||
|
||||
def _add_eol_comment(element, *args, **kwargs):
|
||||
"""
|
||||
add_eol_comment
|
||||
args --> (comment, key)
|
||||
"""
|
||||
if element is None or \
|
||||
(not isinstance(element, CommentedMap) and
|
||||
not isinstance(element, CommentedSeq)) or \
|
||||
args[0] is None or \
|
||||
len(args[0]) == 0:
|
||||
return
|
||||
|
||||
comment = args[0]
|
||||
# comment is empty, do nothing
|
||||
if not comment:
|
||||
return
|
||||
|
||||
key = args[1]
|
||||
try:
|
||||
element.yaml_add_eol_comment(*args, **kwargs)
|
||||
except Exception:
|
||||
element.ca.items.pop(key, None)
|
||||
element.yaml_add_eol_comment(*args, **kwargs)
|
||||
|
||||
|
||||
def _map_comment(_element, _key):
|
||||
origin_comment = ""
|
||||
token = _element.ca.items.get(_key, None)
|
||||
if token is not None:
|
||||
try:
|
||||
origin_comment = token[2].value
|
||||
except Exception:
|
||||
try:
|
||||
# comment is below element, add profix "#\n"
|
||||
col = _element.lc.col + 2
|
||||
space_list = [" " for i in range(col)]
|
||||
space_str = "".join(space_list)
|
||||
|
||||
origin_comment = "\n" + "".join([space_str + t.value for t in token[3]])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return origin_comment
|
||||
|
||||
|
||||
def _seq_comment(_element, _index):
|
||||
# get target comment
|
||||
_comment = ""
|
||||
token = _element.ca.items.get(_index, None)
|
||||
if token is not None:
|
||||
_comment = token[0].value
|
||||
|
||||
return _comment
|
||||
|
||||
|
||||
def _start_comment(_element):
|
||||
_comment = ""
|
||||
cmt = _element.ca.comment
|
||||
try:
|
||||
_comment = cmt[1][0].value
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return _comment
|
||||
|
||||
|
||||
def _comment_counter(_comment):
|
||||
"""
|
||||
|
||||
counter comment tips and split into list
|
||||
"""
|
||||
|
||||
x = lambda l: l.strip().strip("#").strip()
|
||||
|
||||
_counter = []
|
||||
if _comment.startswith("\n"):
|
||||
_counter.append("")
|
||||
_counter.append(x(_comment[1:]))
|
||||
|
||||
return _counter
|
||||
elif _comment.startswith("#\n"):
|
||||
_counter.append("")
|
||||
_counter.append(x(_comment[2:]))
|
||||
else:
|
||||
index = _comment.find("\n")
|
||||
_counter.append(x(_comment[:index]))
|
||||
_counter.append(x(_comment[index + 1:]))
|
||||
|
||||
return _counter
|
||||
|
||||
|
||||
def _obtain_comment(_m_comment, _t_comment):
|
||||
if not _m_comment or not _t_comment:
|
||||
return _m_comment or _t_comment
|
||||
|
||||
_m_counter = _comment_counter(_m_comment)
|
||||
_t_counter = _comment_counter(_t_comment)
|
||||
|
||||
if not _m_counter[0] and not _t_counter[1]:
|
||||
comment = _t_comment + _m_comment
|
||||
elif not _m_counter[1] and not _t_counter[0]:
|
||||
comment = _m_comment + _t_comment
|
||||
elif _t_counter[0] and _t_counter[1]:
|
||||
comment = _t_comment
|
||||
elif not _t_counter[0] and not _t_counter[1]:
|
||||
comment = _m_comment
|
||||
elif not _m_counter[0] and not _m_counter[1]:
|
||||
comment = _t_comment
|
||||
else:
|
||||
if _t_counter[0]:
|
||||
comment = _m_comment.replace(_m_counter[0], _t_counter[0], 1)
|
||||
else:
|
||||
comment = _m_comment.replace(_m_counter[1], _t_counter[1], 1)
|
||||
|
||||
i = comment.find("\n\n")
|
||||
while i >= 0:
|
||||
comment = comment.replace("\n\n\n", "\n\n", 1)
|
||||
i = comment.find("\n\n\n")
|
||||
|
||||
return comment
|
||||
|
||||
|
||||
############################################
|
||||
# Utils
|
||||
#
|
||||
############################################
|
||||
def _get_update_par(_args):
|
||||
_dict = _args.__dict__
|
||||
|
||||
# file path
|
||||
_in_file = _dict.get("f", None) or _dict.get("file", None)
|
||||
# tips
|
||||
_tips = _dict.get('tips', None) or "Input \"-h\" for more information"
|
||||
# update
|
||||
_u = _dict.get("u", None) or _dict.get("update", None)
|
||||
# apppend
|
||||
_a = _dict.get('a', None) or _dict.get('append', None)
|
||||
# out stream group
|
||||
_i = _dict.get("i", None) or _dict.get("inplace", None)
|
||||
_o = _dict.get("o", None) or _dict.get("out_file", None)
|
||||
|
||||
return _in_file, _u, _a, _i, _o, _tips
|
||||
|
||||
|
||||
############################################
|
||||
# Element operation
|
||||
#
|
||||
############################################
|
||||
def update_map_element(element, key, value, comment, _type):
|
||||
"""
|
||||
element:
|
||||
key:
|
||||
value:
|
||||
comment:
|
||||
_type: value type.
|
||||
"""
|
||||
if element is None or not isinstance(element, CommentedMap):
|
||||
print("Only key-value update support")
|
||||
sys.exit(1)
|
||||
|
||||
origin_comment = _map_comment(element, key)
|
||||
|
||||
sub_element = element.get(key, None)
|
||||
if isinstance(sub_element, CommentedMap) or isinstance(sub_element, CommentedSeq):
|
||||
print("Only support update a single value")
|
||||
|
||||
element.update({key: value})
|
||||
|
||||
comment = _obtain_comment(origin_comment, comment)
|
||||
_add_eol_comment(element, _extract_comment(comment), key)
|
||||
|
||||
|
||||
def update_seq_element(element, value, comment, _type):
|
||||
if element is None or not isinstance(element, CommentedSeq):
|
||||
print("Param `-a` only use to append yaml list")
|
||||
sys.exit(1)
|
||||
element.append(str(value))
|
||||
|
||||
comment = _obtain_comment("", comment)
|
||||
_add_eol_comment(element, _extract_comment(comment), len(element) - 1)
|
||||
|
||||
|
||||
def run_update(code, keys, value, comment, _app):
|
||||
key_list = keys.split(".")
|
||||
|
||||
space_str = ":\n "
|
||||
key_str = "{}".format(key_list[0])
|
||||
for key in key_list[1:]:
|
||||
key_str = key_str + space_str + key
|
||||
space_str = space_str + " "
|
||||
if not _app:
|
||||
yaml_str = """{}: {}""".format(key_str, value)
|
||||
else:
|
||||
yaml_str = "{}{}- {}".format(key_str, space_str, value)
|
||||
|
||||
if comment:
|
||||
yaml_str = "{} # {}".format(yaml_str, comment)
|
||||
|
||||
mcode = yaml.load(yaml_str)
|
||||
|
||||
_merge(code, mcode)
|
||||
|
||||
|
||||
def _update(code, _update, _app, _tips):
|
||||
if not _update:
|
||||
return code
|
||||
|
||||
_update_list = [l.strip() for l in _update.split(",")]
|
||||
for l in _update_list:
|
||||
try:
|
||||
variant, comment = l.split("#")
|
||||
except ValueError:
|
||||
variant = l
|
||||
comment = None
|
||||
|
||||
try:
|
||||
keys, value = variant.split("=")
|
||||
run_update(code, keys, value, comment, _app)
|
||||
except ValueError:
|
||||
print("Invalid format. print command \"--help\" get more info.")
|
||||
sys.exit(1)
|
||||
|
||||
return code
|
||||
|
||||
|
||||
def _backup(in_file_p):
|
||||
backup_p = in_file_p + ".bak"
|
||||
|
||||
if os.path.exists(backup_p):
|
||||
os.remove(backup_p)
|
||||
|
||||
if not os.path.exists(in_file_p):
|
||||
print("File {} not exists.".format(in_file_p))
|
||||
sys.exit(1)
|
||||
|
||||
shutil.copyfile(in_file_p, backup_p) # 复制文件
|
||||
|
||||
|
||||
def _recovery(in_file_p):
|
||||
backup_p = in_file_p + ".bak"
|
||||
|
||||
if not os.path.exists(in_file_p):
|
||||
print("File {} not exists.".format(in_file_p))
|
||||
sys.exit(1)
|
||||
elif not os.path.exists(backup_p):
|
||||
print("Backup file not exists")
|
||||
sys.exit(0)
|
||||
|
||||
os.remove(in_file_p)
|
||||
|
||||
os.rename(backup_p, in_file_p)
|
||||
|
||||
|
||||
# master merge target
|
||||
def _merge(master, target):
|
||||
if type(master) != type(target):
|
||||
print("yaml format not match:\n")
|
||||
yaml.dump(master, sys.stdout)
|
||||
print("\n&&\n")
|
||||
yaml.dump(target, sys.stdout)
|
||||
|
||||
sys.exit(1)
|
||||
|
||||
## item is a sequence
|
||||
if isinstance(target, CommentedSeq):
|
||||
for index in range(len(target)):
|
||||
# get target comment
|
||||
target_comment = _seq_comment(target, index)
|
||||
|
||||
master_index = len(master)
|
||||
|
||||
target_item = target[index]
|
||||
|
||||
if isinstance(target_item, CommentedMap):
|
||||
merge_flag = False
|
||||
for idx in range(len(master)):
|
||||
if isinstance(master[idx], CommentedMap):
|
||||
if master[idx].keys() == target_item.keys():
|
||||
_merge(master[idx], target_item)
|
||||
# nonlocal merge_flag
|
||||
master_index = idx
|
||||
merge_flag = True
|
||||
break
|
||||
|
||||
if merge_flag is False:
|
||||
master.append(target_item)
|
||||
elif target_item not in master:
|
||||
master.append(target[index])
|
||||
else:
|
||||
# merge(master[index], target[index])
|
||||
pass
|
||||
|
||||
# # remove enter signal in previous item
|
||||
previous_comment = _seq_comment(master, master_index - 1)
|
||||
_add_eol_comment(master, _extract_comment(previous_comment), master_index - 1)
|
||||
|
||||
origin_comment = _seq_comment(master, master_index)
|
||||
comment = _obtain_comment(origin_comment, target_comment)
|
||||
if len(comment) > 0:
|
||||
_add_eol_comment(master, _extract_comment(comment) + "\n\n", len(master) - 1)
|
||||
|
||||
## item is a map
|
||||
elif isinstance(target, CommentedMap):
|
||||
for item in target:
|
||||
if item == "flag":
|
||||
print("")
|
||||
origin_comment = _map_comment(master, item)
|
||||
target_comment = _map_comment(target, item)
|
||||
|
||||
# get origin start comment
|
||||
origin_start_comment = _start_comment(master)
|
||||
|
||||
# get target start comment
|
||||
target_start_comment = _start_comment(target)
|
||||
|
||||
m = master.get(item, default=None)
|
||||
if m is None or \
|
||||
(not (isinstance(m, CommentedMap) or
|
||||
isinstance(m, CommentedSeq))):
|
||||
master.update({item: target[item]})
|
||||
|
||||
else:
|
||||
_merge(master[item], target[item])
|
||||
|
||||
comment = _obtain_comment(origin_comment, target_comment)
|
||||
if len(comment) > 0:
|
||||
_add_eol_comment(master, _extract_comment(comment), item)
|
||||
|
||||
start_comment = _obtain_comment(origin_start_comment, target_start_comment)
|
||||
if len(start_comment) > 0:
|
||||
master.yaml_set_start_comment(_extract_comment(start_comment))
|
||||
|
||||
|
||||
def _save(_code, _file):
|
||||
with open(_file, 'w') as wf:
|
||||
yaml.dump(_code, wf)
|
||||
|
||||
|
||||
def _load(_file):
|
||||
with open(_file, 'r') as rf:
|
||||
code = yaml.load(rf)
|
||||
return code
|
||||
|
||||
|
||||
############################################
|
||||
# sub parser process operation
|
||||
#
|
||||
############################################
|
||||
def merge_yaml(_args):
|
||||
_dict = _args.__dict__
|
||||
|
||||
_m_file = _dict.get("merge_file", None)
|
||||
_in_file, _u, _a, _i, _o, _tips = _get_update_par(_args)
|
||||
|
||||
if not (_in_file and _m_file):
|
||||
print(_tips)
|
||||
sys.exit(1)
|
||||
|
||||
code = _load(_in_file)
|
||||
mcode = _load(_m_file)
|
||||
|
||||
_merge(code, mcode)
|
||||
|
||||
_update(code, _u, _a, _tips)
|
||||
|
||||
if _i:
|
||||
_backup(_in_file)
|
||||
_save(code, _in_file)
|
||||
elif _o:
|
||||
_save(code, _o)
|
||||
else:
|
||||
print(_tips)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def update_yaml(_args):
|
||||
_in_file, _u, _a, _i, _o, _tips = _get_update_par(_args)
|
||||
|
||||
if not _in_file or not _u:
|
||||
print(_tips)
|
||||
sys.exit(1)
|
||||
|
||||
code = _load(_in_file)
|
||||
|
||||
if _i and _o:
|
||||
print(_tips)
|
||||
sys.exit(1)
|
||||
|
||||
_update(code, _u, _a, _tips)
|
||||
|
||||
if _i:
|
||||
_backup(_in_file)
|
||||
_save(code, _in_file)
|
||||
elif _o:
|
||||
_save(code, _o)
|
||||
|
||||
|
||||
def reset(_args):
|
||||
_dict = _args.__dict__
|
||||
_f = _dict.get('f', None) or _dict.get('file', None)
|
||||
|
||||
if _f:
|
||||
_recovery(_f)
|
||||
else:
|
||||
_t = _dict.get('tips', None) or "Input \"-h\" for more information"
|
||||
print(_t)
|
||||
|
||||
|
||||
############################################
|
||||
# Cli operation
|
||||
#
|
||||
############################################
|
||||
def _set_merge_parser(_parsers):
|
||||
"""
|
||||
config merge parser
|
||||
"""
|
||||
|
||||
merge_parser = _parsers.add_parser("merge", help="merge with another yaml file")
|
||||
|
||||
_set_merge_parser_arg(merge_parser)
|
||||
_set_update_parser_arg(merge_parser)
|
||||
|
||||
merge_parser.set_defaults(
|
||||
function=merge_yaml,
|
||||
tips=merge_parser.format_help()
|
||||
)
|
||||
|
||||
|
||||
def _set_merge_parser_arg(_parser):
|
||||
"""
|
||||
config parser argument for merging
|
||||
"""
|
||||
|
||||
_parser.add_argument("-m", "--merge-file", help="indicate merge yaml file")
|
||||
|
||||
|
||||
def _set_update_parser(_parsers):
|
||||
"""
|
||||
config merge parser
|
||||
"""
|
||||
|
||||
update_parser = _parsers.add_parser("update", help="update with another yaml file")
|
||||
_set_update_parser_arg(update_parser)
|
||||
|
||||
update_parser.set_defaults(
|
||||
function=update_yaml,
|
||||
tips=update_parser.format_help()
|
||||
)
|
||||
|
||||
|
||||
def _set_update_parser_arg(_parser):
|
||||
"""
|
||||
config parser argument for updating
|
||||
"""
|
||||
|
||||
_parser.add_argument("-f", "--file", help="source yaml file")
|
||||
_parser.add_argument('-u', '--update', help="update with args, instance as \"a.b.c=d# d comment\"")
|
||||
_parser.add_argument('-a', '--append', action="store_true", help="append to a seq")
|
||||
|
||||
group = _parser.add_mutually_exclusive_group()
|
||||
group.add_argument("-o", "--out-file", help="indicate output yaml file")
|
||||
group.add_argument("-i", "--inplace", action="store_true", help="indicate whether result store in origin file")
|
||||
|
||||
|
||||
def _set_reset_parser(_parsers):
|
||||
"""
|
||||
config merge parser
|
||||
"""
|
||||
|
||||
reset_parser = _parsers.add_parser("reset", help="reset yaml file")
|
||||
|
||||
# indicate yaml file
|
||||
reset_parser.add_argument('-f', '--file', help="indicate input yaml file")
|
||||
|
||||
reset_parser.set_defaults(
|
||||
function=reset,
|
||||
tips=reset_parser.format_help()
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
sub_parsers = parser.add_subparsers()
|
||||
|
||||
# set merge command
|
||||
_set_merge_parser(sub_parsers)
|
||||
|
||||
# set update command
|
||||
_set_update_parser(sub_parsers)
|
||||
|
||||
# set reset command
|
||||
_set_reset_parser(sub_parsers)
|
||||
|
||||
# parse argument and run func
|
||||
args = parser.parse_args()
|
||||
args.function(args)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
460
tests/benchmark/client.py
Normal file
460
tests/benchmark/client.py
Normal file
@ -0,0 +1,460 @@
|
||||
import sys
|
||||
import pdb
|
||||
import random
|
||||
import logging
|
||||
import json
|
||||
import time, datetime
|
||||
import traceback
|
||||
from multiprocessing import Process
|
||||
from milvus import Milvus, DataType
|
||||
import numpy as np
|
||||
import utils
|
||||
|
||||
logger = logging.getLogger("milvus_benchmark.client")
|
||||
|
||||
SERVER_HOST_DEFAULT = "127.0.0.1"
|
||||
SERVER_PORT_DEFAULT = 19530
|
||||
INDEX_MAP = {
|
||||
"flat": "FLAT",
|
||||
"ivf_flat": "IVF_FLAT",
|
||||
"ivf_sq8": "IVF_SQ8",
|
||||
"nsg": "NSG",
|
||||
"ivf_sq8h": "IVF_SQ8_HYBRID",
|
||||
"ivf_pq": "IVF_PQ",
|
||||
"hnsw": "HNSW",
|
||||
"annoy": "ANNOY",
|
||||
"bin_flat": "BIN_FLAT",
|
||||
"bin_ivf_flat": "BIN_IVF_FLAT",
|
||||
"rhnsw_pq": "RHNSW_PQ",
|
||||
"rhnsw_sq": "RHNSW_SQ"
|
||||
}
|
||||
epsilon = 0.1
|
||||
|
||||
|
||||
def time_wrapper(func):
|
||||
"""
|
||||
This decorator prints the execution time for the decorated function.
|
||||
"""
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
start = time.time()
|
||||
# logger.debug("Milvus {} start".format(func.__name__))
|
||||
log = kwargs.get("log", True)
|
||||
kwargs.pop("log", None)
|
||||
result = func(*args, **kwargs)
|
||||
end = time.time()
|
||||
if log:
|
||||
logger.debug("Milvus {} run in {}s".format(func.__name__, round(end - start, 2)))
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class MilvusClient(object):
|
||||
def __init__(self, collection_name=None, host=None, port=None, timeout=180):
|
||||
self._collection_name = collection_name
|
||||
start_time = time.time()
|
||||
if not host:
|
||||
host = SERVER_HOST_DEFAULT
|
||||
if not port:
|
||||
port = SERVER_PORT_DEFAULT
|
||||
logger.debug(host)
|
||||
logger.debug(port)
|
||||
# retry connect remote server
|
||||
i = 0
|
||||
while time.time() < start_time + timeout:
|
||||
try:
|
||||
self._milvus = Milvus(
|
||||
host=host,
|
||||
port=port,
|
||||
try_connect=False,
|
||||
pre_ping=False)
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
logger.error("Milvus connect failed: %d times" % i)
|
||||
i = i + 1
|
||||
time.sleep(i)
|
||||
|
||||
if time.time() > start_time + timeout:
|
||||
raise Exception("Server connect timeout")
|
||||
# self._metric_type = None
|
||||
|
||||
def __str__(self):
|
||||
return 'Milvus collection %s' % self._collection_name
|
||||
|
||||
def check_status(self, status):
|
||||
if not status.OK():
|
||||
logger.error(status.message)
|
||||
logger.error(self._milvus.server_status())
|
||||
logger.error(self.count())
|
||||
raise Exception("Status not ok")
|
||||
|
||||
def check_result_ids(self, result):
|
||||
for index, item in enumerate(result):
|
||||
if item[0].distance >= epsilon:
|
||||
logger.error(index)
|
||||
logger.error(item[0].distance)
|
||||
raise Exception("Distance wrong")
|
||||
|
||||
# only support the given field name
|
||||
def create_collection(self, dimension, data_type=DataType.FLOAT_VECTOR, auto_id=False,
|
||||
collection_name=None, other_fields=None):
|
||||
self._dimension = dimension
|
||||
if not collection_name:
|
||||
collection_name = self._collection_name
|
||||
vec_field_name = utils.get_default_field_name(data_type)
|
||||
fields = [{"name": vec_field_name, "type": data_type, "params": {"dim": dimension}}]
|
||||
if other_fields:
|
||||
other_fields = other_fields.split(",")
|
||||
if "int" in other_fields:
|
||||
fields.append({"name": utils.DEFAULT_INT_FIELD_NAME, "type": DataType.INT64})
|
||||
if "float" in other_fields:
|
||||
fields.append({"name": utils.DEFAULT_FLOAT_FIELD_NAME, "type": DataType.FLOAT})
|
||||
create_param = {
|
||||
"fields": fields,
|
||||
"auto_id": auto_id}
|
||||
try:
|
||||
self._milvus.create_collection(collection_name, create_param)
|
||||
logger.info("Create collection: <%s> successfully" % collection_name)
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
raise
|
||||
|
||||
def create_partition(self, tag, collection_name=None):
|
||||
if not collection_name:
|
||||
collection_name = self._collection_name
|
||||
self._milvus.create_partition(collection_name, tag)
|
||||
|
||||
def generate_values(self, data_type, vectors, ids):
|
||||
values = None
|
||||
if data_type in [DataType.INT32, DataType.INT64]:
|
||||
values = ids
|
||||
elif data_type in [DataType.FLOAT, DataType.DOUBLE]:
|
||||
values = [(i + 0.0) for i in ids]
|
||||
elif data_type in [DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR]:
|
||||
values = vectors
|
||||
return values
|
||||
|
||||
def generate_entities(self, vectors, ids=None, collection_name=None):
|
||||
entities = []
|
||||
if collection_name is None:
|
||||
collection_name = self._collection_name
|
||||
info = self.get_info(collection_name)
|
||||
for field in info["fields"]:
|
||||
field_type = field["type"]
|
||||
entities.append(
|
||||
{"name": field["name"], "type": field_type, "values": self.generate_values(field_type, vectors, ids)})
|
||||
return entities
|
||||
|
||||
@time_wrapper
|
||||
def insert(self, entities, ids=None, collection_name=None):
|
||||
tmp_collection_name = self._collection_name if collection_name is None else collection_name
|
||||
try:
|
||||
insert_ids = self._milvus.insert(tmp_collection_name, entities, ids=ids)
|
||||
return insert_ids
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
|
||||
def get_dimension(self):
|
||||
info = self.get_info()
|
||||
for field in info["fields"]:
|
||||
if field["type"] in [DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR]:
|
||||
return field["params"]["dim"]
|
||||
|
||||
def get_rand_ids(self, length):
|
||||
segment_ids = []
|
||||
while True:
|
||||
stats = self.get_stats()
|
||||
segments = stats["partitions"][0]["segments"]
|
||||
# random choice one segment
|
||||
segment = random.choice(segments)
|
||||
try:
|
||||
segment_ids = self._milvus.list_id_in_segment(self._collection_name, segment["id"])
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
if not len(segment_ids):
|
||||
continue
|
||||
elif len(segment_ids) > length:
|
||||
return random.sample(segment_ids, length)
|
||||
else:
|
||||
logger.debug("Reset length: %d" % len(segment_ids))
|
||||
return segment_ids
|
||||
|
||||
# def get_rand_ids_each_segment(self, length):
|
||||
# res = []
|
||||
# status, stats = self._milvus.get_collection_stats(self._collection_name)
|
||||
# self.check_status(status)
|
||||
# segments = stats["partitions"][0]["segments"]
|
||||
# segments_num = len(segments)
|
||||
# # random choice from each segment
|
||||
# for segment in segments:
|
||||
# status, segment_ids = self._milvus.list_id_in_segment(self._collection_name, segment["name"])
|
||||
# self.check_status(status)
|
||||
# res.extend(segment_ids[:length])
|
||||
# return segments_num, res
|
||||
|
||||
# def get_rand_entities(self, length):
|
||||
# ids = self.get_rand_ids(length)
|
||||
# status, get_res = self._milvus.get_entity_by_id(self._collection_name, ids)
|
||||
# self.check_status(status)
|
||||
# return ids, get_res
|
||||
|
||||
def get(self):
|
||||
get_ids = random.randint(1, 1000000)
|
||||
self._milvus.get_entity_by_id(self._collection_name, [get_ids])
|
||||
|
||||
@time_wrapper
|
||||
def get_entities(self, get_ids):
|
||||
get_res = self._milvus.get_entity_by_id(self._collection_name, get_ids)
|
||||
return get_res
|
||||
|
||||
@time_wrapper
|
||||
def delete(self, ids, collection_name=None):
|
||||
tmp_collection_name = self._collection_name if collection_name is None else collection_name
|
||||
self._milvus.delete_entity_by_id(tmp_collection_name, ids)
|
||||
|
||||
def delete_rand(self):
|
||||
delete_id_length = random.randint(1, 100)
|
||||
count_before = self.count()
|
||||
logger.debug("%s: length to delete: %d" % (self._collection_name, delete_id_length))
|
||||
delete_ids = self.get_rand_ids(delete_id_length)
|
||||
self.delete(delete_ids)
|
||||
self.flush()
|
||||
logger.info("%s: count after delete: %d" % (self._collection_name, self.count()))
|
||||
get_res = self._milvus.get_entity_by_id(self._collection_name, delete_ids)
|
||||
for item in get_res:
|
||||
assert not item
|
||||
# if count_before - len(delete_ids) < self.count():
|
||||
# logger.error(delete_ids)
|
||||
# raise Exception("Error occured")
|
||||
|
||||
@time_wrapper
|
||||
def flush(self,_async=False, collection_name=None):
|
||||
tmp_collection_name = self._collection_name if collection_name is None else collection_name
|
||||
self._milvus.flush([tmp_collection_name], _async=_async)
|
||||
|
||||
@time_wrapper
|
||||
def compact(self, collection_name=None):
|
||||
tmp_collection_name = self._collection_name if collection_name is None else collection_name
|
||||
status = self._milvus.compact(tmp_collection_name)
|
||||
self.check_status(status)
|
||||
|
||||
@time_wrapper
|
||||
def create_index(self, field_name, index_type, metric_type, _async=False, index_param=None):
|
||||
index_type = INDEX_MAP[index_type]
|
||||
metric_type = utils.metric_type_trans(metric_type)
|
||||
logger.info("Building index start, collection_name: %s, index_type: %s, metric_type: %s" % (
|
||||
self._collection_name, index_type, metric_type))
|
||||
if index_param:
|
||||
logger.info(index_param)
|
||||
index_params = {
|
||||
"index_type": index_type,
|
||||
"metric_type": metric_type,
|
||||
"params": index_param
|
||||
}
|
||||
self._milvus.create_index(self._collection_name, field_name, index_params, _async=_async)
|
||||
|
||||
# TODO: need to check
|
||||
def describe_index(self, field_name):
|
||||
# stats = self.get_stats()
|
||||
info = self._milvus.describe_index(self._collection_name, field_name)
|
||||
index_info = {"index_type": "flat", "index_param": None}
|
||||
for field in info["fields"]:
|
||||
for index in field['indexes']:
|
||||
if not index or "index_type" not in index:
|
||||
continue
|
||||
else:
|
||||
for k, v in INDEX_MAP.items():
|
||||
if index['index_type'] == v:
|
||||
index_info['index_type'] = k
|
||||
index_info['index_param'] = index['params']
|
||||
return index_info
|
||||
return index_info
|
||||
|
||||
def drop_index(self, field_name):
|
||||
logger.info("Drop index: %s" % self._collection_name)
|
||||
return self._milvus.drop_index(self._collection_name, field_name)
|
||||
|
||||
@time_wrapper
|
||||
def query(self, vector_query, filter_query=None, collection_name=None):
|
||||
tmp_collection_name = self._collection_name if collection_name is None else collection_name
|
||||
must_params = [vector_query]
|
||||
if filter_query:
|
||||
must_params.extend(filter_query)
|
||||
query = {
|
||||
"bool": {"must": must_params}
|
||||
}
|
||||
result = self._milvus.search(tmp_collection_name, query)
|
||||
return result
|
||||
|
||||
@time_wrapper
|
||||
def load_and_query(self, vector_query, filter_query=None, collection_name=None):
|
||||
tmp_collection_name = self._collection_name if collection_name is None else collection_name
|
||||
must_params = [vector_query]
|
||||
if filter_query:
|
||||
must_params.extend(filter_query)
|
||||
query = {
|
||||
"bool": {"must": must_params}
|
||||
}
|
||||
self.load_collection(tmp_collection_name)
|
||||
result = self._milvus.search(tmp_collection_name, query)
|
||||
return result
|
||||
|
||||
def get_ids(self, result):
|
||||
idss = result._entities.ids
|
||||
ids = []
|
||||
len_idss = len(idss)
|
||||
len_r = len(result)
|
||||
top_k = len_idss // len_r
|
||||
for offset in range(0, len_idss, top_k):
|
||||
ids.append(idss[offset: min(offset + top_k, len_idss)])
|
||||
return ids
|
||||
|
||||
def query_rand(self, nq_max=100):
|
||||
# for ivf search
|
||||
dimension = 128
|
||||
top_k = random.randint(1, 100)
|
||||
nq = random.randint(1, nq_max)
|
||||
nprobe = random.randint(1, 100)
|
||||
search_param = {"nprobe": nprobe}
|
||||
query_vectors = [[random.random() for _ in range(dimension)] for _ in range(nq)]
|
||||
metric_type = random.choice(["l2", "ip"])
|
||||
logger.info("%s, Search nq: %d, top_k: %d, nprobe: %d" % (self._collection_name, nq, top_k, nprobe))
|
||||
vec_field_name = utils.get_default_field_name()
|
||||
vector_query = {"vector": {vec_field_name: {
|
||||
"topk": top_k,
|
||||
"query": query_vectors,
|
||||
"metric_type": utils.metric_type_trans(metric_type),
|
||||
"params": search_param}
|
||||
}}
|
||||
self.query(vector_query)
|
||||
|
||||
def load_query_rand(self, nq_max=100):
|
||||
# for ivf search
|
||||
dimension = 128
|
||||
top_k = random.randint(1, 100)
|
||||
nq = random.randint(1, nq_max)
|
||||
nprobe = random.randint(1, 100)
|
||||
search_param = {"nprobe": nprobe}
|
||||
query_vectors = [[random.random() for _ in range(dimension)] for _ in range(nq)]
|
||||
metric_type = random.choice(["l2", "ip"])
|
||||
logger.info("%s, Search nq: %d, top_k: %d, nprobe: %d" % (self._collection_name, nq, top_k, nprobe))
|
||||
vec_field_name = utils.get_default_field_name()
|
||||
vector_query = {"vector": {vec_field_name: {
|
||||
"topk": top_k,
|
||||
"query": query_vectors,
|
||||
"metric_type": utils.metric_type_trans(metric_type),
|
||||
"params": search_param}
|
||||
}}
|
||||
self.load_and_query(vector_query)
|
||||
|
||||
# TODO: need to check
|
||||
def count(self, collection_name=None):
|
||||
if collection_name is None:
|
||||
collection_name = self._collection_name
|
||||
row_count = self._milvus.get_collection_stats(collection_name)["row_count"]
|
||||
logger.debug("Row count: %d in collection: <%s>" % (row_count, collection_name))
|
||||
return row_count
|
||||
|
||||
def drop(self, timeout=120, collection_name=None):
|
||||
timeout = int(timeout)
|
||||
if collection_name is None:
|
||||
collection_name = self._collection_name
|
||||
logger.info("Start delete collection: %s" % collection_name)
|
||||
self._milvus.drop_collection(collection_name)
|
||||
i = 0
|
||||
while i < timeout:
|
||||
try:
|
||||
row_count = self.count(collection_name=collection_name)
|
||||
if row_count:
|
||||
time.sleep(1)
|
||||
i = i + 1
|
||||
continue
|
||||
else:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.debug(str(e))
|
||||
break
|
||||
if i >= timeout:
|
||||
logger.error("Delete collection timeout")
|
||||
|
||||
def get_stats(self):
|
||||
return self._milvus.get_collection_stats(self._collection_name)
|
||||
|
||||
def get_info(self, collection_name=None):
|
||||
# pdb.set_trace()
|
||||
if collection_name is None:
|
||||
collection_name = self._collection_name
|
||||
return self._milvus.get_collection_info(collection_name)
|
||||
|
||||
def show_collections(self):
|
||||
return self._milvus.list_collections()
|
||||
|
||||
def exists_collection(self, collection_name=None):
|
||||
if collection_name is None:
|
||||
collection_name = self._collection_name
|
||||
res = self._milvus.has_collection(collection_name)
|
||||
return res
|
||||
|
||||
def clean_db(self):
|
||||
collection_names = self.show_collections()
|
||||
for name in collection_names:
|
||||
self.drop(collection_name=name)
|
||||
|
||||
@time_wrapper
|
||||
def load_collection(self, collection_name=None):
|
||||
if collection_name is None:
|
||||
collection_name = self._collection_name
|
||||
return self._milvus.load_collection(collection_name, timeout=3000)
|
||||
|
||||
@time_wrapper
|
||||
def release_collection(self, collection_name=None):
|
||||
if collection_name is None:
|
||||
collection_name = self._collection_name
|
||||
return self._milvus.release_collection(collection_name, timeout=3000)
|
||||
|
||||
@time_wrapper
|
||||
def load_partitions(self, tag_names, collection_name=None):
|
||||
if collection_name is None:
|
||||
collection_name = self._collection_name
|
||||
return self._milvus.load_partitions(collection_name, tag_names, timeout=3000)
|
||||
|
||||
@time_wrapper
|
||||
def release_partitions(self, tag_names, collection_name=None):
|
||||
if collection_name is None:
|
||||
collection_name = self._collection_name
|
||||
return self._milvus.release_partitions(collection_name, tag_names, timeout=3000)
|
||||
|
||||
# TODO: remove
|
||||
# def get_server_version(self):
|
||||
# return self._milvus.server_version()
|
||||
|
||||
# def get_server_mode(self):
|
||||
# return self.cmd("mode")
|
||||
|
||||
# def get_server_commit(self):
|
||||
# return self.cmd("build_commit_id")
|
||||
|
||||
# def get_server_config(self):
|
||||
# return json.loads(self.cmd("get_milvus_config"))
|
||||
|
||||
# def get_mem_info(self):
|
||||
# result = json.loads(self.cmd("get_system_info"))
|
||||
# result_human = {
|
||||
# # unit: Gb
|
||||
# "memory_used": round(int(result["memory_used"]) / (1024 * 1024 * 1024), 2)
|
||||
# }
|
||||
# return result_human
|
||||
|
||||
# def cmd(self, command):
|
||||
# res = self._milvus._cmd(command)
|
||||
# logger.info("Server command: %s, result: %s" % (command, res))
|
||||
# return res
|
||||
|
||||
# @time_wrapper
|
||||
# def set_config(self, parent_key, child_key, value):
|
||||
# self._milvus.set_config(parent_key, child_key, value)
|
||||
|
||||
# def get_config(self, key):
|
||||
# return self._milvus.get_config(key)
|
366
tests/benchmark/docker_runner.py
Normal file
366
tests/benchmark/docker_runner.py
Normal file
@ -0,0 +1,366 @@
|
||||
import os
|
||||
import logging
|
||||
import pdb
|
||||
import time
|
||||
import random
|
||||
from multiprocessing import Process
|
||||
import numpy as np
|
||||
from client import MilvusClient
|
||||
import utils
|
||||
import parser
|
||||
from runner import Runner
|
||||
|
||||
logger = logging.getLogger("milvus_benchmark.docker")
|
||||
|
||||
|
||||
class DockerRunner(Runner):
|
||||
"""run docker mode"""
|
||||
def __init__(self, image):
|
||||
super(DockerRunner, self).__init__()
|
||||
self.image = image
|
||||
|
||||
def run(self, definition, run_type=None):
|
||||
if run_type == "performance":
|
||||
for op_type, op_value in definition.items():
|
||||
# run docker mode
|
||||
run_count = op_value["run_count"]
|
||||
run_params = op_value["params"]
|
||||
container = None
|
||||
|
||||
if op_type == "insert":
|
||||
if not run_params:
|
||||
logger.debug("No run params")
|
||||
continue
|
||||
for index, param in enumerate(run_params):
|
||||
logger.info("Definition param: %s" % str(param))
|
||||
collection_name = param["collection_name"]
|
||||
volume_name = param["db_path_prefix"]
|
||||
print(collection_name)
|
||||
(data_type, collection_size, index_file_size, dimension, metric_type) = parser.collection_parser(collection_name)
|
||||
for k, v in param.items():
|
||||
if k.startswith("server."):
|
||||
# Update server config
|
||||
utils.modify_config(k, v, type="server", db_slave=None)
|
||||
container = utils.run_server(self.image, test_type="remote", volume_name=volume_name, db_slave=None)
|
||||
time.sleep(2)
|
||||
milvus = MilvusClient(collection_name)
|
||||
# Check has collection or not
|
||||
if milvus.exists_collection():
|
||||
milvus.delete()
|
||||
time.sleep(10)
|
||||
milvus.create_collection(collection_name, dimension, index_file_size, metric_type)
|
||||
# debug
|
||||
# milvus.create_index("ivf_sq8", 16384)
|
||||
res = self.do_insert(milvus, collection_name, data_type, dimension, collection_size, param["ni_per"])
|
||||
logger.info(res)
|
||||
# wait for file merge
|
||||
time.sleep(collection_size * dimension / 5000000)
|
||||
# Clear up
|
||||
utils.remove_container(container)
|
||||
|
||||
elif op_type == "query":
|
||||
for index, param in enumerate(run_params):
|
||||
logger.info("Definition param: %s" % str(param))
|
||||
collection_name = param["dataset"]
|
||||
volume_name = param["db_path_prefix"]
|
||||
(data_type, collection_size, index_file_size, dimension, metric_type) = parser.collection_parser(collection_name)
|
||||
for k, v in param.items():
|
||||
if k.startswith("server."):
|
||||
utils.modify_config(k, v, type="server")
|
||||
container = utils.run_server(self.image, test_type="remote", volume_name=volume_name, db_slave=None)
|
||||
time.sleep(2)
|
||||
milvus = MilvusClient(collection_name)
|
||||
logger.debug(milvus.show_collections())
|
||||
# Check has collection or not
|
||||
if not milvus.exists_collection():
|
||||
logger.warning("Table %s not existed, continue exec next params ..." % collection_name)
|
||||
continue
|
||||
# parse index info
|
||||
index_types = param["index.index_types"]
|
||||
nlists = param["index.nlists"]
|
||||
# parse top-k, nq, nprobe
|
||||
top_ks, nqs, nprobes = parser.search_params_parser(param)
|
||||
for index_type in index_types:
|
||||
for nlist in nlists:
|
||||
result = milvus.describe_index()
|
||||
logger.info(result)
|
||||
# milvus.drop_index()
|
||||
# milvus.create_index(index_type, nlist)
|
||||
result = milvus.describe_index()
|
||||
logger.info(result)
|
||||
logger.info(milvus.count())
|
||||
# preload index
|
||||
milvus.preload_collection()
|
||||
logger.info("Start warm up query")
|
||||
res = self.do_query(milvus, collection_name, [1], [1], 1, 1)
|
||||
logger.info("End warm up query")
|
||||
# Run query test
|
||||
for nprobe in nprobes:
|
||||
logger.info("index_type: %s, nlist: %s, metric_type: %s, nprobe: %s" % (index_type, nlist, metric_type, nprobe))
|
||||
res = self.do_query(milvus, collection_name, top_ks, nqs, nprobe, run_count)
|
||||
headers = ["Nq/Top-k"]
|
||||
headers.extend([str(top_k) for top_k in top_ks])
|
||||
utils.print_collection(headers, nqs, res)
|
||||
utils.remove_container(container)
|
||||
|
||||
elif run_type == "insert_performance":
|
||||
for op_type, op_value in definition.items():
|
||||
# run docker mode
|
||||
run_count = op_value["run_count"]
|
||||
run_params = op_value["params"]
|
||||
container = None
|
||||
if not run_params:
|
||||
logger.debug("No run params")
|
||||
continue
|
||||
for index, param in enumerate(run_params):
|
||||
logger.info("Definition param: %s" % str(param))
|
||||
collection_name = param["collection_name"]
|
||||
volume_name = param["db_path_prefix"]
|
||||
print(collection_name)
|
||||
(data_type, collection_size, index_file_size, dimension, metric_type) = parser.collection_parser(collection_name)
|
||||
for k, v in param.items():
|
||||
if k.startswith("server."):
|
||||
# Update server config
|
||||
utils.modify_config(k, v, type="server", db_slave=None)
|
||||
container = utils.run_server(self.image, test_type="remote", volume_name=volume_name, db_slave=None)
|
||||
time.sleep(2)
|
||||
milvus = MilvusClient(collection_name)
|
||||
# Check has collection or not
|
||||
if milvus.exists_collection():
|
||||
milvus.delete()
|
||||
time.sleep(10)
|
||||
milvus.create_collection(collection_name, dimension, index_file_size, metric_type)
|
||||
# debug
|
||||
# milvus.create_index("ivf_sq8", 16384)
|
||||
res = self.do_insert(milvus, collection_name, data_type, dimension, collection_size, param["ni_per"])
|
||||
logger.info(res)
|
||||
# wait for file merge
|
||||
time.sleep(collection_size * dimension / 5000000)
|
||||
# Clear up
|
||||
utils.remove_container(container)
|
||||
|
||||
elif run_type == "search_performance":
|
||||
for op_type, op_value in definition.items():
|
||||
# run docker mode
|
||||
run_count = op_value["run_count"]
|
||||
run_params = op_value["params"]
|
||||
container = None
|
||||
for index, param in enumerate(run_params):
|
||||
logger.info("Definition param: %s" % str(param))
|
||||
collection_name = param["dataset"]
|
||||
volume_name = param["db_path_prefix"]
|
||||
(data_type, collection_size, index_file_size, dimension, metric_type) = parser.collection_parser(collection_name)
|
||||
for k, v in param.items():
|
||||
if k.startswith("server."):
|
||||
utils.modify_config(k, v, type="server")
|
||||
container = utils.run_server(self.image, test_type="remote", volume_name=volume_name, db_slave=None)
|
||||
time.sleep(2)
|
||||
milvus = MilvusClient(collection_name)
|
||||
logger.debug(milvus.show_collections())
|
||||
# Check has collection or not
|
||||
if not milvus.exists_collection():
|
||||
logger.warning("Table %s not existed, continue exec next params ..." % collection_name)
|
||||
continue
|
||||
# parse index info
|
||||
index_types = param["index.index_types"]
|
||||
nlists = param["index.nlists"]
|
||||
# parse top-k, nq, nprobe
|
||||
top_ks, nqs, nprobes = parser.search_params_parser(param)
|
||||
for index_type in index_types:
|
||||
for nlist in nlists:
|
||||
result = milvus.describe_index()
|
||||
logger.info(result)
|
||||
# milvus.drop_index()
|
||||
# milvus.create_index(index_type, nlist)
|
||||
result = milvus.describe_index()
|
||||
logger.info(result)
|
||||
logger.info(milvus.count())
|
||||
# preload index
|
||||
milvus.preload_collection()
|
||||
logger.info("Start warm up query")
|
||||
res = self.do_query(milvus, collection_name, [1], [1], 1, 1)
|
||||
logger.info("End warm up query")
|
||||
# Run query test
|
||||
for nprobe in nprobes:
|
||||
logger.info("index_type: %s, nlist: %s, metric_type: %s, nprobe: %s" % (index_type, nlist, metric_type, nprobe))
|
||||
res = self.do_query(milvus, collection_name, top_ks, nqs, nprobe, run_count)
|
||||
headers = ["Nq/Top-k"]
|
||||
headers.extend([str(top_k) for top_k in top_ks])
|
||||
utils.print_collection(headers, nqs, res)
|
||||
utils.remove_container(container)
|
||||
|
||||
elif run_type == "accuracy":
|
||||
"""
|
||||
{
|
||||
"dataset": "random_50m_1024_512",
|
||||
"index.index_types": ["flat", ivf_flat", "ivf_sq8"],
|
||||
"index.nlists": [16384],
|
||||
"nprobes": [1, 32, 128],
|
||||
"nqs": [100],
|
||||
"top_ks": [1, 64],
|
||||
"server.use_blas_threshold": 1100,
|
||||
"server.cpu_cache_capacity": 256
|
||||
}
|
||||
"""
|
||||
for op_type, op_value in definition.items():
|
||||
if op_type != "query":
|
||||
logger.warning("invalid operation: %s in accuracy test, only support query operation" % op_type)
|
||||
break
|
||||
run_count = op_value["run_count"]
|
||||
run_params = op_value["params"]
|
||||
container = None
|
||||
|
||||
for index, param in enumerate(run_params):
|
||||
logger.info("Definition param: %s" % str(param))
|
||||
collection_name = param["dataset"]
|
||||
sift_acc = False
|
||||
if "sift_acc" in param:
|
||||
sift_acc = param["sift_acc"]
|
||||
(data_type, collection_size, index_file_size, dimension, metric_type) = parser.collection_parser(collection_name)
|
||||
for k, v in param.items():
|
||||
if k.startswith("server."):
|
||||
utils.modify_config(k, v, type="server")
|
||||
volume_name = param["db_path_prefix"]
|
||||
container = utils.run_server(self.image, test_type="remote", volume_name=volume_name, db_slave=None)
|
||||
time.sleep(2)
|
||||
milvus = MilvusClient(collection_name)
|
||||
# Check has collection or not
|
||||
if not milvus.exists_collection():
|
||||
logger.warning("Table %s not existed, continue exec next params ..." % collection_name)
|
||||
continue
|
||||
|
||||
# parse index info
|
||||
index_types = param["index.index_types"]
|
||||
nlists = param["index.nlists"]
|
||||
# parse top-k, nq, nprobe
|
||||
top_ks, nqs, nprobes = parser.search_params_parser(param)
|
||||
if sift_acc is True:
|
||||
# preload groundtruth data
|
||||
true_ids_all = self.get_groundtruth_ids(collection_size)
|
||||
acc_dict = {}
|
||||
for index_type in index_types:
|
||||
for nlist in nlists:
|
||||
result = milvus.describe_index()
|
||||
logger.info(result)
|
||||
milvus.create_index(index_type, nlist)
|
||||
# preload index
|
||||
milvus.preload_collection()
|
||||
# Run query test
|
||||
for nprobe in nprobes:
|
||||
logger.info("index_type: %s, nlist: %s, metric_type: %s, nprobe: %s" % (index_type, nlist, metric_type, nprobe))
|
||||
for top_k in top_ks:
|
||||
for nq in nqs:
|
||||
result_ids = []
|
||||
id_prefix = "%s_index_%s_nlist_%s_metric_type_%s_nprobe_%s_top_k_%s_nq_%s" % \
|
||||
(collection_name, index_type, nlist, metric_type, nprobe, top_k, nq)
|
||||
if sift_acc is False:
|
||||
self.do_query_acc(milvus, collection_name, top_k, nq, nprobe, id_prefix)
|
||||
if index_type != "flat":
|
||||
# Compute accuracy
|
||||
base_name = "%s_index_flat_nlist_%s_metric_type_%s_nprobe_%s_top_k_%s_nq_%s" % \
|
||||
(collection_name, nlist, metric_type, nprobe, top_k, nq)
|
||||
avg_acc = self.compute_accuracy(base_name, id_prefix)
|
||||
logger.info("Query: <%s> accuracy: %s" % (id_prefix, avg_acc))
|
||||
else:
|
||||
result_ids, result_distances = self.do_query_ids(milvus, collection_name, top_k, nq, nprobe)
|
||||
debug_file_ids = "0.5.3_result_ids"
|
||||
debug_file_distances = "0.5.3_result_distances"
|
||||
with open(debug_file_ids, "w+") as fd:
|
||||
total = 0
|
||||
for index, item in enumerate(result_ids):
|
||||
true_item = true_ids_all[:nq, :top_k].tolist()[index]
|
||||
tmp = set(item).intersection(set(true_item))
|
||||
total = total + len(tmp)
|
||||
fd.write("query: N-%d, intersection: %d, total: %d\n" % (index, len(tmp), total))
|
||||
fd.write("%s\n" % str(item))
|
||||
fd.write("%s\n" % str(true_item))
|
||||
acc_value = self.get_recall_value(true_ids_all[:nq, :top_k].tolist(), result_ids)
|
||||
logger.info("Query: <%s> accuracy: %s" % (id_prefix, acc_value))
|
||||
# # print accuracy collection
|
||||
# headers = [collection_name]
|
||||
# headers.extend([str(top_k) for top_k in top_ks])
|
||||
# utils.print_collection(headers, nqs, res)
|
||||
|
||||
# remove container, and run next definition
|
||||
logger.info("remove container, and run next definition")
|
||||
utils.remove_container(container)
|
||||
|
||||
elif run_type == "stability":
|
||||
for op_type, op_value in definition.items():
|
||||
if op_type != "query":
|
||||
logger.warning("invalid operation: %s in accuracy test, only support query operation" % op_type)
|
||||
break
|
||||
run_count = op_value["run_count"]
|
||||
run_params = op_value["params"]
|
||||
container = None
|
||||
for index, param in enumerate(run_params):
|
||||
logger.info("Definition param: %s" % str(param))
|
||||
collection_name = param["dataset"]
|
||||
index_type = param["index_type"]
|
||||
volume_name = param["db_path_prefix"]
|
||||
(data_type, collection_size, index_file_size, dimension, metric_type) = parser.collection_parser(collection_name)
|
||||
|
||||
# set default test time
|
||||
if "during_time" not in param:
|
||||
during_time = 100 # seconds
|
||||
else:
|
||||
during_time = int(param["during_time"]) * 60
|
||||
# set default query process num
|
||||
if "query_process_num" not in param:
|
||||
query_process_num = 10
|
||||
else:
|
||||
query_process_num = int(param["query_process_num"])
|
||||
|
||||
for k, v in param.items():
|
||||
if k.startswith("server."):
|
||||
utils.modify_config(k, v, type="server")
|
||||
|
||||
container = utils.run_server(self.image, test_type="remote", volume_name=volume_name, db_slave=None)
|
||||
time.sleep(2)
|
||||
milvus = MilvusClient(collection_name)
|
||||
# Check has collection or not
|
||||
if not milvus.exists_collection():
|
||||
logger.warning("Table %s not existed, continue exec next params ..." % collection_name)
|
||||
continue
|
||||
|
||||
start_time = time.time()
|
||||
insert_vectors = [[random.random() for _ in range(dimension)] for _ in range(10000)]
|
||||
i = 0
|
||||
while time.time() < start_time + during_time:
|
||||
i = i + 1
|
||||
processes = []
|
||||
# do query
|
||||
# for i in range(query_process_num):
|
||||
# milvus_instance = MilvusClient(collection_name)
|
||||
# top_k = random.choice([x for x in range(1, 100)])
|
||||
# nq = random.choice([x for x in range(1, 100)])
|
||||
# nprobe = random.choice([x for x in range(1, 1000)])
|
||||
# # logger.info("index_type: %s, nlist: %s, metric_type: %s, nprobe: %s" % (index_type, nlist, metric_type, nprobe))
|
||||
# p = Process(target=self.do_query, args=(milvus_instance, collection_name, [top_k], [nq], [nprobe], run_count, ))
|
||||
# processes.append(p)
|
||||
# p.start()
|
||||
# time.sleep(0.1)
|
||||
# for p in processes:
|
||||
# p.join()
|
||||
milvus_instance = MilvusClient(collection_name)
|
||||
top_ks = random.sample([x for x in range(1, 100)], 3)
|
||||
nqs = random.sample([x for x in range(1, 1000)], 3)
|
||||
nprobe = random.choice([x for x in range(1, 500)])
|
||||
res = self.do_query(milvus, collection_name, top_ks, nqs, nprobe, run_count)
|
||||
if i % 10 == 0:
|
||||
status, res = milvus_instance.insert(insert_vectors, ids=[x for x in range(len(insert_vectors))])
|
||||
if not status.OK():
|
||||
logger.error(status)
|
||||
# status = milvus_instance.drop_index()
|
||||
# if not status.OK():
|
||||
# logger.error(status)
|
||||
# index_type = random.choice(["flat", "ivf_flat", "ivf_sq8"])
|
||||
milvus_instance.create_index(index_type, 16384)
|
||||
result = milvus.describe_index()
|
||||
logger.info(result)
|
||||
# milvus_instance.create_index("ivf_sq8", 16384)
|
||||
utils.remove_container(container)
|
||||
|
||||
else:
|
||||
logger.warning("Run type: %s not supported" % run_type)
|
||||
|
126
tests/benchmark/docker_utils.py
Normal file
126
tests/benchmark/docker_utils.py
Normal file
@ -0,0 +1,126 @@
|
||||
# def pull_image(image):
|
||||
# registry = image.split(":")[0]
|
||||
# image_tag = image.split(":")[1]
|
||||
# client = docker.APIClient(base_url='unix://var/run/docker.sock')
|
||||
# logger.info("Start pulling image: %s" % image)
|
||||
# return client.pull(registry, image_tag)
|
||||
|
||||
|
||||
# def run_server(image, mem_limit=None, timeout=30, test_type="local", volume_name=None, db_slave=None):
|
||||
# import colors
|
||||
|
||||
# client = docker.from_env()
|
||||
# # if mem_limit is None:
|
||||
# # mem_limit = psutil.virtual_memory().available
|
||||
# # logger.info('Memory limit:', mem_limit)
|
||||
# # cpu_limit = "0-%d" % (multiprocessing.cpu_count() - 1)
|
||||
# # logger.info('Running on CPUs:', cpu_limit)
|
||||
# for dir_item in ['logs', 'db']:
|
||||
# try:
|
||||
# os.mkdir(os.path.abspath(dir_item))
|
||||
# except Exception as e:
|
||||
# pass
|
||||
|
||||
# if test_type == "local":
|
||||
# volumes = {
|
||||
# os.path.abspath('conf'):
|
||||
# {'bind': '/opt/milvus/conf', 'mode': 'ro'},
|
||||
# os.path.abspath('logs'):
|
||||
# {'bind': '/opt/milvus/logs', 'mode': 'rw'},
|
||||
# os.path.abspath('db'):
|
||||
# {'bind': '/opt/milvus/db', 'mode': 'rw'},
|
||||
# }
|
||||
# elif test_type == "remote":
|
||||
# if volume_name is None:
|
||||
# raise Exception("No volume name")
|
||||
# remote_log_dir = volume_name+'/logs'
|
||||
# remote_db_dir = volume_name+'/db'
|
||||
|
||||
# for dir_item in [remote_log_dir, remote_db_dir]:
|
||||
# if not os.path.isdir(dir_item):
|
||||
# os.makedirs(dir_item, exist_ok=True)
|
||||
# volumes = {
|
||||
# os.path.abspath('conf'):
|
||||
# {'bind': '/opt/milvus/conf', 'mode': 'ro'},
|
||||
# remote_log_dir:
|
||||
# {'bind': '/opt/milvus/logs', 'mode': 'rw'},
|
||||
# remote_db_dir:
|
||||
# {'bind': '/opt/milvus/db', 'mode': 'rw'}
|
||||
# }
|
||||
# # add volumes
|
||||
# if db_slave and isinstance(db_slave, int):
|
||||
# for i in range(2, db_slave+1):
|
||||
# remote_db_dir = volume_name+'/data'+str(i)
|
||||
# if not os.path.isdir(remote_db_dir):
|
||||
# os.makedirs(remote_db_dir, exist_ok=True)
|
||||
# volumes[remote_db_dir] = {'bind': '/opt/milvus/data'+str(i), 'mode': 'rw'}
|
||||
|
||||
# container = client.containers.run(
|
||||
# image,
|
||||
# volumes=volumes,
|
||||
# runtime="nvidia",
|
||||
# ports={'19530/tcp': 19530, '8080/tcp': 8080},
|
||||
# # environment=["OMP_NUM_THREADS=48"],
|
||||
# # cpuset_cpus=cpu_limit,
|
||||
# # mem_limit=mem_limit,
|
||||
# # environment=[""],
|
||||
# detach=True)
|
||||
|
||||
# def stream_logs():
|
||||
# for line in container.logs(stream=True):
|
||||
# logger.info(colors.color(line.decode().rstrip(), fg='blue'))
|
||||
|
||||
# if sys.version_info >= (3, 0):
|
||||
# t = threading.Thread(target=stream_logs, daemon=True)
|
||||
# else:
|
||||
# t = threading.Thread(target=stream_logs)
|
||||
# t.daemon = True
|
||||
# t.start()
|
||||
|
||||
# logger.info('Container: %s started' % container)
|
||||
# return container
|
||||
# # exit_code = container.wait(timeout=timeout)
|
||||
# # # Exit if exit code
|
||||
# # if exit_code == 0:
|
||||
# # return container
|
||||
# # elif exit_code is not None:
|
||||
# # print(colors.color(container.logs().decode(), fg='red'))
|
||||
|
||||
# def restart_server(container):
|
||||
# client = docker.APIClient(base_url='unix://var/run/docker.sock')
|
||||
|
||||
# client.restart(container.name)
|
||||
# logger.info('Container: %s restarted' % container.name)
|
||||
# return container
|
||||
|
||||
|
||||
# def remove_container(container):
|
||||
# container.remove(force=True)
|
||||
# logger.info('Container: %s removed' % container)
|
||||
|
||||
|
||||
# def remove_all_containers(image):
|
||||
# client = docker.from_env()
|
||||
# try:
|
||||
# for container in client.containers.list():
|
||||
# if image in container.image.tags:
|
||||
# container.stop(timeout=30)
|
||||
# container.remove(force=True)
|
||||
# except Exception as e:
|
||||
# logger.error("Containers removed failed")
|
||||
|
||||
|
||||
# def container_exists(image):
|
||||
# '''
|
||||
# Check if container existed with the given image name
|
||||
# @params: image name
|
||||
# @return: container if exists
|
||||
# '''
|
||||
# res = False
|
||||
# client = docker.from_env()
|
||||
# for container in client.containers.list():
|
||||
# if image in container.image.tags:
|
||||
# # True
|
||||
# res = container
|
||||
# return res
|
||||
|
3
tests/benchmark/executors/__init__.py
Normal file
3
tests/benchmark/executors/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
|
||||
class BaseExecutor(object):
|
||||
pass
|
4
tests/benchmark/executors/shell.py
Normal file
4
tests/benchmark/executors/shell.py
Normal file
@ -0,0 +1,4 @@
|
||||
from . import BaseExecutor
|
||||
|
||||
class ShellExecutor(BaseExecutor):
|
||||
pass
|
0
tests/benchmark/handlers/__init__.py
Normal file
0
tests/benchmark/handlers/__init__.py
Normal file
370
tests/benchmark/helm_utils.py
Normal file
370
tests/benchmark/helm_utils.py
Normal file
@ -0,0 +1,370 @@
|
||||
import os
|
||||
import pdb
|
||||
import time
|
||||
import logging
|
||||
import hashlib
|
||||
from yaml import full_load, dump
|
||||
import utils
|
||||
|
||||
logger = logging.getLogger("milvus_benchmark.utils")
|
||||
REGISTRY_URL = "registry.zilliz.com/milvus/engine"
|
||||
IDC_NAS_URL = "//172.16.70.249/test"
|
||||
NAS_URL = "//192.168.1.126/test"
|
||||
|
||||
|
||||
def get_host_cpus(hostname):
|
||||
from kubernetes import client, config
|
||||
config.load_kube_config()
|
||||
client.rest.logger.setLevel(logging.WARNING)
|
||||
v1 = client.CoreV1Api()
|
||||
cpus = v1.read_node(hostname).status.allocatable.get("cpu")
|
||||
return cpus
|
||||
|
||||
|
||||
# update values.yaml
|
||||
def update_values(file_path, deploy_mode, hostname, milvus_config, server_config=None):
|
||||
if not os.path.isfile(file_path):
|
||||
raise Exception('File: %s not found' % file_path)
|
||||
# bak values.yaml
|
||||
file_name = os.path.basename(file_path)
|
||||
bak_file_name = file_name + ".bak"
|
||||
file_parent_path = os.path.dirname(file_path)
|
||||
bak_file_path = file_parent_path + '/' + bak_file_name
|
||||
if os.path.exists(bak_file_path):
|
||||
os.system("cp %s %s" % (bak_file_path, file_path))
|
||||
else:
|
||||
os.system("cp %s %s" % (file_path, bak_file_path))
|
||||
with open(file_path) as f:
|
||||
values_dict = full_load(f)
|
||||
f.close()
|
||||
cluster = False
|
||||
if "cluster" in milvus_config and milvus_config["cluster"]:
|
||||
cluster = True
|
||||
for k, v in milvus_config.items():
|
||||
if k.find("primary_path") != -1:
|
||||
suffix_path = milvus_config["suffix_path"] if "suffix_path" in milvus_config else None
|
||||
path_value = v
|
||||
if suffix_path:
|
||||
path_value = v + "_" + str(int(time.time()))
|
||||
values_dict["primaryPath"] = path_value
|
||||
values_dict['wal']['path'] = path_value + "/wal"
|
||||
values_dict['logs']['path'] = path_value + "/logs"
|
||||
# elif k.find("use_blas_threshold") != -1:
|
||||
# values_dict['useBLASThreshold'] = int(v)
|
||||
elif k.find("gpu_search_threshold") != -1:
|
||||
values_dict['gpu']['gpuSearchThreshold'] = int(v)
|
||||
if cluster:
|
||||
values_dict['readonly']['gpu']['gpuSearchThreshold'] = int(v)
|
||||
elif k.find("cpu_cache_capacity") != -1:
|
||||
values_dict['cache']['cacheSize'] = v
|
||||
if cluster:
|
||||
values_dict['readonly']['cache']['cacheSize'] = v
|
||||
# elif k.find("cache_insert_data") != -1:
|
||||
# values_dict['cache']['cacheInsertData'] = v
|
||||
elif k.find("insert_buffer_size") != -1:
|
||||
values_dict['cache']['insertBufferSize'] = v
|
||||
if cluster:
|
||||
values_dict['readonly']['cache']['insertBufferSize'] = v
|
||||
elif k.find("gpu_resource_config.enable") != -1:
|
||||
values_dict['gpu']['enabled'] = v
|
||||
if cluster:
|
||||
values_dict['readonly']['gpu']['enabled'] = v
|
||||
elif k.find("gpu_resource_config.cache_capacity") != -1:
|
||||
values_dict['gpu']['cacheSize'] = v
|
||||
if cluster:
|
||||
values_dict['readonly']['gpu']['cacheSize'] = v
|
||||
elif k.find("build_index_resources") != -1:
|
||||
values_dict['gpu']['buildIndexDevices'] = v
|
||||
if cluster:
|
||||
values_dict['readonly']['gpu']['buildIndexDevices'] = v
|
||||
elif k.find("search_resources") != -1:
|
||||
values_dict['gpu']['searchDevices'] = v
|
||||
if cluster:
|
||||
values_dict['readonly']['gpu']['searchDevices'] = v
|
||||
# wal
|
||||
elif k.find("auto_flush_interval") != -1:
|
||||
values_dict['storage']['autoFlushInterval'] = v
|
||||
if cluster:
|
||||
values_dict['readonly']['storage']['autoFlushInterval'] = v
|
||||
elif k.find("wal_enable") != -1:
|
||||
values_dict['wal']['enabled'] = v
|
||||
|
||||
# if values_dict['nodeSelector']:
|
||||
# logger.warning("nodeSelector has been set: %s" % str(values_dict['engine']['nodeSelector']))
|
||||
# return
|
||||
values_dict["wal"]["recoveryErrorIgnore"] = True
|
||||
# enable monitor
|
||||
values_dict["metrics"]["enabled"] = True
|
||||
values_dict["metrics"]["address"] = "192.168.1.237"
|
||||
values_dict["metrics"]["port"] = 9091
|
||||
# only test avx2
|
||||
values_dict["extraConfiguration"].update({"engine": {"simd_type": "avx2"}})
|
||||
# stat_optimizer_enable
|
||||
values_dict["extraConfiguration"]["engine"].update({"stat_optimizer_enable": False})
|
||||
|
||||
# enable read-write mode
|
||||
if cluster:
|
||||
values_dict["cluster"]["enabled"] = True
|
||||
# update readonly log path
|
||||
values_dict["readonly"]['logs']['path'] = values_dict['logs']['path'] + "/readonly"
|
||||
if "readonly" in milvus_config:
|
||||
if "replicas" in milvus_config["readonly"]:
|
||||
values_dict["readonly"]["replicas"] = milvus_config["readonly"]["replicas"]
|
||||
|
||||
use_external_mysql = False
|
||||
if "external_mysql" in milvus_config and milvus_config["external_mysql"]:
|
||||
use_external_mysql = True
|
||||
# meta mysql
|
||||
if use_external_mysql:
|
||||
values_dict["mysql"]["enabled"] = False
|
||||
# values_dict["mysql"]["persistence"]["enabled"] = True
|
||||
# values_dict["mysql"]["persistence"]["existingClaim"] = hashlib.md5(path_value.encode(encoding='UTF-8')).hexdigest()
|
||||
values_dict['externalMysql']['enabled'] = True
|
||||
if deploy_mode == "local":
|
||||
values_dict['externalMysql']["ip"] = "192.168.1.238"
|
||||
else:
|
||||
values_dict['externalMysql']["ip"] = "milvus-mysql.test"
|
||||
values_dict['externalMysql']["port"] = 3306
|
||||
values_dict['externalMysql']["user"] = "root"
|
||||
values_dict['externalMysql']["password"] = "milvus"
|
||||
values_dict['externalMysql']["database"] = "db"
|
||||
else:
|
||||
values_dict["mysql"]["enabled"] = False
|
||||
# update values.yaml with the given host
|
||||
nas_url = NAS_URL
|
||||
if hostname:
|
||||
nas_url = IDC_NAS_URL
|
||||
values_dict['nodeSelector'] = {'kubernetes.io/hostname': hostname}
|
||||
cpus = server_config["cpus"]
|
||||
|
||||
# set limit/request cpus in resources
|
||||
values_dict["image"]['resources'] = {
|
||||
"limits": {
|
||||
# "cpu": str(int(cpus)) + ".0"
|
||||
"cpu": str(int(cpus)) + ".0"
|
||||
},
|
||||
"requests": {
|
||||
# "cpu": str(int(cpus) // 2) + ".0"
|
||||
"cpu": "4.0"
|
||||
}
|
||||
}
|
||||
# update readonly resouces limits/requests
|
||||
values_dict["readonly"]['resources'] = {
|
||||
"limits": {
|
||||
# "cpu": str(int(cpus)) + ".0"
|
||||
"cpu": str(int(cpus)) + ".0"
|
||||
},
|
||||
"requests": {
|
||||
# "cpu": str(int(cpus) // 2) + ".0"
|
||||
"cpu": "4.0"
|
||||
}
|
||||
}
|
||||
values_dict['tolerations'] = [{
|
||||
"key": "worker",
|
||||
"operator": "Equal",
|
||||
"value": "performance",
|
||||
"effect": "NoSchedule"
|
||||
}]
|
||||
# add extra volumes
|
||||
values_dict['extraVolumes'] = [{
|
||||
'name': 'test',
|
||||
'flexVolume': {
|
||||
'driver': "fstab/cifs",
|
||||
'fsType': "cifs",
|
||||
'secretRef': {
|
||||
'name': "cifs-test-secret"
|
||||
},
|
||||
'options': {
|
||||
'networkPath': nas_url,
|
||||
'mountOptions': "vers=1.0"
|
||||
}
|
||||
}
|
||||
}]
|
||||
values_dict['extraVolumeMounts'] = [{
|
||||
'name': 'test',
|
||||
'mountPath': '/test'
|
||||
}]
|
||||
|
||||
# add extra volumes for mysql
|
||||
# values_dict['mysql']['persistence']['enabled'] = True
|
||||
# values_dict['mysql']['configurationFilesPath'] = "/etc/mysql/mysql.conf.d/"
|
||||
# values_dict['mysql']['imageTag'] = '5.6'
|
||||
# values_dict['mysql']['securityContext'] = {
|
||||
# 'enabled': True}
|
||||
# mysql_db_path = "/test"
|
||||
if deploy_mode == "cluster" and use_external_mysql:
|
||||
# mount_path = values_dict["primaryPath"]+'/data'
|
||||
# long_str = '- name: test-mysql\n flexVolume:\n driver: fstab/cifs\n fsType: cifs\n secretRef:\n name: cifs-test-secret\n options:\n networkPath: //192.168.1.126/test\n mountOptions: vers=1.0'
|
||||
# values_dict['mysql']['extraVolumes'] = literal_str(long_str)
|
||||
# long_str_2 = "- name: test-mysql\n mountPath: %s" % mysql_db_path
|
||||
# values_dict['mysql']['extraVolumeMounts'] = literal_str(long_str_2)
|
||||
# mysql_cnf_str = '[mysqld]\npid-file=%s/mysql.pid\ndatadir=%s' % (mount_path, mount_path)
|
||||
# values_dict['mysql']['configurationFiles'] = {}
|
||||
# values_dict['mysql']['configurationFiles']['mysqld.cnf'] = literal_str(mysql_cnf_str)
|
||||
|
||||
values_dict['mysql']['enabled'] = False
|
||||
values_dict['externalMysql']['enabled'] = True
|
||||
values_dict['externalMysql']["ip"] = "192.168.1.197"
|
||||
values_dict['externalMysql']["port"] = 3306
|
||||
values_dict['externalMysql']["user"] = "root"
|
||||
values_dict['externalMysql']["password"] = "Fantast1c"
|
||||
values_dict['externalMysql']["database"] = "db"
|
||||
|
||||
# logger.debug(values_dict)
|
||||
# print(dump(values_dict))
|
||||
with open(file_path, 'w') as f:
|
||||
dump(values_dict, f, default_flow_style=False)
|
||||
f.close()
|
||||
# DEBUG
|
||||
with open(file_path) as f:
|
||||
for line in f.readlines():
|
||||
line = line.strip("\n")
|
||||
|
||||
|
||||
# deploy server
|
||||
def helm_install_server(helm_path, deploy_mode, image_tag, image_type, name, namespace):
|
||||
timeout = 300
|
||||
logger.debug("Server deploy mode: %s" % deploy_mode)
|
||||
host = "%s.%s.svc.cluster.local" % (name, namespace)
|
||||
if deploy_mode == "single":
|
||||
install_cmd = "helm install \
|
||||
--set image.repository=%s \
|
||||
--set image.tag=%s \
|
||||
--set image.pullPolicy=Always \
|
||||
--set service.type=ClusterIP \
|
||||
-f ci/filebeat/values.yaml \
|
||||
--namespace %s \
|
||||
%s ." % (REGISTRY_URL, image_tag, namespace, name)
|
||||
elif deploy_mode == "cluster":
|
||||
install_cmd = "helm install \
|
||||
--set cluster.enabled=true \
|
||||
--set persistence.enabled=true \
|
||||
--set mishards.image.tag=test \
|
||||
--set mishards.image.pullPolicy=Always \
|
||||
--set image.repository=%s \
|
||||
--set image.tag=%s \
|
||||
--set image.pullPolicy=Always \
|
||||
--set service.type=ClusterIP \
|
||||
-f ci/filebeat/values.yaml \
|
||||
--namespace %s \
|
||||
%s ." % (REGISTRY_URL, image_tag, namespace, name)
|
||||
logger.debug(install_cmd)
|
||||
logger.debug(host)
|
||||
if os.system("cd %s && %s" % (helm_path, install_cmd)):
|
||||
logger.error("Helm install failed: %s" % name)
|
||||
return None
|
||||
time.sleep(30)
|
||||
# config.load_kube_config()
|
||||
# v1 = client.CoreV1Api()
|
||||
# pod_name = None
|
||||
# pod_id = None
|
||||
# pods = v1.list_namespaced_pod(namespace)
|
||||
# for i in pods.items:
|
||||
# if i.metadata.name.find(name) != -1:
|
||||
# pod_name = i.metadata.name
|
||||
# pod_ip = i.status.pod_ip
|
||||
# logger.debug(pod_name)
|
||||
# logger.debug(pod_ip)
|
||||
# return pod_name, pod_ip
|
||||
return host
|
||||
|
||||
|
||||
# delete server
|
||||
@utils.retry(3)
|
||||
def helm_del_server(name, namespace):
|
||||
# logger.debug("Sleep 600s before uninstall server")
|
||||
# time.sleep(600)
|
||||
del_cmd = "helm uninstall -n milvus %s" % name
|
||||
logger.info(del_cmd)
|
||||
if os.system(del_cmd):
|
||||
logger.error("Helm delete name:%s failed" % name)
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def restart_server(helm_release_name, namespace):
|
||||
res = True
|
||||
timeout = 120000
|
||||
# service_name = "%s.%s.svc.cluster.local" % (helm_release_name, namespace)
|
||||
config.load_kube_config()
|
||||
v1 = client.CoreV1Api()
|
||||
pod_name = None
|
||||
# config_map_names = v1.list_namespaced_config_map(namespace, pretty='true')
|
||||
# body = {"replicas": 0}
|
||||
pods = v1.list_namespaced_pod(namespace)
|
||||
for i in pods.items:
|
||||
if i.metadata.name.find(helm_release_name) != -1 and i.metadata.name.find("mysql") == -1:
|
||||
pod_name = i.metadata.name
|
||||
break
|
||||
# v1.patch_namespaced_config_map(config_map_name, namespace, body, pretty='true')
|
||||
# status_res = v1.read_namespaced_service_status(helm_release_name, namespace, pretty='true')
|
||||
logger.debug("Pod name: %s" % pod_name)
|
||||
if pod_name is not None:
|
||||
try:
|
||||
v1.delete_namespaced_pod(pod_name, namespace)
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
logger.error("Exception when calling CoreV1Api->delete_namespaced_pod")
|
||||
res = False
|
||||
return res
|
||||
logger.error("Sleep 10s after pod deleted")
|
||||
time.sleep(10)
|
||||
# check if restart successfully
|
||||
pods = v1.list_namespaced_pod(namespace)
|
||||
for i in pods.items:
|
||||
pod_name_tmp = i.metadata.name
|
||||
logger.error(pod_name_tmp)
|
||||
if pod_name_tmp == pod_name:
|
||||
continue
|
||||
elif pod_name_tmp.find(helm_release_name) == -1 or pod_name_tmp.find("mysql") != -1:
|
||||
continue
|
||||
else:
|
||||
status_res = v1.read_namespaced_pod_status(pod_name_tmp, namespace, pretty='true')
|
||||
logger.error(status_res.status.phase)
|
||||
start_time = time.time()
|
||||
ready_break = False
|
||||
while time.time() - start_time <= timeout:
|
||||
logger.error(time.time())
|
||||
status_res = v1.read_namespaced_pod_status(pod_name_tmp, namespace, pretty='true')
|
||||
if status_res.status.phase == "Running":
|
||||
logger.error("Already running")
|
||||
ready_break = True
|
||||
break
|
||||
else:
|
||||
time.sleep(5)
|
||||
if time.time() - start_time > timeout:
|
||||
logger.error("Restart pod: %s timeout" % pod_name_tmp)
|
||||
res = False
|
||||
return res
|
||||
if ready_break:
|
||||
break
|
||||
else:
|
||||
raise Exception("Pod: %s not found" % pod_name)
|
||||
follow = True
|
||||
pretty = True
|
||||
previous = True # bool | Return previous terminated container logs. Defaults to false. (optional)
|
||||
since_seconds = 56 # int | A relative time in seconds before the current time from which to show logs. If this value precedes the time a pod was started, only logs since the pod start will be returned. If this value is in the future, no logs will be returned. Only one of sinceSeconds or sinceTime may be specified. (optional)
|
||||
timestamps = True # bool | If true, add an RFC3339 or RFC3339Nano timestamp at the beginning of every line of log output. Defaults to false. (optional)
|
||||
container = "milvus"
|
||||
# start_time = time.time()
|
||||
# while time.time() - start_time <= timeout:
|
||||
# try:
|
||||
# api_response = v1.read_namespaced_pod_log(pod_name_tmp, namespace, container=container, follow=follow,
|
||||
# pretty=pretty, previous=previous, since_seconds=since_seconds,
|
||||
# timestamps=timestamps)
|
||||
# logging.error(api_response)
|
||||
# return res
|
||||
# except Exception as e:
|
||||
# logging.error("Exception when calling CoreV1Api->read_namespaced_pod_log: %s\n" % e)
|
||||
# # waiting for server start
|
||||
# time.sleep(2)
|
||||
# # res = False
|
||||
# # return res
|
||||
# if time.time() - start_time > timeout:
|
||||
# logging.error("Restart pod: %s timeout" % pod_name_tmp)
|
||||
# res = False
|
||||
return res
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
print(type(get_host_cpus("idc-sh002")))
|
927
tests/benchmark/k8s_runner.py
Normal file
927
tests/benchmark/k8s_runner.py
Normal file
@ -0,0 +1,927 @@
|
||||
import os
|
||||
import logging
|
||||
import pdb
|
||||
import time
|
||||
import re
|
||||
import random
|
||||
import traceback
|
||||
import json
|
||||
import csv
|
||||
import threading
|
||||
from multiprocessing import Process
|
||||
import numpy as np
|
||||
from milvus import DataType
|
||||
from yaml import full_load, dump
|
||||
import concurrent.futures
|
||||
|
||||
import locust_user
|
||||
from client import MilvusClient
|
||||
import parser
|
||||
from runner import Runner
|
||||
from milvus_metrics.api import report
|
||||
from milvus_metrics.models import Env, Hardware, Server, Metric
|
||||
import helm_utils
|
||||
import utils
|
||||
|
||||
logger = logging.getLogger("milvus_benchmark.k8s_runner")
|
||||
namespace = "milvus"
|
||||
default_port = 19530
|
||||
DELETE_INTERVAL_TIME = 5
|
||||
# INSERT_INTERVAL = 100000
|
||||
INSERT_INTERVAL = 50000
|
||||
BIG_FLUSH_INTERVAL = 3600
|
||||
DEFAULT_FLUSH_INTERVAL = 1
|
||||
timestamp = int(time.time())
|
||||
default_path = "/var/lib/milvus"
|
||||
|
||||
|
||||
class K8sRunner(Runner):
|
||||
"""run docker mode"""
|
||||
|
||||
def __init__(self):
|
||||
super(K8sRunner, self).__init__()
|
||||
self.service_name = utils.get_unique_name()
|
||||
self.host = None
|
||||
self.port = default_port
|
||||
self.hostname = None
|
||||
self.env_value = None
|
||||
self.hardware = None
|
||||
self.deploy_mode = None
|
||||
|
||||
def init_env(self, milvus_config, server_config, server_host, deploy_mode, image_type, image_tag):
|
||||
logger.debug("Tests run on server host:")
|
||||
logger.debug(server_host)
|
||||
self.hostname = server_host
|
||||
self.deploy_mode = deploy_mode
|
||||
if self.hostname:
|
||||
try:
|
||||
cpus = helm_utils.get_host_cpus(self.hostname)
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
cpus = 64
|
||||
logger.debug(type(cpus))
|
||||
if server_config:
|
||||
if "cpus" in server_config.keys():
|
||||
cpus = min(server_config["cpus"], int(cpus))
|
||||
else:
|
||||
server_config.update({"cpus": cpus})
|
||||
else:
|
||||
server_config = {"cpus": cpus}
|
||||
self.hardware = Hardware(name=self.hostname, cpus=cpus)
|
||||
# update values
|
||||
helm_path = os.path.join(os.getcwd(), "../milvus-helm/charts/milvus")
|
||||
values_file_path = helm_path + "/values.yaml"
|
||||
if not os.path.exists(values_file_path):
|
||||
raise Exception("File %s not existed" % values_file_path)
|
||||
if milvus_config:
|
||||
helm_utils.update_values(values_file_path, deploy_mode, server_host, milvus_config, server_config)
|
||||
try:
|
||||
logger.debug("Start install server")
|
||||
self.host = helm_utils.helm_install_server(helm_path, deploy_mode, image_tag, image_type, self.service_name,
|
||||
namespace)
|
||||
except Exception as e:
|
||||
logger.error("Helm install server failed: %s" % (str(e)))
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(self.hostname)
|
||||
self.clean_up()
|
||||
return False
|
||||
logger.debug(server_config)
|
||||
# for debugging
|
||||
if not self.host:
|
||||
logger.error("Helm install server failed")
|
||||
self.clean_up()
|
||||
return False
|
||||
return True
|
||||
|
||||
def clean_up(self):
|
||||
logger.debug("Start clean up: %s" % self.service_name)
|
||||
helm_utils.helm_del_server(self.service_name, namespace)
|
||||
|
||||
def report_wrapper(self, milvus_instance, env_value, hostname, collection_info, index_info, search_params,
|
||||
run_params=None, server_config=None):
|
||||
metric = Metric()
|
||||
metric.set_run_id(timestamp)
|
||||
metric.env = Env(env_value)
|
||||
metric.env.OMP_NUM_THREADS = 0
|
||||
metric.hardware = self.hardware
|
||||
# TODO: removed
|
||||
# server_version = milvus_instance.get_server_version()
|
||||
# server_mode = milvus_instance.get_server_mode()
|
||||
# commit = milvus_instance.get_server_commit()
|
||||
server_version = "0.12.0"
|
||||
server_mode = self.deploy_mode
|
||||
metric.server = Server(version=server_version, mode=server_mode, build_commit=None)
|
||||
metric.collection = collection_info
|
||||
metric.index = index_info
|
||||
metric.search = search_params
|
||||
metric.run_params = run_params
|
||||
return metric
|
||||
|
||||
def run(self, run_type, collection):
|
||||
logger.debug(run_type)
|
||||
logger.debug(collection)
|
||||
collection_name = collection["collection_name"] if "collection_name" in collection else None
|
||||
milvus_instance = MilvusClient(collection_name=collection_name, host=self.host)
|
||||
|
||||
# TODO: removed
|
||||
# self.env_value = milvus_instance.get_server_config()
|
||||
# ugly implemention
|
||||
# self.env_value = utils.convert_nested(self.env_value)
|
||||
# self.env_value.pop("logs")
|
||||
# self.env_value.pop("network")
|
||||
self.env_value = collection
|
||||
|
||||
if run_type == "insert_performance":
|
||||
(data_type, collection_size, dimension, metric_type) = parser.collection_parser(
|
||||
collection_name)
|
||||
ni_per = collection["ni_per"]
|
||||
build_index = collection["build_index"]
|
||||
if milvus_instance.exists_collection():
|
||||
milvus_instance.drop()
|
||||
time.sleep(10)
|
||||
index_info = {}
|
||||
search_params = {}
|
||||
vector_type = self.get_vector_type(data_type)
|
||||
other_fields = collection["other_fields"] if "other_fields" in collection else None
|
||||
milvus_instance.create_collection(dimension, data_type=vector_type,
|
||||
other_fields=other_fields)
|
||||
if build_index is True:
|
||||
index_type = collection["index_type"]
|
||||
index_param = collection["index_param"]
|
||||
index_info = {
|
||||
"index_type": index_type,
|
||||
"index_param": index_param
|
||||
}
|
||||
index_field_name = utils.get_default_field_name(vector_type)
|
||||
milvus_instance.create_index(index_field_name, index_type, metric_type, index_param=index_param)
|
||||
logger.debug(milvus_instance.describe_index())
|
||||
res = self.do_insert(milvus_instance, collection_name, data_type, dimension, collection_size, ni_per)
|
||||
flush_time = 0.0
|
||||
if "flush" in collection and collection["flush"] == "no":
|
||||
logger.debug("No manual flush")
|
||||
else:
|
||||
start_time = time.time()
|
||||
milvus_instance.flush()
|
||||
flush_time = time.time() - start_time
|
||||
logger.debug(milvus_instance.count())
|
||||
collection_info = {
|
||||
"dimension": dimension,
|
||||
"metric_type": metric_type,
|
||||
"dataset_name": collection_name,
|
||||
"other_fields": other_fields,
|
||||
"ni_per": ni_per
|
||||
}
|
||||
metric = self.report_wrapper(milvus_instance, self.env_value, self.hostname, collection_info, index_info,
|
||||
search_params)
|
||||
total_time = res["total_time"]
|
||||
build_time = 0
|
||||
if build_index is True:
|
||||
logger.debug("Start build index for last file")
|
||||
start_time = time.time()
|
||||
milvus_instance.create_index(index_field_name, index_type, metric_type, index_param=index_param)
|
||||
build_time = time.time() - start_time
|
||||
total_time = total_time + build_time
|
||||
metric.metrics = {
|
||||
"type": run_type,
|
||||
"value": {
|
||||
"total_time": total_time,
|
||||
"qps": res["qps"],
|
||||
"ni_time": res["ni_time"],
|
||||
"flush_time": flush_time,
|
||||
"build_time": build_time
|
||||
}
|
||||
}
|
||||
report(metric)
|
||||
|
||||
elif run_type == "build_performance":
|
||||
(data_type, collection_size, dimension, metric_type) = parser.collection_parser(
|
||||
collection_name)
|
||||
index_type = collection["index_type"]
|
||||
index_param = collection["index_param"]
|
||||
collection_info = {
|
||||
"dimension": dimension,
|
||||
"metric_type": metric_type,
|
||||
"dataset_name": collection_name
|
||||
}
|
||||
index_info = {
|
||||
"index_type": index_type,
|
||||
"index_param": index_param
|
||||
}
|
||||
if not milvus_instance.exists_collection():
|
||||
logger.error("Table name: %s not existed" % collection_name)
|
||||
return
|
||||
search_params = {}
|
||||
vector_type = self.get_vector_type(data_type)
|
||||
index_field_name = utils.get_default_field_name(vector_type)
|
||||
start_time = time.time()
|
||||
# drop index
|
||||
logger.debug("Drop index")
|
||||
milvus_instance.drop_index(index_field_name)
|
||||
# start_mem_usage = milvus_instance.get_mem_info()["memory_used"]
|
||||
# TODO: need to check
|
||||
milvus_instance.create_index(index_field_name, index_type, metric_type, index_param=index_param)
|
||||
logger.debug(milvus_instance.describe_index())
|
||||
logger.debug(milvus_instance.count())
|
||||
end_time = time.time()
|
||||
# end_mem_usage = milvus_instance.get_mem_info()["memory_used"]
|
||||
metric = self.report_wrapper(milvus_instance, self.env_value, self.hostname, collection_info, index_info,
|
||||
search_params)
|
||||
metric.metrics = {
|
||||
"type": "build_performance",
|
||||
"value": {
|
||||
"build_time": round(end_time - start_time, 1),
|
||||
}
|
||||
}
|
||||
report(metric)
|
||||
|
||||
elif run_type == "delete_performance":
|
||||
(data_type, collection_size, dimension, metric_type) = parser.collection_parser(
|
||||
collection_name)
|
||||
ni_per = collection["ni_per"]
|
||||
auto_flush = collection["auto_flush"] if "auto_flush" in collection else True
|
||||
search_params = {}
|
||||
collection_info = {
|
||||
"dimension": dimension,
|
||||
"metric_type": metric_type,
|
||||
"dataset_name": collection_name
|
||||
}
|
||||
if not milvus_instance.exists_collection():
|
||||
logger.error(milvus_instance.show_collections())
|
||||
logger.error("Table name: %s not existed" % collection_name)
|
||||
return
|
||||
length = milvus_instance.count()
|
||||
logger.info(length)
|
||||
index_info = milvus_instance.describe_index()
|
||||
logger.info(index_info)
|
||||
ids = [i for i in range(length)]
|
||||
loops = int(length / ni_per)
|
||||
milvus_instance.load_collection()
|
||||
# TODO: remove
|
||||
# start_mem_usage = milvus_instance.get_mem_info()["memory_used"]
|
||||
start_time = time.time()
|
||||
# if auto_flush is False:
|
||||
# milvus_instance.set_config("storage", "auto_flush_interval", BIG_FLUSH_INTERVAL)
|
||||
for i in range(loops):
|
||||
delete_ids = ids[i * ni_per: i * ni_per + ni_per]
|
||||
logger.debug("Delete %d - %d" % (delete_ids[0], delete_ids[-1]))
|
||||
milvus_instance.delete(delete_ids)
|
||||
logger.debug("Table row counts: %d" % milvus_instance.count())
|
||||
logger.debug("Table row counts: %d" % milvus_instance.count())
|
||||
start_flush_time = time.time()
|
||||
milvus_instance.flush()
|
||||
end_flush_time = time.time()
|
||||
end_time = time.time()
|
||||
# end_mem_usage = milvus_instance.get_mem_info()["memory_used"]
|
||||
logger.debug("Table row counts: %d" % milvus_instance.count())
|
||||
# milvus_instance.set_config("storage", "auto_flush_interval", DEFAULT_FLUSH_INTERVAL)
|
||||
metric = self.report_wrapper(milvus_instance, self.env_value, self.hostname, collection_info, index_info,
|
||||
search_params)
|
||||
delete_time = round(end_time - start_time, 1)
|
||||
metric.metrics = {
|
||||
"type": "delete_performance",
|
||||
"value": {
|
||||
"delete_time": delete_time,
|
||||
"qps": round(collection_size / delete_time, 1)
|
||||
}
|
||||
}
|
||||
if auto_flush is False:
|
||||
flush_time = round(end_flush_time - start_flush_time, 1)
|
||||
metric.metrics["value"].update({"flush_time": flush_time})
|
||||
report(metric)
|
||||
|
||||
elif run_type == "get_ids_performance":
|
||||
(data_type, collection_size, dimension, metric_type) = parser.collection_parser(
|
||||
collection_name)
|
||||
ids_length_per_segment = collection["ids_length_per_segment"]
|
||||
if not milvus_instance.exists_collection():
|
||||
logger.error("Table name: %s not existed" % collection_name)
|
||||
return
|
||||
collection_info = {
|
||||
"dimension": dimension,
|
||||
"metric_type": metric_type,
|
||||
"dataset_name": collection_name
|
||||
}
|
||||
search_params = {}
|
||||
logger.info(milvus_instance.count())
|
||||
index_info = milvus_instance.describe_index()
|
||||
logger.info(index_info)
|
||||
for ids_num in ids_length_per_segment:
|
||||
segment_num, get_ids = milvus_instance.get_rand_ids_each_segment(ids_num)
|
||||
start_time = time.time()
|
||||
get_res = milvus_instance.get_entities(get_ids)
|
||||
total_time = time.time() - start_time
|
||||
avg_time = total_time / segment_num
|
||||
run_params = {"ids_num": ids_num}
|
||||
logger.info(
|
||||
"Segment num: %d, ids num per segment: %d, run_time: %f" % (segment_num, ids_num, total_time))
|
||||
metric = self.report_wrapper(milvus_instance, self.env_value, self.hostname, collection_info,
|
||||
index_info, search_params, run_params=run_params)
|
||||
metric.metrics = {
|
||||
"type": run_type,
|
||||
"value": {
|
||||
"total_time": round(total_time, 1),
|
||||
"avg_time": round(avg_time, 1)
|
||||
}
|
||||
}
|
||||
report(metric)
|
||||
|
||||
elif run_type == "search_performance":
|
||||
(data_type, collection_size, dimension, metric_type) = parser.collection_parser(
|
||||
collection_name)
|
||||
run_count = collection["run_count"]
|
||||
top_ks = collection["top_ks"]
|
||||
nqs = collection["nqs"]
|
||||
# filter_query = collection["filter"] if "filter" in collection else None
|
||||
filters = collection["filters"] if "filters" in collection else []
|
||||
filter_query = []
|
||||
search_params = collection["search_params"]
|
||||
fields = self.get_fields(milvus_instance, collection_name)
|
||||
collection_info = {
|
||||
"dimension": dimension,
|
||||
"metric_type": metric_type,
|
||||
"dataset_name": collection_name
|
||||
"fields": fields
|
||||
}
|
||||
if not milvus_instance.exists_collection():
|
||||
logger.error("Table name: %s not existed" % collection_name)
|
||||
return
|
||||
|
||||
vector_type = self.get_vector_type(data_type)
|
||||
vec_field_name = utils.get_default_field_name(vector_type)
|
||||
logger.info(milvus_instance.count())
|
||||
index_info = milvus_instance.describe_index()
|
||||
logger.info(index_info)
|
||||
milvus_instance.load_collection()
|
||||
logger.info("Start warm up query")
|
||||
res = self.do_query(milvus_instance, collection_name, vec_field_name, [1], [1], 2,
|
||||
search_param=search_params[0], filter_query=filter_query)
|
||||
logger.info("End warm up query")
|
||||
for search_param in search_params:
|
||||
logger.info("Search param: %s" % json.dumps(search_param))
|
||||
if not filters:
|
||||
filters.append(None)
|
||||
for filter in filters:
|
||||
filter_param = []
|
||||
if isinstance(filter, dict) and "range" in filter:
|
||||
filter_query.append(eval(filter["range"]))
|
||||
filter_param.append(filter["range"])
|
||||
if isinstance(filter, dict) and "term" in filter:
|
||||
filter_query.append(eval(filter["term"]))
|
||||
filter_param.append(filter["term"])
|
||||
logger.info("filter param: %s" % json.dumps(filter_param))
|
||||
res = self.do_query(milvus_instance, collection_name, vec_field_name, top_ks, nqs, run_count,
|
||||
search_param, filter_query=filter_query)
|
||||
headers = ["Nq/Top-k"]
|
||||
headers.extend([str(top_k) for top_k in top_ks])
|
||||
logger.info("Search param: %s" % json.dumps(search_param))
|
||||
utils.print_table(headers, nqs, res)
|
||||
for index_nq, nq in enumerate(nqs):
|
||||
for index_top_k, top_k in enumerate(top_ks):
|
||||
search_param_group = {
|
||||
"nq": nq,
|
||||
"topk": top_k,
|
||||
"search_param": search_param,
|
||||
"filter": filter_param
|
||||
}
|
||||
search_time = res[index_nq][index_top_k]
|
||||
metric = self.report_wrapper(milvus_instance, self.env_value, self.hostname,
|
||||
collection_info, index_info, search_param_group)
|
||||
metric.metrics = {
|
||||
"type": "search_performance",
|
||||
"value": {
|
||||
"search_time": search_time
|
||||
}
|
||||
}
|
||||
report(metric)
|
||||
|
||||
elif run_type == "locust_insert_stress":
|
||||
pass
|
||||
|
||||
elif run_type in ["locust_search_performance", "locust_insert_performance", "locust_mix_performance"]:
|
||||
(data_type, collection_size, dimension, metric_type) = parser.collection_parser(
|
||||
collection_name)
|
||||
ni_per = collection["ni_per"]
|
||||
build_index = collection["build_index"]
|
||||
if milvus_instance.exists_collection():
|
||||
milvus_instance.drop()
|
||||
time.sleep(10)
|
||||
index_info = {}
|
||||
search_params = {}
|
||||
vector_type = self.get_vector_type(data_type)
|
||||
index_field_name = utils.get_default_field_name(vector_type)
|
||||
milvus_instance.create_collection(dimension, data_type=vector_type, other_fields=None)
|
||||
vector_type = self.get_vector_type(data_type)
|
||||
vec_field_name = utils.get_default_field_name(vector_type)
|
||||
if build_index is True:
|
||||
index_type = collection["index_type"]
|
||||
index_param = collection["index_param"]
|
||||
index_info = {
|
||||
"index_type": index_type,
|
||||
"index_param": index_param
|
||||
}
|
||||
milvus_instance.create_index(index_field_name, index_type, metric_type, index_param=index_param)
|
||||
logger.debug(milvus_instance.describe_index())
|
||||
if run_type in ["locust_search_performance", "locust_mix_performance"]:
|
||||
res = self.do_insert(milvus_instance, collection_name, data_type, dimension, collection_size, ni_per)
|
||||
if "flush" in collection and collection["flush"] == "no":
|
||||
logger.debug("No manual flush")
|
||||
else:
|
||||
milvus_instance.flush()
|
||||
if build_index is True:
|
||||
logger.debug("Start build index for last file")
|
||||
milvus_instance.create_index(index_field_name, index_type, metric_type, _async=True,
|
||||
index_param=index_param)
|
||||
logger.debug(milvus_instance.describe_index())
|
||||
logger.debug("Table row counts: %d" % milvus_instance.count())
|
||||
milvus_instance.load_collection()
|
||||
logger.info("Start warm up query")
|
||||
for i in range(2):
|
||||
res = self.do_query(milvus_instance, collection_name, vec_field_name, [1], [1], 2,
|
||||
search_param={"nprobe": 16})
|
||||
logger.info("End warm up query")
|
||||
real_metric_type = utils.metric_type_trans(metric_type)
|
||||
### spawn locust requests
|
||||
task = collection["task"]
|
||||
connection_type = "single"
|
||||
connection_num = task["connection_num"]
|
||||
if connection_num > 1:
|
||||
connection_type = "multi"
|
||||
clients_num = task["clients_num"]
|
||||
hatch_rate = task["hatch_rate"]
|
||||
during_time = utils.timestr_to_int(task["during_time"])
|
||||
task_types = task["types"]
|
||||
run_params = {"tasks": {}, "clients_num": clients_num, "spawn_rate": hatch_rate, "during_time": during_time}
|
||||
for task_type in task_types:
|
||||
run_params["tasks"].update({task_type["type"]: task_type["weight"] if "weight" in task_type else 1})
|
||||
|
||||
# . collect stats
|
||||
locust_stats = locust_user.locust_executor(self.host, self.port, collection_name,
|
||||
connection_type=connection_type, run_params=run_params)
|
||||
logger.info(locust_stats)
|
||||
collection_info = {
|
||||
"dimension": dimension,
|
||||
"metric_type": metric_type,
|
||||
"dataset_name": collection_name
|
||||
}
|
||||
metric = self.report_wrapper(milvus_instance, self.env_value, self.hostname, collection_info, index_info,
|
||||
search_params)
|
||||
metric.metrics = {
|
||||
"type": run_type,
|
||||
"value": locust_stats}
|
||||
report(metric)
|
||||
|
||||
elif run_type == "search_ids_stability":
|
||||
(data_type, collection_size, dimension, metric_type) = parser.collection_parser(
|
||||
collection_name)
|
||||
search_params = collection["search_params"]
|
||||
during_time = collection["during_time"]
|
||||
ids_length = collection["ids_length"]
|
||||
ids = collection["ids"]
|
||||
collection_info = {
|
||||
"dimension": dimension,
|
||||
"metric_type": metric_type,
|
||||
"dataset_name": collection_name
|
||||
}
|
||||
if not milvus_instance.exists_collection():
|
||||
logger.error("Table name: %s not existed" % collection_name)
|
||||
return
|
||||
logger.info(milvus_instance.count())
|
||||
index_info = milvus_instance.describe_index()
|
||||
logger.info(index_info)
|
||||
g_top_k = int(collection["top_ks"].split("-")[1])
|
||||
l_top_k = int(collection["top_ks"].split("-")[0])
|
||||
g_id = int(ids.split("-")[1])
|
||||
l_id = int(ids.split("-")[0])
|
||||
g_id_length = int(ids_length.split("-")[1])
|
||||
l_id_length = int(ids_length.split("-")[0])
|
||||
|
||||
milvus_instance.load_collection()
|
||||
# start_mem_usage = milvus_instance.get_mem_info()["memory_used"]
|
||||
# logger.debug(start_mem_usage)
|
||||
start_time = time.time()
|
||||
while time.time() < start_time + during_time * 60:
|
||||
search_param = {}
|
||||
top_k = random.randint(l_top_k, g_top_k)
|
||||
ids_num = random.randint(l_id_length, g_id_length)
|
||||
ids_param = [random.randint(l_id_length, g_id_length) for _ in range(ids_num)]
|
||||
for k, v in search_params.items():
|
||||
search_param[k] = random.randint(int(v.split("-")[0]), int(v.split("-")[1]))
|
||||
logger.debug("Query top-k: %d, ids_num: %d, param: %s" % (top_k, ids_num, json.dumps(search_param)))
|
||||
result = milvus_instance.query_ids(top_k, ids_param, search_param=search_param)
|
||||
# end_mem_usage = milvus_instance.get_mem_info()["memory_used"]
|
||||
metric = self.report_wrapper(milvus_instance, self.env_value, self.hostname, collection_info, index_info,
|
||||
{})
|
||||
metric.metrics = {
|
||||
"type": "search_ids_stability",
|
||||
"value": {
|
||||
"during_time": during_time,
|
||||
}
|
||||
}
|
||||
report(metric)
|
||||
|
||||
# for sift/deep datasets
|
||||
# TODO: enable
|
||||
elif run_type == "accuracy":
|
||||
(data_type, collection_size, dimension, metric_type) = parser.collection_parser(
|
||||
collection_name)
|
||||
search_params = collection["search_params"]
|
||||
# mapping to search param list
|
||||
search_params = self.generate_combinations(search_params)
|
||||
|
||||
top_ks = collection["top_ks"]
|
||||
nqs = collection["nqs"]
|
||||
collection_info = {
|
||||
"dimension": dimension,
|
||||
"metric_type": metric_type,
|
||||
"dataset_name": collection_name
|
||||
}
|
||||
if not milvus_instance.exists_collection():
|
||||
logger.error("Table name: %s not existed" % collection_name)
|
||||
return
|
||||
logger.info(milvus_instance.count())
|
||||
index_info = milvus_instance.describe_index()
|
||||
logger.info(index_info)
|
||||
milvus_instance.load_collection()
|
||||
true_ids_all = self.get_groundtruth_ids(collection_size)
|
||||
vector_type = self.get_vector_type(data_type)
|
||||
vec_field_name = utils.get_default_field_name(vector_type)
|
||||
for search_param in search_params:
|
||||
headers = ["Nq/Top-k"]
|
||||
res = []
|
||||
for nq in nqs:
|
||||
for top_k in top_ks:
|
||||
tmp_res = []
|
||||
search_param_group = {
|
||||
"nq": nq,
|
||||
"topk": top_k,
|
||||
"search_param": search_param,
|
||||
"metric_type": metric_type
|
||||
}
|
||||
logger.info("Query params: %s" % json.dumps(search_param_group))
|
||||
result_ids = self.do_query_ids(milvus_instance, collection_name, vec_field_name, top_k, nq,
|
||||
search_param=search_param)
|
||||
# mem_used = milvus_instance.get_mem_info()["memory_used"]
|
||||
acc_value = self.get_recall_value(true_ids_all[:nq, :top_k].tolist(), result_ids)
|
||||
logger.info("Query accuracy: %s" % acc_value)
|
||||
tmp_res.append(acc_value)
|
||||
# logger.info("Memory usage: %s" % mem_used)
|
||||
metric = self.report_wrapper(milvus_instance, self.env_value, self.hostname, collection_info,
|
||||
index_info, search_param_group)
|
||||
metric.metrics = {
|
||||
"type": "accuracy",
|
||||
"value": {
|
||||
"acc": acc_value
|
||||
}
|
||||
}
|
||||
report(metric)
|
||||
# logger.info("Memory usage: %s" % mem_used)
|
||||
res.append(tmp_res)
|
||||
headers.extend([str(top_k) for top_k in top_ks])
|
||||
logger.info("Search param: %s" % json.dumps(search_param))
|
||||
utils.print_table(headers, nqs, res)
|
||||
|
||||
elif run_type == "ann_accuracy":
|
||||
hdf5_source_file = collection["source_file"]
|
||||
collection_name = collection["collection_name"]
|
||||
index_types = collection["index_types"]
|
||||
index_params = collection["index_params"]
|
||||
top_ks = collection["top_ks"]
|
||||
nqs = collection["nqs"]
|
||||
search_params = collection["search_params"]
|
||||
# mapping to search param list
|
||||
search_params = self.generate_combinations(search_params)
|
||||
# mapping to index param list
|
||||
index_params = self.generate_combinations(index_params)
|
||||
|
||||
data_type, dimension, metric_type = parser.parse_ann_collection_name(collection_name)
|
||||
collection_info = {
|
||||
"dimension": dimension,
|
||||
"metric_type": metric_type,
|
||||
"dataset_name": collection_name
|
||||
}
|
||||
dataset = utils.get_dataset(hdf5_source_file)
|
||||
if milvus_instance.exists_collection(collection_name):
|
||||
logger.info("Re-create collection: %s" % collection_name)
|
||||
milvus_instance.drop()
|
||||
time.sleep(DELETE_INTERVAL_TIME)
|
||||
true_ids = np.array(dataset["neighbors"])
|
||||
vector_type = self.get_vector_type_from_metric(metric_type)
|
||||
vec_field_name = utils.get_default_field_name(vector_type)
|
||||
real_metric_type = utils.metric_type_trans(metric_type)
|
||||
|
||||
# re-create collection
|
||||
if milvus_instance.exists_collection(collection_name):
|
||||
milvus_instance.drop()
|
||||
time.sleep(DELETE_INTERVAL_TIME)
|
||||
milvus_instance.create_collection(dimension, data_type=vector_type)
|
||||
insert_vectors = self.normalize(metric_type, np.array(dataset["train"]))
|
||||
if len(insert_vectors) != dataset["train"].shape[0]:
|
||||
raise Exception("Row count of insert vectors: %d is not equal to dataset size: %d" % (
|
||||
len(insert_vectors), dataset["train"].shape[0]))
|
||||
logger.debug("The row count of entities to be inserted: %d" % len(insert_vectors))
|
||||
# Insert batch once
|
||||
# milvus_instance.insert(insert_vectors)
|
||||
loops = len(insert_vectors) // INSERT_INTERVAL + 1
|
||||
for i in range(loops):
|
||||
start = i * INSERT_INTERVAL
|
||||
end = min((i + 1) * INSERT_INTERVAL, len(insert_vectors))
|
||||
if start < end:
|
||||
tmp_vectors = insert_vectors[start:end]
|
||||
ids = [i for i in range(start, end)]
|
||||
if not isinstance(tmp_vectors, list):
|
||||
entities = milvus_instance.generate_entities(tmp_vectors.tolist(), ids)
|
||||
res_ids = milvus_instance.insert(entities, ids=ids)
|
||||
else:
|
||||
entities = milvus_instance.generate_entities(tmp_vectors, ids)
|
||||
res_ids = milvus_instance.insert(entities, ids=ids)
|
||||
assert res_ids == ids
|
||||
milvus_instance.flush()
|
||||
res_count = milvus_instance.count()
|
||||
logger.info("Table: %s, row count: %d" % (collection_name, res_count))
|
||||
if res_count != len(insert_vectors):
|
||||
raise Exception("Table row count is not equal to insert vectors")
|
||||
for index_type in index_types:
|
||||
for index_param in index_params:
|
||||
logger.debug("Building index with param: %s" % json.dumps(index_param))
|
||||
if milvus_instance.get_config("cluster.enable") == "true":
|
||||
milvus_instance.create_index(vec_field_name, index_type, metric_type, _async=True,
|
||||
index_param=index_param)
|
||||
else:
|
||||
milvus_instance.create_index(vec_field_name, index_type, metric_type,
|
||||
index_param=index_param)
|
||||
logger.info(milvus_instance.describe_index())
|
||||
logger.info("Start load collection: %s" % collection_name)
|
||||
milvus_instance.load_collection()
|
||||
logger.info("End load collection: %s" % collection_name)
|
||||
index_info = {
|
||||
"index_type": index_type,
|
||||
"index_param": index_param
|
||||
}
|
||||
logger.debug(index_info)
|
||||
warm_up = True
|
||||
for search_param in search_params:
|
||||
for nq in nqs:
|
||||
query_vectors = self.normalize(metric_type, np.array(dataset["test"][:nq]))
|
||||
if not isinstance(query_vectors, list):
|
||||
query_vectors = query_vectors.tolist()
|
||||
for top_k in top_ks:
|
||||
search_param_group = {
|
||||
"nq": len(query_vectors),
|
||||
"topk": top_k,
|
||||
"search_param": search_param,
|
||||
"metric_type": metric_type
|
||||
}
|
||||
logger.debug(search_param_group)
|
||||
vector_query = {"vector": {vec_field_name: {
|
||||
"topk": top_k,
|
||||
"query": query_vectors,
|
||||
"metric_type": real_metric_type,
|
||||
"params": search_param}
|
||||
}}
|
||||
for i in range(2):
|
||||
result = milvus_instance.query(vector_query)
|
||||
warm_up = False
|
||||
logger.info("End warm up")
|
||||
result = milvus_instance.query(vector_query)
|
||||
result_ids = milvus_instance.get_ids(result)
|
||||
acc_value = self.get_recall_value(true_ids[:nq, :top_k].tolist(), result_ids)
|
||||
logger.info("Query ann_accuracy: %s" % acc_value)
|
||||
metric = self.report_wrapper(milvus_instance, self.env_value, self.hostname,
|
||||
collection_info, index_info, search_param_group)
|
||||
metric.metrics = {
|
||||
"type": "ann_accuracy",
|
||||
"value": {
|
||||
"acc": acc_value
|
||||
}
|
||||
}
|
||||
report(metric)
|
||||
|
||||
elif run_type == "search_stability":
|
||||
(data_type, collection_size, dimension, metric_type) = parser.collection_parser(
|
||||
collection_name)
|
||||
search_params = collection["search_params"]
|
||||
during_time = collection["during_time"]
|
||||
collection_info = {
|
||||
"dimension": dimension,
|
||||
"metric_type": metric_type,
|
||||
"dataset_name": collection_name
|
||||
}
|
||||
if not milvus_instance.exists_collection():
|
||||
logger.error("Table name: %s not existed" % collection_name)
|
||||
return
|
||||
logger.info(milvus_instance.count())
|
||||
index_info = milvus_instance.describe_index()
|
||||
logger.info(index_info)
|
||||
g_top_k = int(collection["top_ks"].split("-")[1])
|
||||
g_nq = int(collection["nqs"].split("-")[1])
|
||||
l_top_k = int(collection["top_ks"].split("-")[0])
|
||||
l_nq = int(collection["nqs"].split("-")[0])
|
||||
milvus_instance.load_collection()
|
||||
# start_mem_usage = milvus_instance.get_mem_info()["memory_used"]
|
||||
# logger.debug(start_mem_usage)
|
||||
start_row_count = milvus_instance.count()
|
||||
logger.debug(milvus_instance.describe_index())
|
||||
logger.info(start_row_count)
|
||||
vector_type = self.get_vector_type(data_type)
|
||||
vec_field_name = utils.get_default_field_name(vector_type)
|
||||
real_metric_type = utils.metric_type_trans(metric_type)
|
||||
start_time = time.time()
|
||||
while time.time() < start_time + during_time * 60:
|
||||
search_param = {}
|
||||
top_k = random.randint(l_top_k, g_top_k)
|
||||
nq = random.randint(l_nq, g_nq)
|
||||
for k, v in search_params.items():
|
||||
search_param[k] = random.randint(int(v.split("-")[0]), int(v.split("-")[1]))
|
||||
query_vectors = [[random.random() for _ in range(dimension)] for _ in range(nq)]
|
||||
logger.debug("Query nq: %d, top-k: %d, param: %s" % (nq, top_k, json.dumps(search_param)))
|
||||
vector_query = {"vector": {vec_field_name: {
|
||||
"topk": top_k,
|
||||
"query": query_vectors[:nq],
|
||||
"metric_type": real_metric_type,
|
||||
"params": search_param}
|
||||
}}
|
||||
milvus_instance.query(vector_query)
|
||||
# end_mem_usage = milvus_instance.get_mem_info()["memory_used"]
|
||||
metric = self.report_wrapper(milvus_instance, self.env_value, self.hostname, collection_info, index_info,
|
||||
{})
|
||||
metric.metrics = {
|
||||
"type": "search_stability",
|
||||
"value": {
|
||||
"during_time": during_time,
|
||||
}
|
||||
}
|
||||
report(metric)
|
||||
|
||||
elif run_type == "loop_stability":
|
||||
# init data
|
||||
milvus_instance.clean_db()
|
||||
pull_interval = collection["pull_interval"]
|
||||
collection_num = collection["collection_num"]
|
||||
concurrent = collection["concurrent"] if "concurrent" in collection else False
|
||||
concurrent_num = collection_num
|
||||
dimension = collection["dimension"] if "dimension" in collection else 128
|
||||
insert_xb = collection["insert_xb"] if "insert_xb" in collection else 100000
|
||||
index_types = collection["index_types"] if "index_types" in collection else ['ivf_sq8']
|
||||
index_param = {"nlist": 256}
|
||||
collection_names = []
|
||||
milvus_instances_map = {}
|
||||
insert_vectors = [[random.random() for _ in range(dimension)] for _ in range(insert_xb)]
|
||||
ids = [i for i in range(insert_xb)]
|
||||
# initialize and prepare
|
||||
for i in range(collection_num):
|
||||
name = utils.get_unique_name(prefix="collection_%d_" % i)
|
||||
collection_names.append(name)
|
||||
metric_type = random.choice(["l2", "ip"])
|
||||
# default float_vector
|
||||
milvus_instance = MilvusClient(collection_name=name, host=self.host)
|
||||
milvus_instance.create_collection(dimension, other_fields=None)
|
||||
index_type = random.choice(index_types)
|
||||
field_name = utils.get_default_field_name()
|
||||
milvus_instance.create_index(field_name, index_type, metric_type, index_param=index_param)
|
||||
logger.info(milvus_instance.describe_index())
|
||||
insert_vectors = utils.normalize(metric_type, insert_vectors)
|
||||
entities = milvus_instance.generate_entities(insert_vectors, ids)
|
||||
res_ids = milvus_instance.insert(entities, ids=ids)
|
||||
milvus_instance.flush()
|
||||
milvus_instances_map.update({name: milvus_instance})
|
||||
logger.info(milvus_instance.describe_index())
|
||||
|
||||
# loop time unit: min -> s
|
||||
pull_interval_seconds = pull_interval * 60
|
||||
tasks = ["insert_rand", "query_rand", "flush"]
|
||||
i = 1
|
||||
while True:
|
||||
logger.info("Loop time: %d" % i)
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < pull_interval_seconds:
|
||||
if concurrent:
|
||||
threads = []
|
||||
for name in collection_names:
|
||||
task_name = random.choice(tasks)
|
||||
task_run = getattr(milvus_instances_map[name], task_name)
|
||||
t = threading.Thread(target=task_run, args=())
|
||||
threads.append(t)
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
# with concurrent.futures.ThreadPoolExecutor(max_workers=concurrent_num) as executor:
|
||||
# future_results = {executor.submit(getattr(milvus_instances_map[mp[j][0]], mp[j][1])): j for j in range(concurrent_num)}
|
||||
# for future in concurrent.futures.as_completed(future_results):
|
||||
# future.result()
|
||||
else:
|
||||
tmp_collection_name = random.choice(collection_names)
|
||||
task_name = random.choice(tasks)
|
||||
logger.info(tmp_collection_name)
|
||||
logger.info(task_name)
|
||||
task_run = getattr(milvus_instances_map[tmp_collection_name], task_name)
|
||||
task_run()
|
||||
|
||||
logger.debug("Restart server")
|
||||
helm_utils.restart_server(self.service_name, namespace)
|
||||
# new connection
|
||||
# for name in collection_names:
|
||||
# milvus_instance = MilvusClient(collection_name=name, host=self.host)
|
||||
# milvus_instances_map.update({name: milvus_instance})
|
||||
time.sleep(30)
|
||||
i = i + 1
|
||||
|
||||
elif run_type == "stability":
|
||||
(data_type, collection_size, dimension, metric_type) = parser.collection_parser(
|
||||
collection_name)
|
||||
during_time = collection["during_time"]
|
||||
operations = collection["operations"]
|
||||
collection_info = {
|
||||
"dimension": dimension,
|
||||
"metric_type": metric_type,
|
||||
"dataset_name": collection_name
|
||||
}
|
||||
if not milvus_instance.exists_collection():
|
||||
logger.error(milvus_instance.show_collections())
|
||||
raise Exception("Table name: %s not existed" % collection_name)
|
||||
logger.info(milvus_instance.count())
|
||||
index_info = milvus_instance.describe_index()
|
||||
logger.info(index_info)
|
||||
# start_mem_usage = milvus_instance.get_mem_info()["memory_used"]
|
||||
start_row_count = milvus_instance.count()
|
||||
logger.info(start_row_count)
|
||||
vector_type = self.get_vector_type(data_type)
|
||||
vec_field_name = utils.get_default_field_name(vector_type)
|
||||
real_metric_type = utils.metric_type_trans(metric_type)
|
||||
query_vectors = [[random.random() for _ in range(dimension)] for _ in range(10000)]
|
||||
if "insert" in operations:
|
||||
insert_xb = operations["insert"]["xb"]
|
||||
if "delete" in operations:
|
||||
delete_xb = operations["delete"]["xb"]
|
||||
if "query" in operations:
|
||||
g_top_k = int(operations["query"]["top_ks"].split("-")[1])
|
||||
l_top_k = int(operations["query"]["top_ks"].split("-")[0])
|
||||
g_nq = int(operations["query"]["nqs"].split("-")[1])
|
||||
l_nq = int(operations["query"]["nqs"].split("-")[0])
|
||||
search_params = operations["query"]["search_params"]
|
||||
i = 0
|
||||
start_time = time.time()
|
||||
while time.time() < start_time + during_time * 60:
|
||||
i = i + 1
|
||||
q = self.gen_executors(operations)
|
||||
for name in q:
|
||||
try:
|
||||
if name == "insert":
|
||||
insert_ids = random.sample(list(range(collection_size)), insert_xb)
|
||||
insert_vectors = [[random.random() for _ in range(dimension)] for _ in range(insert_xb)]
|
||||
entities = milvus_instance.generate_entities(insert_vectors, insert_ids)
|
||||
milvus_instance.insert(entities, ids=insert_ids)
|
||||
elif name == "delete":
|
||||
delete_ids = random.sample(list(range(collection_size)), delete_xb)
|
||||
milvus_instance.delete(delete_ids)
|
||||
elif name == "query":
|
||||
top_k = random.randint(l_top_k, g_top_k)
|
||||
nq = random.randint(l_nq, g_nq)
|
||||
search_param = {}
|
||||
for k, v in search_params.items():
|
||||
search_param[k] = random.randint(int(v.split("-")[0]), int(v.split("-")[1]))
|
||||
logger.debug("Query nq: %d, top-k: %d, param: %s" % (nq, top_k, json.dumps(search_param)))
|
||||
vector_query = {"vector": {vec_field_name: {
|
||||
"topk": top_k,
|
||||
"query": query_vectors[:nq],
|
||||
"metric_type": real_metric_type,
|
||||
"params": search_param}
|
||||
}}
|
||||
result = milvus_instance.query(vector_query)
|
||||
elif name in ["flush", "compact"]:
|
||||
func = getattr(milvus_instance, name)
|
||||
func()
|
||||
logger.debug(milvus_instance.count())
|
||||
except Exception as e:
|
||||
logger.error(name)
|
||||
logger.error(str(e))
|
||||
raise
|
||||
logger.debug("Loop time: %d" % i)
|
||||
# end_mem_usage = milvus_instance.get_mem_info()["memory_used"]
|
||||
end_row_count = milvus_instance.count()
|
||||
metric = self.report_wrapper(milvus_instance, self.env_value, self.hostname, collection_info, index_info,
|
||||
{})
|
||||
metric.metrics = {
|
||||
"type": "stability",
|
||||
"value": {
|
||||
"during_time": during_time,
|
||||
"row_count_increments": end_row_count - start_row_count
|
||||
}
|
||||
}
|
||||
report(metric)
|
||||
|
||||
elif run_type == "debug":
|
||||
time.sleep(7200)
|
||||
default_insert_vectors = [[random.random() for _ in range(128)] for _ in range(500000)]
|
||||
interval = 50000
|
||||
for loop in range(1, 7):
|
||||
insert_xb = loop * interval
|
||||
insert_vectors = default_insert_vectors[:insert_xb]
|
||||
insert_ids = [i for i in range(insert_xb)]
|
||||
entities = milvus_instance.generate_entities(insert_vectors, insert_ids)
|
||||
for j in range(5):
|
||||
milvus_instance.insert(entities, ids=insert_ids)
|
||||
time.sleep(10)
|
||||
|
||||
else:
|
||||
raise Exception("Run type not defined")
|
||||
logger.debug("All test finished")
|
732
tests/benchmark/local_runner.py
Normal file
732
tests/benchmark/local_runner.py
Normal file
@ -0,0 +1,732 @@
|
||||
import os
|
||||
import logging
|
||||
import pdb
|
||||
import string
|
||||
import time
|
||||
import random
|
||||
import json
|
||||
import csv
|
||||
from multiprocessing import Process
|
||||
import numpy as np
|
||||
import concurrent.futures
|
||||
from queue import Queue
|
||||
|
||||
import locust_user
|
||||
from milvus import DataType
|
||||
from client import MilvusClient
|
||||
from runner import Runner
|
||||
import utils
|
||||
import parser
|
||||
|
||||
|
||||
DELETE_INTERVAL_TIME = 5
|
||||
INSERT_INTERVAL = 50000
|
||||
logger = logging.getLogger("milvus_benchmark.local_runner")
|
||||
|
||||
|
||||
class LocalRunner(Runner):
|
||||
"""run local mode"""
|
||||
def __init__(self, host, port):
|
||||
super(LocalRunner, self).__init__()
|
||||
self.host = host
|
||||
self.port = port
|
||||
|
||||
def run(self, run_type, collection):
|
||||
logger.debug(run_type)
|
||||
logger.debug(collection)
|
||||
collection_name = collection["collection_name"] if "collection_name" in collection else None
|
||||
milvus_instance = MilvusClient(collection_name=collection_name, host=self.host, port=self.port)
|
||||
logger.info(milvus_instance.show_collections())
|
||||
# TODO:
|
||||
# self.env_value = milvus_instance.get_server_config()
|
||||
# ugly implemention
|
||||
# self.env_value = utils.convert_nested(self.env_value)
|
||||
# self.env_value.pop("logs")
|
||||
# self.env_value.pop("network")
|
||||
# logger.info(self.env_value)
|
||||
|
||||
if run_type in ["insert_performance", "insert_flush_performance"]:
|
||||
(data_type, collection_size, dimension, metric_type) = parser.collection_parser(collection_name)
|
||||
ni_per = collection["ni_per"]
|
||||
build_index = collection["build_index"]
|
||||
if milvus_instance.exists_collection():
|
||||
milvus_instance.drop()
|
||||
time.sleep(10)
|
||||
vector_type = self.get_vector_type(data_type)
|
||||
other_fields = collection["other_fields"] if "other_fields" in collection else None
|
||||
milvus_instance.create_collection(dimension, data_type=vector_type, other_fields=other_fields)
|
||||
if build_index is True:
|
||||
index_type = collection["index_type"]
|
||||
index_param = collection["index_param"]
|
||||
index_field_name = utils.get_default_field_name(vector_type)
|
||||
milvus_instance.create_index(index_field_name, index_type, metric_type, index_param=index_param)
|
||||
res = self.do_insert(milvus_instance, collection_name, data_type, dimension, collection_size, ni_per)
|
||||
milvus_instance.flush()
|
||||
logger.debug("Table row counts: %d" % milvus_instance.count())
|
||||
if build_index is True:
|
||||
logger.debug("Start build index for last file")
|
||||
milvus_instance.create_index(index_field_name, index_type, metric_type, index_param=index_param)
|
||||
|
||||
elif run_type == "delete_performance":
|
||||
(data_type, collection_size, dimension, metric_type) = parser.collection_parser(collection_name)
|
||||
ni_per = collection["ni_per"]
|
||||
auto_flush = collection["auto_flush"] if "auto_flush" in collection else True
|
||||
if not milvus_instance.exists_collection():
|
||||
logger.error(milvus_instance.show_collections())
|
||||
logger.error("Table: %s not found" % collection_name)
|
||||
return
|
||||
length = milvus_instance.count()
|
||||
ids = [i for i in range(length)]
|
||||
loops = int(length / ni_per)
|
||||
if auto_flush is False:
|
||||
milvus_instance.set_config("storage", "auto_flush_interval", BIG_FLUSH_INTERVAL)
|
||||
for i in range(loops):
|
||||
delete_ids = ids[i*ni_per: i*ni_per+ni_per]
|
||||
logger.debug("Delete %d - %d" % (delete_ids[0], delete_ids[-1]))
|
||||
milvus_instance.delete(delete_ids)
|
||||
logger.debug("Table row counts: %d" % milvus_instance.count())
|
||||
logger.debug("Table row counts: %d" % milvus_instance.count())
|
||||
milvus_instance.flush()
|
||||
logger.debug("Table row counts: %d" % milvus_instance.count())
|
||||
|
||||
elif run_type == "build_performance":
|
||||
(data_type, collection_size, dimension, metric_type) = parser.collection_parser(collection_name)
|
||||
index_type = collection["index_type"]
|
||||
index_param = collection["index_param"]
|
||||
if not milvus_instance.exists_collection():
|
||||
logger.error("Table name: %s not existed" % collection_name)
|
||||
return
|
||||
vector_type = self.get_vector_type(data_type)
|
||||
index_field_name = utils.get_default_field_name(vector_type)
|
||||
# drop index
|
||||
logger.debug("Drop index")
|
||||
milvus_instance.drop_index(index_field_name)
|
||||
start_mem_usage = milvus_instance.get_mem_info()["memory_used"]
|
||||
start_time = time.time()
|
||||
milvus_instance.create_index(index_field_name, index_type, metric_type, index_param=index_param)
|
||||
end_time = time.time()
|
||||
logger.debug("Table row counts: %d" % milvus_instance.count())
|
||||
end_mem_usage = milvus_instance.get_mem_info()["memory_used"]
|
||||
logger.debug("Diff memory: %s, current memory usage: %s, build time: %s" % ((end_mem_usage - start_mem_usage), end_mem_usage, round(end_time - start_time, 1)))
|
||||
|
||||
elif run_type == "search_performance":
|
||||
(data_type, collection_size, dimension, metric_type) = parser.collection_parser(collection_name)
|
||||
run_count = collection["run_count"]
|
||||
top_ks = collection["top_ks"]
|
||||
nqs = collection["nqs"]
|
||||
search_params = collection["search_params"]
|
||||
filter_query = []
|
||||
filters = collection["filters"] if "filters" in collection else []
|
||||
# pdb.set_trace()
|
||||
# ranges = collection["range"] if "range" in collection else None
|
||||
# terms = collection["term"] if "term" in collection else None
|
||||
# if ranges:
|
||||
# filter_query.append(eval(ranges))
|
||||
# if terms:
|
||||
# filter_query.append(eval(terms))
|
||||
vector_type = self.get_vector_type(data_type)
|
||||
vec_field_name = utils.get_default_field_name(vector_type)
|
||||
# for debugging
|
||||
# time.sleep(3600)
|
||||
if not milvus_instance.exists_collection():
|
||||
logger.error("Table name: %s not existed" % collection_name)
|
||||
return
|
||||
vector_type = self.get_vector_type(data_type)
|
||||
vec_field_name = utils.get_default_field_name(vector_type)
|
||||
logger.info(milvus_instance.count())
|
||||
result = milvus_instance.describe_index()
|
||||
logger.info(result)
|
||||
milvus_instance.preload_collection()
|
||||
mem_usage = milvus_instance.get_mem_info()["memory_used"]
|
||||
logger.info(mem_usage)
|
||||
for search_param in search_params:
|
||||
logger.info("Search param: %s" % json.dumps(search_param))
|
||||
filter_param = []
|
||||
if not filters:
|
||||
filters.append(None)
|
||||
for filter in filters:
|
||||
if isinstance(filter, dict) and "range" in filter:
|
||||
filter_query.append(eval(filter["range"]))
|
||||
filter_param.append(filter["range"])
|
||||
if isinstance(filter, dict) and "term" in filter:
|
||||
filter_query.append(eval(filter["term"]))
|
||||
filter_param.append(filter["term"])
|
||||
logger.info("filter param: %s" % json.dumps(filter_param))
|
||||
res = self.do_query(milvus_instance, collection_name, vec_field_name, top_ks, nqs, run_count, search_param, filter_query)
|
||||
headers = ["Nq/Top-k"]
|
||||
headers.extend([str(top_k) for top_k in top_ks])
|
||||
logger.info("Search param: %s" % json.dumps(search_param))
|
||||
utils.print_table(headers, nqs, res)
|
||||
mem_usage = milvus_instance.get_mem_info()["memory_used"]
|
||||
logger.info(mem_usage)
|
||||
|
||||
elif run_type == "locust_search_performance":
|
||||
(data_type, collection_size, dimension, metric_type) = parser.collection_parser(collection_name)
|
||||
ni_per = collection["ni_per"]
|
||||
build_index = collection["build_index"]
|
||||
vector_type = self.get_vector_type(data_type)
|
||||
index_field_name = utils.get_default_field_name(vector_type)
|
||||
# if build_index is True:
|
||||
# index_type = collection["index_type"]
|
||||
# index_param = collection["index_param"]
|
||||
# # TODO: debug
|
||||
# if milvus_instance.exists_collection():
|
||||
# milvus_instance.drop()
|
||||
# time.sleep(10)
|
||||
# other_fields = collection["other_fields"] if "other_fields" in collection else None
|
||||
# milvus_instance.create_collection(dimension, data_type=vector_type, other_fields=other_fields)
|
||||
# milvus_instance.create_index(index_field_name, index_type, metric_type, index_param=index_param)
|
||||
# res = self.do_insert(milvus_instance, collection_name, data_type, dimension, collection_size, ni_per)
|
||||
# milvus_instance.flush()
|
||||
# logger.debug("Table row counts: %d" % milvus_instance.count())
|
||||
# if build_index is True:
|
||||
# logger.debug("Start build index for last file")
|
||||
# milvus_instance.create_index(index_field_name, index_type, metric_type, index_param=index_param)
|
||||
real_metric_type = utils.metric_type_trans(metric_type)
|
||||
### spawn locust requests
|
||||
task = collection["task"]
|
||||
connection_type = "single"
|
||||
connection_num = task["connection_num"]
|
||||
if connection_num > 1:
|
||||
connection_type = "multi"
|
||||
clients_num = task["clients_num"]
|
||||
hatch_rate = task["hatch_rate"]
|
||||
during_time = utils.timestr_to_int(task["during_time"])
|
||||
task_types = task["types"]
|
||||
# """
|
||||
# task:
|
||||
# connection_num: 1
|
||||
# clients_num: 100
|
||||
# hatch_rate: 2
|
||||
# during_time: 5m
|
||||
# types:
|
||||
# -
|
||||
# type: query
|
||||
# weight: 1
|
||||
# params:
|
||||
# top_k: 10
|
||||
# nq: 1
|
||||
# # filters:
|
||||
# # -
|
||||
# # range:
|
||||
# # int64:
|
||||
# # LT: 0
|
||||
# # GT: 1000000
|
||||
# search_param:
|
||||
# nprobe: 16
|
||||
# """
|
||||
run_params = {"tasks": {}, "clients_num": clients_num, "spawn_rate": hatch_rate, "during_time": during_time}
|
||||
for task_type in task_types:
|
||||
run_params["tasks"].update({task_type["type"]: task_type["weight"] if "weight" in task_type else 1})
|
||||
|
||||
#. collect stats
|
||||
locust_stats = locust_user.locust_executor(self.host, self.port, collection_name, connection_type=connection_type, run_params=run_params)
|
||||
logger.info(locust_stats)
|
||||
|
||||
elif run_type == "search_ids_stability":
|
||||
(data_type, collection_size, dimension, metric_type) = parser.collection_parser(collection_name)
|
||||
search_params = collection["search_params"]
|
||||
during_time = collection["during_time"]
|
||||
ids_length = collection["ids_length"]
|
||||
ids = collection["ids"]
|
||||
logger.info(milvus_instance.count())
|
||||
index_info = milvus_instance.describe_index()
|
||||
logger.info(index_info)
|
||||
g_top_k = int(collection["top_ks"].split("-")[1])
|
||||
l_top_k = int(collection["top_ks"].split("-")[0])
|
||||
g_id = int(ids.split("-")[1])
|
||||
l_id = int(ids.split("-")[0])
|
||||
g_id_length = int(ids_length.split("-")[1])
|
||||
l_id_length = int(ids_length.split("-")[0])
|
||||
|
||||
milvus_instance.preload_collection()
|
||||
start_mem_usage = milvus_instance.get_mem_info()["memory_used"]
|
||||
logger.debug(start_mem_usage)
|
||||
start_time = time.time()
|
||||
while time.time() < start_time + during_time * 60:
|
||||
search_param = {}
|
||||
top_k = random.randint(l_top_k, g_top_k)
|
||||
ids_num = random.randint(l_id_length, g_id_length)
|
||||
l_ids = random.randint(l_id, g_id-ids_num)
|
||||
# ids_param = [random.randint(l_id_length, g_id_length) for _ in range(ids_num)]
|
||||
ids_param = [id for id in range(l_ids, l_ids+ids_num)]
|
||||
for k, v in search_params.items():
|
||||
search_param[k] = random.randint(int(v.split("-")[0]), int(v.split("-")[1]))
|
||||
logger.debug("Query top-k: %d, ids_num: %d, param: %s" % (top_k, ids_num, json.dumps(search_param)))
|
||||
result = milvus_instance.query_ids(top_k, ids_param, search_param=search_param)
|
||||
end_mem_usage = milvus_instance.get_mem_info()["memory_used"]
|
||||
metrics = {
|
||||
"during_time": during_time,
|
||||
"start_mem_usage": start_mem_usage,
|
||||
"end_mem_usage": end_mem_usage,
|
||||
"diff_mem": end_mem_usage - start_mem_usage,
|
||||
}
|
||||
logger.info(metrics)
|
||||
|
||||
elif run_type == "search_performance_concurrents":
|
||||
data_type, dimension, metric_type = parser.parse_ann_collection_name(collection_name)
|
||||
hdf5_source_file = collection["source_file"]
|
||||
use_single_connection = collection["use_single_connection"]
|
||||
concurrents = collection["concurrents"]
|
||||
top_ks = collection["top_ks"]
|
||||
nqs = collection["nqs"]
|
||||
search_params = self.generate_combinations(collection["search_params"])
|
||||
if not milvus_instance.exists_collection():
|
||||
logger.error("Table name: %s not existed" % collection_name)
|
||||
return
|
||||
logger.info(milvus_instance.count())
|
||||
result = milvus_instance.describe_index()
|
||||
logger.info(result)
|
||||
milvus_instance.preload_collection()
|
||||
dataset = utils.get_dataset(hdf5_source_file)
|
||||
for concurrent_num in concurrents:
|
||||
top_k = top_ks[0]
|
||||
for nq in nqs:
|
||||
mem_usage = milvus_instance.get_mem_info()["memory_used"]
|
||||
logger.info(mem_usage)
|
||||
query_vectors = self.normalize(metric_type, np.array(dataset["test"][:nq]))
|
||||
logger.debug(search_params)
|
||||
for search_param in search_params:
|
||||
logger.info("Search param: %s" % json.dumps(search_param))
|
||||
total_time = 0.0
|
||||
if use_single_connection is True:
|
||||
connections = [MilvusClient(collection_name=collection_name, host=self.host, port=self.port)]
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=concurrent_num) as executor:
|
||||
future_results = {executor.submit(
|
||||
self.do_query_qps, connections[0], query_vectors, top_k, search_param=search_param) : index for index in range(concurrent_num)}
|
||||
else:
|
||||
connections = [MilvusClient(collection_name=collection_name, host=self.hos, port=self.port) for i in range(concurrent_num)]
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=concurrent_num) as executor:
|
||||
future_results = {executor.submit(
|
||||
self.do_query_qps, connections[index], query_vectors, top_k, search_param=search_param) : index for index in range(concurrent_num)}
|
||||
for future in concurrent.futures.as_completed(future_results):
|
||||
total_time = total_time + future.result()
|
||||
qps_value = total_time / concurrent_num
|
||||
logger.debug("QPS value: %f, total_time: %f, request_nums: %f" % (qps_value, total_time, concurrent_num))
|
||||
mem_usage = milvus_instance.get_mem_info()["memory_used"]
|
||||
logger.info(mem_usage)
|
||||
|
||||
elif run_type == "ann_accuracy":
|
||||
hdf5_source_file = collection["source_file"]
|
||||
collection_name = collection["collection_name"]
|
||||
index_types = collection["index_types"]
|
||||
index_params = collection["index_params"]
|
||||
top_ks = collection["top_ks"]
|
||||
nqs = collection["nqs"]
|
||||
search_params = collection["search_params"]
|
||||
# mapping to search param list
|
||||
search_params = self.generate_combinations(search_params)
|
||||
# mapping to index param list
|
||||
index_params = self.generate_combinations(index_params)
|
||||
data_type, dimension, metric_type = parser.parse_ann_collection_name(collection_name)
|
||||
dataset = utils.get_dataset(hdf5_source_file)
|
||||
true_ids = np.array(dataset["neighbors"])
|
||||
vector_type = self.get_vector_type_from_metric(metric_type)
|
||||
vec_field_name = utils.get_default_field_name(vector_type)
|
||||
real_metric_type = utils.metric_type_trans(metric_type)
|
||||
|
||||
# re-create collection
|
||||
if milvus_instance.exists_collection(collection_name):
|
||||
milvus_instance.drop()
|
||||
time.sleep(DELETE_INTERVAL_TIME)
|
||||
milvus_instance.create_collection(dimension, data_type=vector_type)
|
||||
insert_vectors = self.normalize(metric_type, np.array(dataset["train"]))
|
||||
if len(insert_vectors) != dataset["train"].shape[0]:
|
||||
raise Exception("Row count of insert vectors: %d is not equal to dataset size: %d" % (len(insert_vectors), dataset["train"].shape[0]))
|
||||
logger.debug("The row count of entities to be inserted: %d" % len(insert_vectors))
|
||||
# insert batch once
|
||||
# milvus_instance.insert(insert_vectors)
|
||||
loops = len(insert_vectors) // INSERT_INTERVAL + 1
|
||||
for i in range(loops):
|
||||
start = i*INSERT_INTERVAL
|
||||
end = min((i+1)*INSERT_INTERVAL, len(insert_vectors))
|
||||
if start < end:
|
||||
tmp_vectors = insert_vectors[start:end]
|
||||
ids = [i for i in range(start, end)]
|
||||
if not isinstance(tmp_vectors, list):
|
||||
entities = milvus_instance.generate_entities(tmp_vectors.tolist(), ids)
|
||||
res_ids = milvus_instance.insert(entities, ids=ids)
|
||||
else:
|
||||
entities = milvus_instance.generate_entities(tmp_vectors, ids)
|
||||
res_ids = milvus_instance.insert(entities, ids=ids)
|
||||
assert res_ids == ids
|
||||
milvus_instance.flush()
|
||||
res_count = milvus_instance.count()
|
||||
logger.info("Table: %s, row count: %d" % (collection_name, res_count))
|
||||
if res_count != len(insert_vectors):
|
||||
raise Exception("Table row count is not equal to insert vectors")
|
||||
for index_type in index_types:
|
||||
for index_param in index_params:
|
||||
logger.debug("Building index with param: %s, metric_type: %s" % (json.dumps(index_param), metric_type))
|
||||
milvus_instance.create_index(vec_field_name, index_type, metric_type, index_param=index_param)
|
||||
logger.info("Start preload collection: %s" % collection_name)
|
||||
milvus_instance.preload_collection()
|
||||
for search_param in search_params:
|
||||
for nq in nqs:
|
||||
query_vectors = self.normalize(metric_type, np.array(dataset["test"][:nq]))
|
||||
if not isinstance(query_vectors, list):
|
||||
query_vectors = query_vectors.tolist()
|
||||
for top_k in top_ks:
|
||||
logger.debug("Search nq: %d, top-k: %d, search_param: %s, metric_type: %s" % (nq, top_k, json.dumps(search_param), metric_type))
|
||||
vector_query = {"vector": {vec_field_name: {
|
||||
"topk": top_k,
|
||||
"query": query_vectors,
|
||||
"metric_type": real_metric_type,
|
||||
"params": search_param}
|
||||
}}
|
||||
result = milvus_instance.query(vector_query)
|
||||
result_ids = milvus_instance.get_ids(result)
|
||||
# pdb.set_trace()
|
||||
acc_value = self.get_recall_value(true_ids[:nq, :top_k].tolist(), result_ids)
|
||||
logger.info("Query ann_accuracy: %s" % acc_value)
|
||||
|
||||
elif run_type == "accuracy":
|
||||
(data_type, collection_size, dimension, metric_type) = parser.collection_parser(collection_name)
|
||||
search_params = collection["search_params"]
|
||||
# mapping to search param list
|
||||
search_params = self.generate_combinations(search_params)
|
||||
|
||||
top_ks = collection["top_ks"]
|
||||
nqs = collection["nqs"]
|
||||
collection_info = {
|
||||
"dimension": dimension,
|
||||
"metric_type": metric_type,
|
||||
"dataset_name": collection_name
|
||||
}
|
||||
if not milvus_instance.exists_collection():
|
||||
logger.error("Table name: %s not existed" % collection_name)
|
||||
return
|
||||
logger.info(milvus_instance.count())
|
||||
index_info = milvus_instance.describe_index()
|
||||
logger.info(index_info)
|
||||
milvus_instance.preload_collection()
|
||||
true_ids_all = self.get_groundtruth_ids(collection_size)
|
||||
vector_type = self.get_vector_type(data_type)
|
||||
vec_field_name = utils.get_default_field_name(vector_type)
|
||||
for search_param in search_params:
|
||||
headers = ["Nq/Top-k"]
|
||||
res = []
|
||||
for nq in nqs:
|
||||
tmp_res = []
|
||||
for top_k in top_ks:
|
||||
search_param_group = {
|
||||
"nq": nq,
|
||||
"topk": top_k,
|
||||
"search_param": search_param,
|
||||
"metric_type": metric_type
|
||||
}
|
||||
logger.info("Query params: %s" % json.dumps(search_param_group))
|
||||
result_ids = self.do_query_ids(milvus_instance, collection_name, vec_field_name, top_k, nq, search_param=search_param)
|
||||
mem_used = milvus_instance.get_mem_info()["memory_used"]
|
||||
acc_value = self.get_recall_value(true_ids_all[:nq, :top_k].tolist(), result_ids)
|
||||
logger.info("Query accuracy: %s" % acc_value)
|
||||
tmp_res.append(acc_value)
|
||||
logger.info("Memory usage: %s" % mem_used)
|
||||
res.append(tmp_res)
|
||||
headers.extend([str(top_k) for top_k in top_ks])
|
||||
logger.info("Search param: %s" % json.dumps(search_param))
|
||||
utils.print_table(headers, nqs, res)
|
||||
|
||||
elif run_type == "stability":
|
||||
(data_type, collection_size, dimension, metric_type) = parser.collection_parser(collection_name)
|
||||
during_time = collection["during_time"]
|
||||
operations = collection["operations"]
|
||||
if not milvus_instance.exists_collection():
|
||||
logger.error(milvus_instance.show_collections())
|
||||
raise Exception("Table name: %s not existed" % collection_name)
|
||||
milvus_instance.preload_collection()
|
||||
start_mem_usage = milvus_instance.get_mem_info()["memory_used"]
|
||||
start_row_count = milvus_instance.count()
|
||||
logger.info(start_row_count)
|
||||
vector_type = self.get_vector_type(data_type)
|
||||
vec_field_name = utils.get_default_field_name(vector_type)
|
||||
real_metric_type = utils.metric_type_trans(metric_type)
|
||||
query_vectors = [[random.random() for _ in range(dimension)] for _ in range(10000)]
|
||||
if "insert" in operations:
|
||||
insert_xb = operations["insert"]["xb"]
|
||||
if "delete" in operations:
|
||||
delete_xb = operations["delete"]["xb"]
|
||||
if "query" in operations:
|
||||
g_top_k = int(operations["query"]["top_ks"].split("-")[1])
|
||||
l_top_k = int(operations["query"]["top_ks"].split("-")[0])
|
||||
g_nq = int(operations["query"]["nqs"].split("-")[1])
|
||||
l_nq = int(operations["query"]["nqs"].split("-")[0])
|
||||
search_params = operations["query"]["search_params"]
|
||||
i = 0
|
||||
start_time = time.time()
|
||||
while time.time() < start_time + during_time * 60:
|
||||
i = i + 1
|
||||
q = self.gen_executors(operations)
|
||||
for name in q:
|
||||
try:
|
||||
if name == "insert":
|
||||
insert_ids = random.sample(list(range(collection_size)), insert_xb)
|
||||
insert_vectors = [[random.random() for _ in range(dimension)] for _ in range(insert_xb)]
|
||||
entities = milvus_instance.generate_entities(insert_vectors, insert_ids)
|
||||
milvus_instance.insert(entities, ids=insert_ids)
|
||||
elif name == "delete":
|
||||
delete_ids = random.sample(list(range(collection_size)), delete_xb)
|
||||
milvus_instance.delete(delete_ids)
|
||||
elif name == "query":
|
||||
top_k = random.randint(l_top_k, g_top_k)
|
||||
nq = random.randint(l_nq, g_nq)
|
||||
search_param = {}
|
||||
for k, v in search_params.items():
|
||||
search_param[k] = random.randint(int(v.split("-")[0]), int(v.split("-")[1]))
|
||||
logger.debug("Query nq: %d, top-k: %d, param: %s" % (nq, top_k, json.dumps(search_param)))
|
||||
vector_query = {"vector": {vec_field_name: {
|
||||
"topk": top_k,
|
||||
"query": query_vectors[:nq],
|
||||
"metric_type": real_metric_type,
|
||||
"params": search_param}
|
||||
}}
|
||||
result = milvus_instance.query(vector_query)
|
||||
elif name in ["flush", "compact"]:
|
||||
func = getattr(milvus_instance, name)
|
||||
func()
|
||||
logger.debug(milvus_instance.count())
|
||||
except Exception as e:
|
||||
logger.error(name)
|
||||
logger.error(str(e))
|
||||
raise
|
||||
logger.debug("Loop time: %d" % i)
|
||||
end_mem_usage = milvus_instance.get_mem_info()["memory_used"]
|
||||
end_row_count = milvus_instance.count()
|
||||
metrics = {
|
||||
"during_time": during_time,
|
||||
"start_mem_usage": start_mem_usage,
|
||||
"end_mem_usage": end_mem_usage,
|
||||
"diff_mem": end_mem_usage - start_mem_usage,
|
||||
"row_count_increments": end_row_count - start_row_count
|
||||
}
|
||||
logger.info(metrics)
|
||||
|
||||
elif run_type == "loop_stability":
|
||||
# init data
|
||||
milvus_instance.clean_db()
|
||||
pull_interval = collection["pull_interval"]
|
||||
collection_num = collection["collection_num"]
|
||||
concurrent = collection["concurrent"] if "concurrent" in collection else False
|
||||
concurrent_num = collection_num
|
||||
dimension = collection["dimension"] if "dimension" in collection else 128
|
||||
insert_xb = collection["insert_xb"] if "insert_xb" in collection else 100000
|
||||
index_types = collection["index_types"] if "index_types" in collection else ['ivf_sq8']
|
||||
index_param = {"nlist": 256}
|
||||
collection_names = []
|
||||
milvus_instances_map = {}
|
||||
insert_vectors = [[random.random() for _ in range(dimension)] for _ in range(insert_xb)]
|
||||
ids = [i for i in range(insert_xb)]
|
||||
# initialize and prepare
|
||||
for i in range(collection_num):
|
||||
name = utils.get_unique_name(prefix="collection_%d_" % i)
|
||||
collection_names.append(name)
|
||||
metric_type = random.choice(["l2", "ip"])
|
||||
# default float_vector
|
||||
milvus_instance = MilvusClient(collection_name=name, host=self.host)
|
||||
milvus_instance.create_collection(dimension, other_fields=None)
|
||||
index_type = random.choice(index_types)
|
||||
field_name = utils.get_default_field_name()
|
||||
milvus_instance.create_index(field_name, index_type, metric_type, index_param=index_param)
|
||||
logger.info(milvus_instance.describe_index())
|
||||
insert_vectors = utils.normalize(metric_type, insert_vectors)
|
||||
entities = milvus_instance.generate_entities(insert_vectors, ids)
|
||||
res_ids = milvus_instance.insert(entities, ids=ids)
|
||||
milvus_instance.flush()
|
||||
milvus_instances_map.update({name: milvus_instance})
|
||||
logger.info(milvus_instance.describe_index())
|
||||
|
||||
# loop time unit: min -> s
|
||||
pull_interval_seconds = pull_interval * 60
|
||||
tasks = ["insert_rand", "delete_rand", "query_rand", "flush", "compact"]
|
||||
i = 1
|
||||
while True:
|
||||
logger.info("Loop time: %d" % i)
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < pull_interval_seconds:
|
||||
if concurrent:
|
||||
threads = []
|
||||
for name in collection_names:
|
||||
task_name = random.choice(tasks)
|
||||
task_run = getattr(milvus_instances_map[name], task_name)
|
||||
t = threading.Thread(target=task_run, args=())
|
||||
threads.append(t)
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
# with concurrent.futures.ThreadPoolExecutor(max_workers=concurrent_num) as executor:
|
||||
# future_results = {executor.submit(getattr(milvus_instances_map[mp[j][0]], mp[j][1])): j for j in range(concurrent_num)}
|
||||
# for future in concurrent.futures.as_completed(future_results):
|
||||
# future.result()
|
||||
else:
|
||||
tmp_collection_name = random.choice(collection_names)
|
||||
task_name = random.choice(tasks)
|
||||
logger.info(tmp_collection_name)
|
||||
logger.info(task_name)
|
||||
task_run = getattr(milvus_instances_map[tmp_collection_name], task_name)
|
||||
task_run()
|
||||
# new connection
|
||||
# for name in collection_names:
|
||||
# milvus_instance = MilvusClient(collection_name=name, host=self.host)
|
||||
# milvus_instances_map.update({name: milvus_instance})
|
||||
i = i + 1
|
||||
|
||||
elif run_type == "locust_mix_performance":
|
||||
(data_type, collection_size, dimension, metric_type) = parser.collection_parser(
|
||||
collection_name)
|
||||
ni_per = collection["ni_per"]
|
||||
build_index = collection["build_index"]
|
||||
vector_type = self.get_vector_type(data_type)
|
||||
index_field_name = utils.get_default_field_name(vector_type)
|
||||
# drop exists collection
|
||||
if milvus_instance.exists_collection():
|
||||
milvus_instance.drop()
|
||||
time.sleep(10)
|
||||
# create collection
|
||||
other_fields = collection["other_fields"] if "other_fields" in collection else None
|
||||
milvus_instance.create_collection(dimension, data_type=DataType.FLOAT_VECTOR, collection_name=collection_name, other_fields=other_fields)
|
||||
logger.info(milvus_instance.get_info())
|
||||
# insert entities
|
||||
insert_vectors = [[random.random() for _ in range(dimension)] for _ in range(ni_per)]
|
||||
insert_ids = random.sample(list(range(collection_size)), ni_per)
|
||||
insert_vectors = utils.normalize(metric_type, insert_vectors)
|
||||
entities = milvus_instance.generate_entities(insert_vectors, insert_ids, collection_name)
|
||||
milvus_instance.insert(entities, ids=insert_ids)
|
||||
# flush
|
||||
milvus_instance.flush()
|
||||
logger.info(milvus_instance.get_stats())
|
||||
logger.debug("Table row counts: %d" % milvus_instance.count())
|
||||
# create index
|
||||
if build_index is True:
|
||||
index_type = collection["index_type"]
|
||||
index_param = collection["index_param"]
|
||||
logger.debug("Start build index for last file")
|
||||
milvus_instance.create_index(index_field_name, index_type, metric_type, index_param)
|
||||
logger.debug(milvus_instance.describe_index())
|
||||
# locust
|
||||
task = collection["tasks"]
|
||||
task_file = utils.get_unique_name()
|
||||
task_file_script = task_file + '.py'
|
||||
task_file_csv = task_file + '_stats.csv'
|
||||
task_types = task["types"]
|
||||
connection_type = "single"
|
||||
connection_num = task["connection_num"]
|
||||
if connection_num > 1:
|
||||
connection_type = "multi"
|
||||
clients_num = task["clients_num"]
|
||||
hatch_rate = task["hatch_rate"]
|
||||
during_time = task["during_time"]
|
||||
def_strs = ""
|
||||
# define def str
|
||||
for task_type in task_types:
|
||||
type = task_type["type"]
|
||||
weight = task_type["weight"]
|
||||
if type == "flush":
|
||||
def_str = """
|
||||
@task(%d)
|
||||
def flush(self):
|
||||
client = get_client(collection_name)
|
||||
client.flush(collection_name=collection_name)
|
||||
""" % weight
|
||||
if type == "compact":
|
||||
def_str = """
|
||||
@task(%d)
|
||||
def compact(self):
|
||||
client = get_client(collection_name)
|
||||
client.compact(collection_name)
|
||||
""" % weight
|
||||
if type == "query":
|
||||
def_str = """
|
||||
@task(%d)
|
||||
def query(self):
|
||||
client = get_client(collection_name)
|
||||
params = %s
|
||||
X = [[random.random() for i in range(dim)] for i in range(params["nq"])]
|
||||
vector_query = {"vector": {"%s": {
|
||||
"topk": params["top_k"],
|
||||
"query": X,
|
||||
"metric_type": "%s",
|
||||
"params": params["search_param"]}}}
|
||||
client.query(vector_query, filter_query=params["filters"], collection_name=collection_name)
|
||||
""" % (weight, task_type["params"], index_field_name, utils.metric_type_trans(metric_type))
|
||||
if type == "insert":
|
||||
def_str = """
|
||||
@task(%d)
|
||||
def insert(self):
|
||||
client = get_client(collection_name)
|
||||
params = %s
|
||||
insert_ids = random.sample(list(range(100000)), params["nb"])
|
||||
insert_vectors = [[random.random() for _ in range(dim)] for _ in range(params["nb"])]
|
||||
insert_vectors = utils.normalize("l2", insert_vectors)
|
||||
entities = generate_entities(insert_vectors, insert_ids)
|
||||
client.insert(entities,ids=insert_ids, collection_name=collection_name)
|
||||
""" % (weight, task_type["params"])
|
||||
if type == "delete":
|
||||
def_str = """
|
||||
@task(%d)
|
||||
def delete(self):
|
||||
client = get_client(collection_name)
|
||||
ids = [random.randint(1, 1000000) for i in range(1)]
|
||||
client.delete(ids, collection_name)
|
||||
""" % weight
|
||||
def_strs += def_str
|
||||
print(def_strs)
|
||||
# define locust code str
|
||||
code_str = """
|
||||
import random
|
||||
import json
|
||||
from locust import User, task, between
|
||||
from locust_task import MilvusTask
|
||||
from client import MilvusClient
|
||||
import utils
|
||||
|
||||
host = '%s'
|
||||
port = %s
|
||||
collection_name = '%s'
|
||||
dim = %s
|
||||
connection_type = '%s'
|
||||
m = MilvusClient(host=host, port=port)
|
||||
|
||||
|
||||
def get_client(collection_name):
|
||||
if connection_type == 'single':
|
||||
return MilvusTask(m=m)
|
||||
elif connection_type == 'multi':
|
||||
return MilvusTask(connection_type='multi', host=host, port=port, collection_name=collection_name)
|
||||
|
||||
|
||||
def generate_entities(vectors, ids):
|
||||
return m.generate_entities(vectors, ids, collection_name)
|
||||
|
||||
|
||||
class MixTask(User):
|
||||
wait_time = between(0.001, 0.002)
|
||||
%s
|
||||
""" % (self.host, self.port, collection_name, dimension, connection_type, def_strs)
|
||||
with open(task_file_script, "w+") as fd:
|
||||
fd.write(code_str)
|
||||
locust_cmd = "locust -f %s --headless --csv=%s -u %d -r %d -t %s" % (
|
||||
task_file_script,
|
||||
task_file,
|
||||
clients_num,
|
||||
hatch_rate,
|
||||
during_time)
|
||||
logger.info(locust_cmd)
|
||||
try:
|
||||
res = os.system(locust_cmd)
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
return
|
||||
|
||||
# . retrieve and collect test statistics
|
||||
metric = None
|
||||
with open(task_file_csv, newline='') as fd:
|
||||
dr = csv.DictReader(fd)
|
||||
for row in dr:
|
||||
if row["Name"] != "Aggregated":
|
||||
continue
|
||||
metric = row
|
||||
logger.info(metric)
|
||||
|
||||
else:
|
||||
raise Exception("Run type not defined")
|
||||
logger.debug("All test finished")
|
30
tests/benchmark/locust_file.py
Normal file
30
tests/benchmark/locust_file.py
Normal file
@ -0,0 +1,30 @@
|
||||
|
||||
import random
|
||||
from locust import HttpUser, task, between
|
||||
|
||||
|
||||
collection_name = "random_1m_2048_512_ip_sq8"
|
||||
headers = {'Content-Type': "application/json"}
|
||||
url = '/collections/%s/vectors' % collection_name
|
||||
top_k = 2
|
||||
nq = 1
|
||||
dim = 512
|
||||
vectors = [[random.random() for _ in range(dim)] for _ in range(nq)]
|
||||
data = {
|
||||
"search":{
|
||||
"topk": top_k,
|
||||
"vectors": vectors,
|
||||
"params": {
|
||||
"nprobe": 1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class MyUser(HttpUser):
|
||||
wait_time = between(0, 0.1)
|
||||
host = "http://192.168.1.112:19122"
|
||||
|
||||
@task
|
||||
def search(self):
|
||||
response = self.client.put(url=url, json=data, headers=headers, timeout=2)
|
||||
print(response)
|
33
tests/benchmark/locust_flush_task.py
Normal file
33
tests/benchmark/locust_flush_task.py
Normal file
@ -0,0 +1,33 @@
|
||||
import random
|
||||
from locust import User, task, between
|
||||
from locust_task import MilvusTask
|
||||
from client import MilvusClient
|
||||
from milvus import DataType
|
||||
|
||||
connection_type = "single"
|
||||
host = "192.168.1.6"
|
||||
port = 19530
|
||||
collection_name = "create_collection_CZkkwJgo"
|
||||
dim = 128
|
||||
nb = 50000
|
||||
m = MilvusClient(host=host, port=port, collection_name=collection_name)
|
||||
m.clean_db()
|
||||
m.create_collection(dim, data_type=DataType.FLOAT_VECTOR, auto_id=True, other_fields=None)
|
||||
vectors = [[random.random() for _ in range(dim)] for _ in range(nb)]
|
||||
entities = m.generate_entities(vectors)
|
||||
|
||||
|
||||
class FlushTask(User):
|
||||
wait_time = between(0.001, 0.002)
|
||||
if connection_type == "single":
|
||||
client = MilvusTask(m=m)
|
||||
else:
|
||||
client = MilvusTask(host=host, port=port, collection_name=collection_name)
|
||||
|
||||
# def insert(self):
|
||||
# self.client.insert(entities)
|
||||
|
||||
@task(1)
|
||||
def flush(self):
|
||||
self.client.insert(entities)
|
||||
self.client.flush(collection_name)
|
36
tests/benchmark/locust_get_entity_task.py
Normal file
36
tests/benchmark/locust_get_entity_task.py
Normal file
@ -0,0 +1,36 @@
|
||||
import logging
|
||||
import random
|
||||
from locust import User, task, between
|
||||
from locust_task import MilvusTask
|
||||
from client import MilvusClient
|
||||
from milvus import DataType
|
||||
|
||||
connection_type = "single"
|
||||
host = "192.168.1.6"
|
||||
port = 19530
|
||||
collection_name = "sift_10m_100000_128_l2"
|
||||
dim = 128
|
||||
m = MilvusClient(host=host, port=port, collection_name=collection_name)
|
||||
# m.clean_db()
|
||||
# m.create_collection(dim, data_type=DataType.FLOAT_VECTOR, auto_id=True, other_fields=None)
|
||||
nb = 6000
|
||||
# vectors = [[random.random() for _ in range(dim)] for _ in range(nb)]
|
||||
# entities = m.generate_entities(vectors)
|
||||
ids = [i for i in range(nb)]
|
||||
|
||||
class GetEntityTask(User):
|
||||
wait_time = between(0.001, 0.002)
|
||||
if connection_type == "single":
|
||||
client = MilvusTask(m=m)
|
||||
else:
|
||||
client = MilvusTask(host=host, port=port, collection_name=collection_name)
|
||||
|
||||
# def insert(self):
|
||||
# self.client.insert(entities)
|
||||
|
||||
@task(1)
|
||||
def get_entity_by_id(self):
|
||||
# num = random.randint(100, 200)
|
||||
# get_ids = random.sample(ids, num)
|
||||
self.client.get_entities([0])
|
||||
# logging.getLogger().info(len(get_res))
|
33
tests/benchmark/locust_insert_task.py
Normal file
33
tests/benchmark/locust_insert_task.py
Normal file
@ -0,0 +1,33 @@
|
||||
import random
|
||||
from locust import User, task, between
|
||||
from locust_task import MilvusTask
|
||||
from client import MilvusClient
|
||||
from milvus import DataType
|
||||
|
||||
connection_type = "single"
|
||||
host = "192.168.1.6"
|
||||
port = 19530
|
||||
collection_name = "create_collection_hello"
|
||||
dim = 128
|
||||
nb = 50000
|
||||
m = MilvusClient(host=host, port=port, collection_name=collection_name)
|
||||
# m.clean_db()
|
||||
m.create_collection(dim, data_type=DataType.FLOAT_VECTOR, auto_id=True, other_fields=None)
|
||||
vectors = [[random.random() for _ in range(dim)] for _ in range(nb)]
|
||||
entities = m.generate_entities(vectors)
|
||||
|
||||
|
||||
class FlushTask(User):
|
||||
wait_time = between(0.001, 0.002)
|
||||
if connection_type == "single":
|
||||
client = MilvusTask(m=m)
|
||||
else:
|
||||
client = MilvusTask(host=host, port=port, collection_name=collection_name)
|
||||
|
||||
@task(1)
|
||||
def insert(self):
|
||||
self.client.insert(entities)
|
||||
# @task(1)
|
||||
# def create_partition(self):
|
||||
# tag = 'tag_'.join(random.choice(string.ascii_letters) for _ in range(8))
|
||||
# self.client.create_partition(tag, collection_name)
|
46
tests/benchmark/locust_search_task.py
Normal file
46
tests/benchmark/locust_search_task.py
Normal file
@ -0,0 +1,46 @@
|
||||
import random
|
||||
from client import MilvusClient
|
||||
from locust_task import MilvusTask
|
||||
from locust import User, task, between
|
||||
|
||||
connection_type = "single"
|
||||
host = "172.16.50.9"
|
||||
port = 19530
|
||||
collection_name = "sift_1m_2000000_128_l2_2"
|
||||
m = MilvusClient(host=host, port=port, collection_name=collection_name)
|
||||
dim = 128
|
||||
top_k = 5
|
||||
nq = 1
|
||||
X = [[random.random() for i in range(dim)] for i in range(nq)]
|
||||
search_params = {"nprobe": 16}
|
||||
vector_query = {"vector": {'float_vector': {
|
||||
"topk": top_k,
|
||||
"query": X,
|
||||
"params": search_params,
|
||||
'metric_type': 'L2'}}}
|
||||
# m.clean_db()
|
||||
|
||||
|
||||
class QueryTask(User):
|
||||
wait_time = between(0.001, 0.002)
|
||||
|
||||
def preload(self):
|
||||
self.client.preload_collection()
|
||||
|
||||
@task(10)
|
||||
def query(self):
|
||||
if connection_type == "single":
|
||||
client = MilvusTask(m=m, connection_type=connection_type)
|
||||
elif connection_type == "multi":
|
||||
client = MilvusTask(host, port, collection_name, connection_type=connection_type)
|
||||
top_k = 10
|
||||
search_param = {"nprobe": 16}
|
||||
X = [[random.random() for i in range(dim)]]
|
||||
vector_query = {"vector": {"float_vector": {
|
||||
"topk": top_k,
|
||||
"query": X,
|
||||
"metric_type": "L2",
|
||||
"params": search_param}
|
||||
}}
|
||||
filter_query = None
|
||||
client.query(vector_query, filter_query=filter_query, collection_name=collection_name)
|
37
tests/benchmark/locust_task.py
Normal file
37
tests/benchmark/locust_task.py
Normal file
@ -0,0 +1,37 @@
|
||||
import time
|
||||
import pdb
|
||||
import random
|
||||
import logging
|
||||
from locust import User, events
|
||||
from client import MilvusClient
|
||||
|
||||
|
||||
class MilvusTask(object):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.request_type = "grpc"
|
||||
connection_type = kwargs.get("connection_type")
|
||||
if connection_type == "single":
|
||||
self.m = kwargs.get("m")
|
||||
elif connection_type == "multi":
|
||||
host = kwargs.get("host")
|
||||
port = kwargs.get("port")
|
||||
collection_name = kwargs.get("collection_name")
|
||||
self.m = MilvusClient(host=host, port=port, collection_name=collection_name)
|
||||
# logging.getLogger().error(id(self.m))
|
||||
|
||||
def __getattr__(self, name):
|
||||
func = getattr(self.m, name)
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
start_time = time.time()
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
total_time = int((time.time() - start_time) * 1000)
|
||||
events.request_success.fire(request_type=self.request_type, name=name, response_time=total_time,
|
||||
response_length=0)
|
||||
except Exception as e:
|
||||
total_time = int((time.time() - start_time) * 1000)
|
||||
events.request_failure.fire(request_type=self.request_type, name=name, response_time=total_time,
|
||||
exception=e, response_length=0)
|
||||
|
||||
return wrapper
|
45
tests/benchmark/locust_tasks.py
Normal file
45
tests/benchmark/locust_tasks.py
Normal file
@ -0,0 +1,45 @@
|
||||
import random
|
||||
import time
|
||||
import logging
|
||||
from locust import TaskSet, task
|
||||
|
||||
dim = 128
|
||||
X = [[random.random() for _ in range(dim)] for _ in range(1)]
|
||||
|
||||
|
||||
class Tasks(TaskSet):
|
||||
|
||||
@task
|
||||
def query(self):
|
||||
top_k = 10
|
||||
search_param = {"nprobe": 16}
|
||||
X = [[random.random() for i in range(dim)]]
|
||||
vector_query = {"vector": {"float_vector": {
|
||||
"topk": top_k,
|
||||
"query": X,
|
||||
"metric_type": "L2",
|
||||
"params": search_param}
|
||||
}}
|
||||
filter_query = None
|
||||
self.client.query(vector_query, filter_query=filter_query, log=False)
|
||||
|
||||
@task
|
||||
def flush(self):
|
||||
self.client.flush(log=False)
|
||||
|
||||
@task
|
||||
def get(self):
|
||||
self.client.get()
|
||||
|
||||
@task
|
||||
def delete(self):
|
||||
self.client.delete([random.randint(1, 1000000)], log=False)
|
||||
|
||||
def insert(self):
|
||||
ids = [random.randint(1, 10000000)]
|
||||
entities = self.client.generate_entities(X, ids)
|
||||
self.client.insert(entities, ids, log=False)
|
||||
|
||||
@task
|
||||
def insert_rand(self):
|
||||
self.client.insert_rand(log=False)
|
18
tests/benchmark/locust_test.py
Normal file
18
tests/benchmark/locust_test.py
Normal file
@ -0,0 +1,18 @@
|
||||
from locust_user import locust_executor
|
||||
from client import MilvusClient
|
||||
from milvus import DataType
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
connection_type = "single"
|
||||
host = "192.168.1.239"
|
||||
# host = "172.16.50.15"
|
||||
port = 19530
|
||||
collection_name = "sift_1m_2000000_128_l2_2"
|
||||
run_params = {"tasks": {"insert_rand": 5, "query": 10, "flush": 2}, "clients_num": 10, "spawn_rate": 2, "during_time": 3600}
|
||||
dim = 128
|
||||
m = MilvusClient(host=host, port=port, collection_name=collection_name)
|
||||
m.clean_db()
|
||||
m.create_collection(dim, data_type=DataType.FLOAT_VECTOR, auto_id=False, other_fields=None)
|
||||
|
||||
locust_executor(host, port, collection_name, run_params=run_params)
|
70
tests/benchmark/locust_user.py
Normal file
70
tests/benchmark/locust_user.py
Normal file
@ -0,0 +1,70 @@
|
||||
import logging
|
||||
import random
|
||||
import pdb
|
||||
import gevent
|
||||
import gevent.monkey
|
||||
gevent.monkey.patch_all()
|
||||
|
||||
from locust import User, between, events, stats
|
||||
from locust.env import Environment
|
||||
import locust.stats
|
||||
from locust.stats import stats_printer, print_stats
|
||||
|
||||
locust.stats.CONSOLE_STATS_INTERVAL_SEC = 30
|
||||
from locust.log import setup_logging, greenlet_exception_logger
|
||||
|
||||
from locust_tasks import Tasks
|
||||
from client import MilvusClient
|
||||
from locust_task import MilvusTask
|
||||
|
||||
logger = logging.getLogger("__locust__")
|
||||
|
||||
class MyUser(User):
|
||||
# task_set = None
|
||||
wait_time = between(0.001, 0.002)
|
||||
|
||||
|
||||
def locust_executor(host, port, collection_name, connection_type="single", run_params=None):
|
||||
m = MilvusClient(host=host, port=port, collection_name=collection_name)
|
||||
MyUser.tasks = {}
|
||||
tasks = run_params["tasks"]
|
||||
for op, weight in tasks.items():
|
||||
task = {eval("Tasks."+op): weight}
|
||||
MyUser.tasks.update(task)
|
||||
logger.error(MyUser.tasks)
|
||||
# MyUser.tasks = {Tasks.query: 1, Tasks.flush: 1}
|
||||
MyUser.client = MilvusTask(host=host, port=port, collection_name=collection_name, connection_type=connection_type, m=m)
|
||||
env = Environment(events=events, user_classes=[MyUser])
|
||||
runner = env.create_local_runner()
|
||||
# setup logging
|
||||
# setup_logging("WARNING", "/dev/null")
|
||||
setup_logging("WARNING", "/dev/null")
|
||||
greenlet_exception_logger(logger=logger)
|
||||
gevent.spawn(stats_printer(env.stats))
|
||||
# env.create_web_ui("127.0.0.1", 8089)
|
||||
# gevent.spawn(stats_printer(env.stats), env, "test", full_history=True)
|
||||
# events.init.fire(environment=env, runner=runner)
|
||||
clients_num = run_params["clients_num"]
|
||||
spawn_rate = run_params["spawn_rate"]
|
||||
during_time = run_params["during_time"]
|
||||
runner.start(clients_num, spawn_rate=spawn_rate)
|
||||
gevent.spawn_later(during_time, lambda: runner.quit())
|
||||
runner.greenlet.join()
|
||||
print_stats(env.stats)
|
||||
result = {
|
||||
"rps": round(env.stats.total.current_rps, 1),
|
||||
"fail_ratio": env.stats.total.fail_ratio,
|
||||
"max_response_time": round(env.stats.total.max_response_time, 1),
|
||||
"min_response_time": round(env.stats.total.avg_response_time, 1)
|
||||
}
|
||||
runner.stop()
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
connection_type = "single"
|
||||
host = "192.168.1.112"
|
||||
port = 19530
|
||||
collection_name = "sift_1m_2000000_128_l2_2"
|
||||
run_params = {"tasks": {"query": 1, "flush": 1}, "clients_num": 1, "spawn_rate": 1, "during_time": 3}
|
||||
locust_executor(host, port, collection_name, run_params=run_params)
|
199
tests/benchmark/main.py
Normal file
199
tests/benchmark/main.py
Normal file
@ -0,0 +1,199 @@
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime
|
||||
import pdb
|
||||
import argparse
|
||||
import logging
|
||||
import traceback
|
||||
from multiprocessing import Process
|
||||
from queue import Queue
|
||||
from logging import handlers
|
||||
from yaml import full_load, dump
|
||||
from local_runner import LocalRunner
|
||||
from docker_runner import DockerRunner
|
||||
import parser
|
||||
|
||||
DEFAULT_IMAGE = "milvusdb/milvus:latest"
|
||||
LOG_FOLDER = "logs"
|
||||
NAMESPACE = "milvus"
|
||||
LOG_PATH = "/test/milvus/benchmark/logs/"
|
||||
BRANCH = "0.11.1"
|
||||
|
||||
logger = logging.getLogger('milvus_benchmark')
|
||||
logger.setLevel(logging.INFO)
|
||||
# create file handler which logs even debug messages
|
||||
fh = logging.FileHandler(LOG_PATH+'benchmark-{}-{:%Y-%m-%d}.log'.format(BRANCH, datetime.now()))
|
||||
fh.setLevel(logging.DEBUG)
|
||||
# create console handler with a higher log level
|
||||
ch = logging.StreamHandler()
|
||||
ch.setLevel(logging.INFO)
|
||||
# create formatter and add it to the handlers
|
||||
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
fh.setFormatter(formatter)
|
||||
ch.setFormatter(formatter)
|
||||
# add the handlers to the logger
|
||||
logger.addHandler(fh)
|
||||
logger.addHandler(ch)
|
||||
|
||||
def positive_int(s):
|
||||
i = None
|
||||
try:
|
||||
i = int(s)
|
||||
except ValueError:
|
||||
pass
|
||||
if not i or i < 1:
|
||||
raise argparse.ArgumentTypeError("%r is not a positive integer" % s)
|
||||
return i
|
||||
|
||||
|
||||
def get_image_tag(image_version, image_type):
|
||||
return "%s-%s-centos7-release" % (image_version, image_type)
|
||||
# return "%s-%s-centos7-release" % ("0.7.1", image_type)
|
||||
# return "%s-%s-centos7-release" % ("PR-2780", image_type)
|
||||
|
||||
|
||||
def queue_worker(queue):
|
||||
from k8s_runner import K8sRunner
|
||||
while not queue.empty():
|
||||
q = queue.get()
|
||||
suite = q["suite"]
|
||||
server_host = q["server_host"]
|
||||
deploy_mode = q["deploy_mode"]
|
||||
image_type = q["image_type"]
|
||||
image_tag = q["image_tag"]
|
||||
|
||||
with open(suite) as f:
|
||||
suite_dict = full_load(f)
|
||||
f.close()
|
||||
logger.debug(suite_dict)
|
||||
|
||||
run_type, run_params = parser.operations_parser(suite_dict)
|
||||
collections = run_params["collections"]
|
||||
for collection in collections:
|
||||
# run tests
|
||||
milvus_config = collection["milvus"] if "milvus" in collection else None
|
||||
server_config = collection["server"] if "server" in collection else None
|
||||
logger.debug(milvus_config)
|
||||
logger.debug(server_config)
|
||||
runner = K8sRunner()
|
||||
if runner.init_env(milvus_config, server_config, server_host, deploy_mode, image_type, image_tag):
|
||||
logger.debug("Start run tests")
|
||||
try:
|
||||
runner.run(run_type, collection)
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
logger.error(traceback.format_exc())
|
||||
finally:
|
||||
time.sleep(60)
|
||||
runner.clean_up()
|
||||
else:
|
||||
logger.error("Runner init failed")
|
||||
if server_host:
|
||||
logger.debug("All task finished in queue: %s" % server_host)
|
||||
|
||||
|
||||
def main():
|
||||
arg_parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
# helm mode with scheduler
|
||||
arg_parser.add_argument(
|
||||
"--image-version",
|
||||
default="",
|
||||
help="image version")
|
||||
arg_parser.add_argument(
|
||||
"--schedule-conf",
|
||||
metavar='FILE',
|
||||
default='',
|
||||
help="load test schedule from FILE")
|
||||
arg_parser.add_argument(
|
||||
"--deploy-mode",
|
||||
default='',
|
||||
help="single node or multi nodes")
|
||||
|
||||
# local mode
|
||||
arg_parser.add_argument(
|
||||
'--local',
|
||||
action='store_true',
|
||||
help='use local milvus server')
|
||||
arg_parser.add_argument(
|
||||
'--host',
|
||||
help='server host ip param for local mode',
|
||||
default='127.0.0.1')
|
||||
arg_parser.add_argument(
|
||||
'--port',
|
||||
help='server port param for local mode',
|
||||
default='19530')
|
||||
arg_parser.add_argument(
|
||||
'--suite',
|
||||
metavar='FILE',
|
||||
help='load test suite from FILE',
|
||||
default='')
|
||||
|
||||
args = arg_parser.parse_args()
|
||||
|
||||
if args.schedule_conf:
|
||||
if args.local:
|
||||
raise Exception("Helm mode with scheduler and other mode are incompatible")
|
||||
if not args.image_version:
|
||||
raise Exception("Image version not given")
|
||||
image_version = args.image_version
|
||||
deploy_mode = args.deploy_mode
|
||||
with open(args.schedule_conf) as f:
|
||||
schedule_config = full_load(f)
|
||||
f.close()
|
||||
queues = []
|
||||
# server_names = set()
|
||||
server_names = []
|
||||
for item in schedule_config:
|
||||
server_host = item["server"] if "server" in item else ""
|
||||
suite_params = item["suite_params"]
|
||||
server_names.append(server_host)
|
||||
q = Queue()
|
||||
for suite_param in suite_params:
|
||||
suite = "suites/"+suite_param["suite"]
|
||||
image_type = suite_param["image_type"]
|
||||
image_tag = get_image_tag(image_version, image_type)
|
||||
q.put({
|
||||
"suite": suite,
|
||||
"server_host": server_host,
|
||||
"deploy_mode": deploy_mode,
|
||||
"image_tag": image_tag,
|
||||
"image_type": image_type
|
||||
})
|
||||
queues.append(q)
|
||||
logger.error(queues)
|
||||
thread_num = len(server_names)
|
||||
processes = []
|
||||
|
||||
for i in range(thread_num):
|
||||
x = Process(target=queue_worker, args=(queues[i], ))
|
||||
processes.append(x)
|
||||
x.start()
|
||||
time.sleep(10)
|
||||
for x in processes:
|
||||
x.join()
|
||||
|
||||
# queue_worker(queues[0])
|
||||
|
||||
elif args.local:
|
||||
# for local mode
|
||||
host = args.host
|
||||
port = args.port
|
||||
suite = args.suite
|
||||
with open(suite) as f:
|
||||
suite_dict = full_load(f)
|
||||
f.close()
|
||||
logger.debug(suite_dict)
|
||||
run_type, run_params = parser.operations_parser(suite_dict)
|
||||
collections = run_params["collections"]
|
||||
if len(collections) > 1:
|
||||
raise Exception("Multi collections not supported in Local Mode")
|
||||
collection = collections[0]
|
||||
runner = LocalRunner(host, port)
|
||||
logger.info("Start run local mode test, test type: %s" % run_type)
|
||||
runner.run(run_type, collection)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
42
tests/benchmark/mix_task.py
Normal file
42
tests/benchmark/mix_task.py
Normal file
@ -0,0 +1,42 @@
|
||||
import random
|
||||
from locust import User, task, between
|
||||
from locust_task import MilvusTask
|
||||
from client import MilvusClient
|
||||
|
||||
connection_type = "single"
|
||||
host = "192.168.1.29"
|
||||
port = 19530
|
||||
collection_name = "sift_128_euclidean"
|
||||
dim = 128
|
||||
m = MilvusClient(host=host, port=port, collection_name=collection_name)
|
||||
|
||||
|
||||
class MixTask(User):
|
||||
wait_time = between(0.001, 0.002)
|
||||
print("in query task")
|
||||
if connection_type == "single":
|
||||
client = MilvusTask(m=m)
|
||||
else:
|
||||
client = MilvusTask(host=host, port=port, collection_name=collection_name)
|
||||
|
||||
@task(30)
|
||||
def query(self):
|
||||
top_k = 10
|
||||
X = [[random.random() for i in range(dim)] for i in range(1)]
|
||||
search_param = {"nprobe": 16}
|
||||
self.client.query(X, top_k, search_param)
|
||||
|
||||
@task(10)
|
||||
def insert(self):
|
||||
id = random.randint(10000000, 10000000000)
|
||||
X = [[random.random() for i in range(dim)] for i in range(1)]
|
||||
self.client.insert(X, ids=[id])
|
||||
|
||||
@task(1)
|
||||
def flush(self):
|
||||
self.client.flush()
|
||||
|
||||
# @task(5)
|
||||
# def delete(self):
|
||||
# self.client.delete([random.randint(1, 1000000)])
|
||||
|
10
tests/benchmark/operation.py
Normal file
10
tests/benchmark/operation.py
Normal file
@ -0,0 +1,10 @@
|
||||
from __future__ import absolute_import
|
||||
import pdb
|
||||
import time
|
||||
|
||||
class Base(object):
|
||||
pass
|
||||
|
||||
|
||||
class Insert(Base):
|
||||
pass
|
85
tests/benchmark/parser.py
Normal file
85
tests/benchmark/parser.py
Normal file
@ -0,0 +1,85 @@
|
||||
import pdb
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger("milvus_benchmark.parser")
|
||||
|
||||
|
||||
def operations_parser(operations):
|
||||
if not operations:
|
||||
raise Exception("No operations in suite defined")
|
||||
for run_type, run_params in operations.items():
|
||||
logger.debug(run_type)
|
||||
return (run_type, run_params)
|
||||
|
||||
|
||||
def collection_parser(collection_name):
|
||||
tmp = collection_name.split("_")
|
||||
# if len(tmp) != 5:
|
||||
# return None
|
||||
data_type = tmp[0]
|
||||
collection_size_unit = tmp[1][-1]
|
||||
collection_size = tmp[1][0:-1]
|
||||
if collection_size_unit == "m":
|
||||
collection_size = int(collection_size) * 1000000
|
||||
elif collection_size_unit == "b":
|
||||
collection_size = int(collection_size) * 1000000000
|
||||
dimension = int(tmp[2])
|
||||
metric_type = str(tmp[3])
|
||||
return (data_type, collection_size, dimension, metric_type)
|
||||
|
||||
|
||||
def parse_ann_collection_name(collection_name):
|
||||
data_type = collection_name.split("_")[0]
|
||||
dimension = int(collection_name.split("_")[1])
|
||||
metric = collection_name.split("_")[2]
|
||||
# metric = collection_name.attrs['distance']
|
||||
# dimension = len(collection_name["train"][0])
|
||||
if metric == "euclidean":
|
||||
metric_type = "l2"
|
||||
elif metric == "angular":
|
||||
metric_type = "ip"
|
||||
elif metric == "jaccard":
|
||||
metric_type = "jaccard"
|
||||
elif metric == "hamming":
|
||||
metric_type = "hamming"
|
||||
return (data_type, dimension, metric_type)
|
||||
|
||||
|
||||
def search_params_parser(param):
|
||||
# parse top-k, set default value if top-k not in param
|
||||
if "top_ks" not in param:
|
||||
top_ks = [10]
|
||||
else:
|
||||
top_ks = param["top_ks"]
|
||||
if isinstance(top_ks, int):
|
||||
top_ks = [top_ks]
|
||||
elif isinstance(top_ks, list):
|
||||
top_ks = list(top_ks)
|
||||
else:
|
||||
logger.warning("Invalid format top-ks: %s" % str(top_ks))
|
||||
|
||||
# parse nqs, set default value if nq not in param
|
||||
if "nqs" not in param:
|
||||
nqs = [10]
|
||||
else:
|
||||
nqs = param["nqs"]
|
||||
if isinstance(nqs, int):
|
||||
nqs = [nqs]
|
||||
elif isinstance(nqs, list):
|
||||
nqs = list(nqs)
|
||||
else:
|
||||
logger.warning("Invalid format nqs: %s" % str(nqs))
|
||||
|
||||
# parse nprobes
|
||||
if "nprobes" not in param:
|
||||
nprobes = [1]
|
||||
else:
|
||||
nprobes = param["nprobes"]
|
||||
if isinstance(nprobes, int):
|
||||
nprobes = [nprobes]
|
||||
elif isinstance(nprobes, list):
|
||||
nprobes = list(nprobes)
|
||||
else:
|
||||
logger.warning("Invalid format nprobes: %s" % str(nprobes))
|
||||
|
||||
return top_ks, nqs, nprobes
|
12
tests/benchmark/requirements.txt
Normal file
12
tests/benchmark/requirements.txt
Normal file
@ -0,0 +1,12 @@
|
||||
pymilvus-test>=0.5.0,<0.6.0
|
||||
scipy>=1.3.1
|
||||
scikit-learn>=0.19.1
|
||||
h5py>=2.7.1
|
||||
# influxdb==5.2.2
|
||||
pyyaml>=5.1
|
||||
tableprint==0.8.0
|
||||
ansicolors==1.1.8
|
||||
kubernetes==10.0.1
|
||||
# rq==1.2.0
|
||||
locust>=1.3.2
|
||||
pymongo==3.10.0
|
11
tests/benchmark/results/__init__.py
Normal file
11
tests/benchmark/results/__init__.py
Normal file
@ -0,0 +1,11 @@
|
||||
|
||||
class Reporter(object):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def report(self, result):
|
||||
pass
|
||||
|
||||
|
||||
class BaseResult(object):
|
||||
pass
|
0
tests/benchmark/results/reporter.py
Normal file
0
tests/benchmark/results/reporter.py
Normal file
369
tests/benchmark/runner.py
Normal file
369
tests/benchmark/runner.py
Normal file
@ -0,0 +1,369 @@
|
||||
import os
|
||||
import threading
|
||||
import logging
|
||||
import pdb
|
||||
import time
|
||||
import random
|
||||
import grpc
|
||||
from multiprocessing import Process
|
||||
from itertools import product
|
||||
import numpy as np
|
||||
import sklearn.preprocessing
|
||||
from milvus import DataType
|
||||
from client import MilvusClient
|
||||
import utils
|
||||
import parser
|
||||
|
||||
logger = logging.getLogger("milvus_benchmark.runner")
|
||||
|
||||
VECTORS_PER_FILE = 1000000
|
||||
SIFT_VECTORS_PER_FILE = 100000
|
||||
BINARY_VECTORS_PER_FILE = 2000000
|
||||
|
||||
MAX_NQ = 10001
|
||||
FILE_PREFIX = "binary_"
|
||||
|
||||
# FOLDER_NAME = 'ann_1000m/source_data'
|
||||
SRC_BINARY_DATA_DIR = '/test/milvus/raw_data/random/'
|
||||
SIFT_SRC_DATA_DIR = '/test/milvus/raw_data/sift1b/'
|
||||
DEEP_SRC_DATA_DIR = '/test/milvus/raw_data/deep1b/'
|
||||
BINARY_SRC_DATA_DIR = '/test/milvus/raw_data/binary/'
|
||||
SIFT_SRC_GROUNDTRUTH_DATA_DIR = SIFT_SRC_DATA_DIR + 'gnd'
|
||||
|
||||
WARM_TOP_K = 1
|
||||
WARM_NQ = 1
|
||||
DEFAULT_DIM = 512
|
||||
|
||||
|
||||
GROUNDTRUTH_MAP = {
|
||||
"1000000": "idx_1M.ivecs",
|
||||
"2000000": "idx_2M.ivecs",
|
||||
"5000000": "idx_5M.ivecs",
|
||||
"10000000": "idx_10M.ivecs",
|
||||
"20000000": "idx_20M.ivecs",
|
||||
"50000000": "idx_50M.ivecs",
|
||||
"100000000": "idx_100M.ivecs",
|
||||
"200000000": "idx_200M.ivecs",
|
||||
"500000000": "idx_500M.ivecs",
|
||||
"1000000000": "idx_1000M.ivecs",
|
||||
}
|
||||
|
||||
|
||||
def gen_file_name(idx, dimension, data_type):
|
||||
s = "%05d" % idx
|
||||
fname = FILE_PREFIX + str(dimension) + "d_" + s + ".npy"
|
||||
if data_type == "random":
|
||||
fname = SRC_BINARY_DATA_DIR+fname
|
||||
elif data_type == "sift":
|
||||
fname = SIFT_SRC_DATA_DIR+fname
|
||||
elif data_type == "deep":
|
||||
fname = DEEP_SRC_DATA_DIR+fname
|
||||
elif data_type == "binary":
|
||||
fname = BINARY_SRC_DATA_DIR+fname
|
||||
return fname
|
||||
|
||||
|
||||
def get_vectors_from_binary(nq, dimension, data_type):
|
||||
# use the first file, nq should be less than VECTORS_PER_FILE
|
||||
if nq > MAX_NQ:
|
||||
raise Exception("Over size nq")
|
||||
if data_type == "random":
|
||||
file_name = SRC_BINARY_DATA_DIR+'query_%d.npy' % dimension
|
||||
elif data_type == "sift":
|
||||
file_name = SIFT_SRC_DATA_DIR+'query.npy'
|
||||
elif data_type == "deep":
|
||||
file_name = DEEP_SRC_DATA_DIR+'query.npy'
|
||||
elif data_type == "binary":
|
||||
file_name = BINARY_SRC_DATA_DIR+'query.npy'
|
||||
data = np.load(file_name)
|
||||
vectors = data[0:nq].tolist()
|
||||
return vectors
|
||||
|
||||
|
||||
class Runner(object):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def gen_executors(self, operations):
|
||||
l = []
|
||||
for name, operation in operations.items():
|
||||
weight = operation["weight"] if "weight" in operation else 1
|
||||
l.extend([name] * weight)
|
||||
random.shuffle(l)
|
||||
return l
|
||||
|
||||
def get_vector_type(self, data_type):
|
||||
vector_type = ''
|
||||
if data_type in ["random", "sift", "deep", "glove"]:
|
||||
vector_type = DataType.FLOAT_VECTOR
|
||||
elif data_type in ["binary"]:
|
||||
vector_type = DataType.BINARY_VECTOR
|
||||
else:
|
||||
raise Exception("Data type: %s not defined" % data_type)
|
||||
return vector_type
|
||||
|
||||
def get_vector_type_from_metric(self, metric_type):
|
||||
vector_type = ''
|
||||
if metric_type in ["hamming", "jaccard"]:
|
||||
vector_type = DataType.BINARY_VECTOR
|
||||
else:
|
||||
vector_type = DataType.FLOAT_VECTOR
|
||||
return vector_type
|
||||
|
||||
def normalize(self, metric_type, X):
|
||||
if metric_type == "ip":
|
||||
logger.info("Set normalize for metric_type: %s" % metric_type)
|
||||
X = sklearn.preprocessing.normalize(X, axis=1, norm='l2')
|
||||
X = X.astype(np.float32)
|
||||
elif metric_type == "l2":
|
||||
X = X.astype(np.float32)
|
||||
elif metric_type in ["jaccard", "hamming", "sub", "super"]:
|
||||
tmp = []
|
||||
for item in X:
|
||||
new_vector = bytes(np.packbits(item, axis=-1).tolist())
|
||||
tmp.append(new_vector)
|
||||
X = tmp
|
||||
return X
|
||||
|
||||
def generate_combinations(self, args):
|
||||
if isinstance(args, list):
|
||||
args = [el if isinstance(el, list) else [el] for el in args]
|
||||
return [list(x) for x in product(*args)]
|
||||
elif isinstance(args, dict):
|
||||
flat = []
|
||||
for k, v in args.items():
|
||||
if isinstance(v, list):
|
||||
flat.append([(k, el) for el in v])
|
||||
else:
|
||||
flat.append([(k, v)])
|
||||
return [dict(x) for x in product(*flat)]
|
||||
else:
|
||||
raise TypeError("No args handling exists for %s" % type(args).__name__)
|
||||
|
||||
def do_insert(self, milvus, collection_name, data_type, dimension, size, ni):
|
||||
'''
|
||||
@params:
|
||||
mivlus: server connect instance
|
||||
dimension: collection dimensionn
|
||||
# index_file_size: size trigger file merge
|
||||
size: row count of vectors to be insert
|
||||
ni: row count of vectors to be insert each time
|
||||
# store_id: if store the ids returned by call add_vectors or not
|
||||
@return:
|
||||
total_time: total time for all insert operation
|
||||
qps: vectors added per second
|
||||
ni_time: avarage insert operation time
|
||||
'''
|
||||
bi_res = {}
|
||||
total_time = 0.0
|
||||
qps = 0.0
|
||||
ni_time = 0.0
|
||||
if data_type == "random":
|
||||
if dimension == 512:
|
||||
vectors_per_file = VECTORS_PER_FILE
|
||||
elif dimension == 4096:
|
||||
vectors_per_file = 100000
|
||||
elif dimension == 16384:
|
||||
vectors_per_file = 10000
|
||||
elif data_type == "sift":
|
||||
vectors_per_file = SIFT_VECTORS_PER_FILE
|
||||
elif data_type in ["binary"]:
|
||||
vectors_per_file = BINARY_VECTORS_PER_FILE
|
||||
else:
|
||||
raise Exception("data_type: %s not supported" % data_type)
|
||||
if size % vectors_per_file or size % ni:
|
||||
raise Exception("Not invalid collection size or ni")
|
||||
i = 0
|
||||
while i < (size // vectors_per_file):
|
||||
vectors = []
|
||||
if vectors_per_file >= ni:
|
||||
file_name = gen_file_name(i, dimension, data_type)
|
||||
# logger.info("Load npy file: %s start" % file_name)
|
||||
data = np.load(file_name)
|
||||
# logger.info("Load npy file: %s end" % file_name)
|
||||
for j in range(vectors_per_file // ni):
|
||||
vectors = data[j*ni:(j+1)*ni].tolist()
|
||||
if vectors:
|
||||
# start insert vectors
|
||||
start_id = i * vectors_per_file + j * ni
|
||||
end_id = start_id + len(vectors)
|
||||
logger.debug("Start id: %s, end id: %s" % (start_id, end_id))
|
||||
ids = [k for k in range(start_id, end_id)]
|
||||
entities = milvus.generate_entities(vectors, ids)
|
||||
ni_start_time = time.time()
|
||||
try:
|
||||
res_ids = milvus.insert(entities, ids=ids)
|
||||
except grpc.RpcError as e:
|
||||
if e.code() == grpc.StatusCode.UNAVAILABLE:
|
||||
logger.debug("Retry insert")
|
||||
def retry():
|
||||
res_ids = milvus.insert(entities, ids=ids)
|
||||
|
||||
t0 = threading.Thread(target=retry)
|
||||
t0.start()
|
||||
t0.join()
|
||||
logger.debug("Retry successfully")
|
||||
raise e
|
||||
assert ids == res_ids
|
||||
# milvus.flush()
|
||||
logger.debug(milvus.count())
|
||||
ni_end_time = time.time()
|
||||
total_time = total_time + ni_end_time - ni_start_time
|
||||
i += 1
|
||||
else:
|
||||
vectors.clear()
|
||||
loops = ni // vectors_per_file
|
||||
for j in range(loops):
|
||||
file_name = gen_file_name(loops*i+j, dimension, data_type)
|
||||
data = np.load(file_name)
|
||||
vectors.extend(data.tolist())
|
||||
if vectors:
|
||||
start_id = i * vectors_per_file
|
||||
end_id = start_id + len(vectors)
|
||||
logger.info("Start id: %s, end id: %s" % (start_id, end_id))
|
||||
ids = [k for k in range(start_id, end_id)]
|
||||
entities = milvus.generate_entities(vectors, ids)
|
||||
ni_start_time = time.time()
|
||||
try:
|
||||
res_ids = milvus.insert(entities, ids=ids)
|
||||
except grpc.RpcError as e:
|
||||
if e.code() == grpc.StatusCode.UNAVAILABLE:
|
||||
logger.debug("Retry insert")
|
||||
def retry():
|
||||
res_ids = milvus.insert(entities, ids=ids)
|
||||
|
||||
t0 = threading.Thread(target=retry)
|
||||
t0.start()
|
||||
t0.join()
|
||||
logger.debug("Retry successfully")
|
||||
raise e
|
||||
|
||||
assert ids == res_ids
|
||||
# milvus.flush()
|
||||
logger.debug(milvus.count())
|
||||
ni_end_time = time.time()
|
||||
total_time = total_time + ni_end_time - ni_start_time
|
||||
i += loops
|
||||
qps = round(size / total_time, 2)
|
||||
ni_time = round(total_time / (size / ni), 2)
|
||||
bi_res["total_time"] = round(total_time, 2)
|
||||
bi_res["qps"] = qps
|
||||
bi_res["ni_time"] = ni_time
|
||||
return bi_res
|
||||
|
||||
def do_query(self, milvus, collection_name, vec_field_name, top_ks, nqs, run_count=1, search_param=None, filter_query=None):
|
||||
bi_res = []
|
||||
(data_type, collection_size, index_file_size, dimension, metric_type) = parser.collection_parser(collection_name)
|
||||
base_query_vectors = get_vectors_from_binary(MAX_NQ, dimension, data_type)
|
||||
for nq in nqs:
|
||||
tmp_res = []
|
||||
query_vectors = base_query_vectors[0:nq]
|
||||
for top_k in top_ks:
|
||||
avg_query_time = 0.0
|
||||
min_query_time = 0.0
|
||||
logger.info("Start query, query params: top-k: {}, nq: {}, actually length of vectors: {}".format(top_k, nq, len(query_vectors)))
|
||||
for i in range(run_count):
|
||||
logger.debug("Start run query, run %d of %s" % (i+1, run_count))
|
||||
start_time = time.time()
|
||||
vector_query = {"vector": {vec_field_name: {
|
||||
"topk": top_k,
|
||||
"query": query_vectors,
|
||||
"metric_type": utils.metric_type_trans(metric_type),
|
||||
"params": search_param}
|
||||
}}
|
||||
query_res = milvus.query(vector_query, filter_query=filter_query)
|
||||
interval_time = time.time() - start_time
|
||||
if (i == 0) or (min_query_time > interval_time):
|
||||
min_query_time = interval_time
|
||||
logger.info("Min query time: %.2f" % min_query_time)
|
||||
tmp_res.append(round(min_query_time, 2))
|
||||
bi_res.append(tmp_res)
|
||||
return bi_res
|
||||
|
||||
def do_query_qps(self, milvus, query_vectors, top_k, search_param):
|
||||
start_time = time.time()
|
||||
result = milvus.query(query_vectors, top_k, search_param)
|
||||
end_time = time.time()
|
||||
return end_time - start_time
|
||||
|
||||
def do_query_ids(self, milvus, collection_name, vec_field_name, top_k, nq, search_param=None, filter_query=None):
|
||||
(data_type, collection_size, index_file_size, dimension, metric_type) = parser.collection_parser(collection_name)
|
||||
base_query_vectors = get_vectors_from_binary(MAX_NQ, dimension, data_type)
|
||||
query_vectors = base_query_vectors[0:nq]
|
||||
logger.info("Start query, query params: top-k: {}, nq: {}, actually length of vectors: {}".format(top_k, nq, len(query_vectors)))
|
||||
vector_query = {"vector": {vec_field_name: {
|
||||
"topk": top_k,
|
||||
"query": query_vectors,
|
||||
"metric_type": utils.metric_type_trans(metric_type),
|
||||
"params": search_param}
|
||||
}}
|
||||
query_res = milvus.query(vector_query, filter_query=filter_query)
|
||||
result_ids = milvus.get_ids(query_res)
|
||||
return result_ids
|
||||
|
||||
def do_query_acc(self, milvus, collection_name, top_k, nq, id_store_name, search_param=None):
|
||||
(data_type, collection_size, index_file_size, dimension, metric_type) = parser.collection_parser(collection_name)
|
||||
base_query_vectors = get_vectors_from_binary(MAX_NQ, dimension, data_type)
|
||||
vectors = base_query_vectors[0:nq]
|
||||
logger.info("Start query, query params: top-k: {}, nq: {}, actually length of vectors: {}".format(top_k, nq, len(vectors)))
|
||||
query_res = milvus.query(vectors, top_k, search_param=None)
|
||||
# if file existed, cover it
|
||||
if os.path.isfile(id_store_name):
|
||||
os.remove(id_store_name)
|
||||
with open(id_store_name, 'a+') as fd:
|
||||
for nq_item in query_res:
|
||||
for item in nq_item:
|
||||
fd.write(str(item.id)+'\t')
|
||||
fd.write('\n')
|
||||
|
||||
# compute and print accuracy
|
||||
def compute_accuracy(self, flat_file_name, index_file_name):
|
||||
flat_id_list = []; index_id_list = []
|
||||
logger.info("Loading flat id file: %s" % flat_file_name)
|
||||
with open(flat_file_name, 'r') as flat_id_fd:
|
||||
for line in flat_id_fd:
|
||||
tmp_list = line.strip("\n").strip().split("\t")
|
||||
flat_id_list.append(tmp_list)
|
||||
logger.info("Loading index id file: %s" % index_file_name)
|
||||
with open(index_file_name) as index_id_fd:
|
||||
for line in index_id_fd:
|
||||
tmp_list = line.strip("\n").strip().split("\t")
|
||||
index_id_list.append(tmp_list)
|
||||
if len(flat_id_list) != len(index_id_list):
|
||||
raise Exception("Flat index result length: <flat: %s, index: %s> not match, Acc compute exiting ..." % (len(flat_id_list), len(index_id_list)))
|
||||
# get the accuracy
|
||||
return self.get_recall_value(flat_id_list, index_id_list)
|
||||
|
||||
def get_recall_value(self, true_ids, result_ids):
|
||||
"""
|
||||
Use the intersection length
|
||||
"""
|
||||
sum_radio = 0.0
|
||||
for index, item in enumerate(result_ids):
|
||||
# tmp = set(item).intersection(set(flat_id_list[index]))
|
||||
tmp = set(true_ids[index]).intersection(set(item))
|
||||
sum_radio = sum_radio + len(tmp) / len(item)
|
||||
# logger.debug(sum_radio)
|
||||
return round(sum_radio / len(result_ids), 3)
|
||||
|
||||
"""
|
||||
Implementation based on:
|
||||
https://github.com/facebookresearch/faiss/blob/master/benchs/datasets.py
|
||||
"""
|
||||
def get_groundtruth_ids(self, collection_size):
|
||||
fname = GROUNDTRUTH_MAP[str(collection_size)]
|
||||
fname = SIFT_SRC_GROUNDTRUTH_DATA_DIR + "/" + fname
|
||||
a = np.fromfile(fname, dtype='int32')
|
||||
d = a[0]
|
||||
true_ids = a.reshape(-1, d + 1)[:, 1:].copy()
|
||||
return true_ids
|
||||
|
||||
def get_fields(self, milvus, collection_name):
|
||||
fields = []
|
||||
info = milvus.get_info(collection_name)
|
||||
for item in info["fields"]:
|
||||
fields.append(item["name"])
|
||||
return fields
|
||||
|
||||
# def get_filter_query(self, filter_query):
|
||||
# for filter in filter_query:
|
11
tests/benchmark/runners/__init__.py
Normal file
11
tests/benchmark/runners/__init__.py
Normal file
@ -0,0 +1,11 @@
|
||||
|
||||
|
||||
class BaseRunner(object):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def set_up(self):
|
||||
pass
|
||||
|
||||
def tear_down(self):
|
||||
pass
|
75
tests/benchmark/runners/locust_runner.py
Normal file
75
tests/benchmark/runners/locust_runner.py
Normal file
@ -0,0 +1,75 @@
|
||||
import time
|
||||
import random
|
||||
from locust import Locust, TaskSet, events, task, between
|
||||
from client import MilvusClient
|
||||
from . import BasicRunner
|
||||
|
||||
|
||||
dim = 128
|
||||
top_k = 10
|
||||
X = [[random.random() for i in range(dim)] for i in range(1)]
|
||||
search_param = {"nprobe": 16}
|
||||
|
||||
|
||||
class MilvusTask(object):
|
||||
def __init__(self, type="single", args):
|
||||
self.type = type
|
||||
self.m = None
|
||||
if type == "single":
|
||||
self.m = MilvusClient(host=args["host"], port=args["port"], collection_name=args["collection_name"])
|
||||
elif type == "multi":
|
||||
self.m = MilvusClient(host=args["m"])
|
||||
|
||||
def query(self, *args, **kwargs):
|
||||
name = "milvus_search"
|
||||
request_type = "grpc"
|
||||
start_time = time.time()
|
||||
try:
|
||||
# result = self.m.getattr(*args, **kwargs)
|
||||
status, result = self.m.query(*args, **kwargs)
|
||||
except Exception as e:
|
||||
total_time = int((time.time() - start_time) * 1000)
|
||||
events.request_failure.fire(request_type=request_type, name=name, response_time=total_time, exception=e, response_length=0)
|
||||
else:
|
||||
if not status.OK:
|
||||
total_time = int((time.time() - start_time) * 1000)
|
||||
events.request_failure.fire(request_type=request_type, name=name, response_time=total_time, exception=e, response_length=0)
|
||||
else:
|
||||
total_time = int((time.time() - start_time) * 1000)
|
||||
events.request_success.fire(request_type=request_type, name=name, response_time=total_time, response_length=0)
|
||||
# In this example, I've hardcoded response_length=0. If we would want the response length to be
|
||||
# reported correctly in the statistics, we would probably need to hook in at a lower level
|
||||
|
||||
|
||||
class MilvusLocust(Locust):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(MilvusLocust, self).__init__(*args, **kwargs)
|
||||
self.client = MilvusTask(self.host, self.port, self.collection_name)
|
||||
|
||||
|
||||
class Query(MilvusLocust):
|
||||
host = "192.168.1.183"
|
||||
port = 19530
|
||||
collection_name = "sift_128_euclidean"
|
||||
# m = MilvusClient(host=host, port=port, collection_name=collection_name)
|
||||
wait_time = between(0.001, 0.002)
|
||||
|
||||
class task_set(TaskSet):
|
||||
@task
|
||||
def query(self):
|
||||
self.client.query(X, top_k, search_param)
|
||||
|
||||
|
||||
class LocustRunner(BasicRunner):
|
||||
"""Only one client, not support M/S mode"""
|
||||
def __init__(self, args):
|
||||
# Start client with params including client number && last time && hatch rate ...
|
||||
pass
|
||||
|
||||
def set_up(self):
|
||||
# helm install locust client
|
||||
pass
|
||||
|
||||
def tear_down(self):
|
||||
# helm uninstall
|
||||
pass
|
65
tests/benchmark/scheduler/010_data.json
Normal file
65
tests/benchmark/scheduler/010_data.json
Normal file
@ -0,0 +1,65 @@
|
||||
[
|
||||
{
|
||||
"server": "athena",
|
||||
"suite_params": [
|
||||
{
|
||||
"suite": "080_gpu_accuracy.yaml",
|
||||
"image_type": "gpu"
|
||||
},
|
||||
{
|
||||
"suite": "080_search_stability.yaml",
|
||||
"image_type": "gpu"
|
||||
},
|
||||
{
|
||||
"suite": "gpu_accuracy_ann.yaml",
|
||||
"image_type": "gpu"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"server": "poseidon",
|
||||
"suite_params": [
|
||||
{
|
||||
"suite": "080_gpu_search.yaml",
|
||||
"image_type": "gpu"
|
||||
},
|
||||
{
|
||||
"suite": "080_cpu_search.yaml",
|
||||
"image_type": "cpu"
|
||||
},
|
||||
{
|
||||
"suite": "080_gpu_build.yaml",
|
||||
"image_type": "gpu"
|
||||
},
|
||||
{
|
||||
"suite": "080_cpu_accuracy.yaml",
|
||||
"image_type": "cpu"
|
||||
},
|
||||
{
|
||||
"suite": "locust_search.yaml",
|
||||
"image_type": "cpu"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"server": "apollo",
|
||||
"suite_params": [
|
||||
{
|
||||
"suite": "cpu_accuracy_ann.yaml",
|
||||
"image_type": "cpu"
|
||||
},
|
||||
{
|
||||
"suite": "080_cpu_build.yaml",
|
||||
"image_type": "cpu"
|
||||
},
|
||||
{
|
||||
"suite": "080_insert_performance.yaml",
|
||||
"image_type": "cpu"
|
||||
},
|
||||
{
|
||||
"suite": "add_flush_performance.yaml",
|
||||
"image_type": "cpu"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
62
tests/benchmark/scheduler/011_data.json
Normal file
62
tests/benchmark/scheduler/011_data.json
Normal file
@ -0,0 +1,62 @@
|
||||
[
|
||||
{
|
||||
"server": "idc-sh002",
|
||||
"suite_params": [
|
||||
{
|
||||
"suite": "011_cpu_accuracy_ann.yaml",
|
||||
"image_type": "cpu"
|
||||
},
|
||||
{
|
||||
"suite": "011_gpu_accuracy_ann.yaml",
|
||||
"image_type": "gpu"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"server": "idc-sh003",
|
||||
"suite_params": [
|
||||
{
|
||||
"suite": "locust_mix.yaml",
|
||||
"image_type": "gpu"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"server": "idc-sh004",
|
||||
"suite_params": [
|
||||
{
|
||||
"suite": "011_insert_performance.yaml",
|
||||
"image_type": "cpu"
|
||||
},
|
||||
{
|
||||
"suite": "011_gpu_accuracy.yaml",
|
||||
"image_type": "gpu"
|
||||
},
|
||||
{
|
||||
"suite": "011_gpu_build.yaml",
|
||||
"image_type": "gpu"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"server": "idc-sh005",
|
||||
"suite_params": [
|
||||
{
|
||||
"suite": "011_gpu_search.yaml",
|
||||
"image_type": "gpu"
|
||||
},
|
||||
{
|
||||
"suite": "011_cpu_search.yaml",
|
||||
"image_type": "cpu"
|
||||
},
|
||||
{
|
||||
"suite": "011_cpu_accuracy.yaml",
|
||||
"image_type": "cpu"
|
||||
},
|
||||
{
|
||||
"suite": "011_locust_search.yaml",
|
||||
"image_type": "cpu"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
11
tests/benchmark/scheduler/011_data_acc_debug.json
Normal file
11
tests/benchmark/scheduler/011_data_acc_debug.json
Normal file
@ -0,0 +1,11 @@
|
||||
[
|
||||
{
|
||||
"server": "apollo",
|
||||
"suite_params": [
|
||||
{
|
||||
"suite": "011_cpu_accuracy_ann.yaml",
|
||||
"image_type": "cpu"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
11
tests/benchmark/scheduler/011_data_gpu_build.json
Normal file
11
tests/benchmark/scheduler/011_data_gpu_build.json
Normal file
@ -0,0 +1,11 @@
|
||||
[
|
||||
{
|
||||
"server": "eros",
|
||||
"suite_params": [
|
||||
{
|
||||
"suite": "011_gpu_build_sift10m.yaml",
|
||||
"image_type": "gpu"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
11
tests/benchmark/scheduler/011_data_insert.json
Normal file
11
tests/benchmark/scheduler/011_data_insert.json
Normal file
@ -0,0 +1,11 @@
|
||||
[
|
||||
{
|
||||
"server": "eros",
|
||||
"suite_params": [
|
||||
{
|
||||
"suite": "011_insert_data.yaml",
|
||||
"image_type": "cpu"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
11
tests/benchmark/scheduler/011_data_search_debug.json
Normal file
11
tests/benchmark/scheduler/011_data_search_debug.json
Normal file
@ -0,0 +1,11 @@
|
||||
[
|
||||
{
|
||||
"server": "athena",
|
||||
"suite_params": [
|
||||
{
|
||||
"suite": "011_gpu_search_debug.yaml",
|
||||
"image_type": "gpu"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
15
tests/benchmark/scheduler/011_delete.json
Normal file
15
tests/benchmark/scheduler/011_delete.json
Normal file
@ -0,0 +1,15 @@
|
||||
[
|
||||
{
|
||||
"server": "apollo",
|
||||
"suite_params": [
|
||||
{
|
||||
"suite": "011_insert_performance.yaml",
|
||||
"image_type": "cpu"
|
||||
},
|
||||
{
|
||||
"suite": "011_delete_performance.yaml",
|
||||
"image_type": "cpu"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
65
tests/benchmark/scheduler/080_data.json
Normal file
65
tests/benchmark/scheduler/080_data.json
Normal file
@ -0,0 +1,65 @@
|
||||
[
|
||||
{
|
||||
"server": "athena",
|
||||
"suite_params": [
|
||||
{
|
||||
"suite": "080_gpu_accuracy.yaml",
|
||||
"image_type": "gpu"
|
||||
},
|
||||
{
|
||||
"suite": "080_search_stability.yaml",
|
||||
"image_type": "gpu"
|
||||
},
|
||||
{
|
||||
"suite": "gpu_accuracy_ann.yaml",
|
||||
"image_type": "gpu"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"server": "poseidon",
|
||||
"suite_params": [
|
||||
{
|
||||
"suite": "080_gpu_search.yaml",
|
||||
"image_type": "gpu"
|
||||
},
|
||||
{
|
||||
"suite": "080_cpu_search.yaml",
|
||||
"image_type": "cpu"
|
||||
},
|
||||
{
|
||||
"suite": "080_gpu_build.yaml",
|
||||
"image_type": "gpu"
|
||||
},
|
||||
{
|
||||
"suite": "080_cpu_accuracy.yaml",
|
||||
"image_type": "cpu"
|
||||
},
|
||||
{
|
||||
"suite": "locust_search.yaml",
|
||||
"image_type": "cpu"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"server": "apollo",
|
||||
"suite_params": [
|
||||
{
|
||||
"suite": "cpu_accuracy_ann.yaml",
|
||||
"image_type": "cpu"
|
||||
},
|
||||
{
|
||||
"suite": "080_cpu_build.yaml",
|
||||
"image_type": "cpu"
|
||||
},
|
||||
{
|
||||
"suite": "080_insert_performance.yaml",
|
||||
"image_type": "cpu"
|
||||
},
|
||||
{
|
||||
"suite": "add_flush_performance.yaml",
|
||||
"image_type": "cpu"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
15
tests/benchmark/scheduler/acc.json
Normal file
15
tests/benchmark/scheduler/acc.json
Normal file
@ -0,0 +1,15 @@
|
||||
[
|
||||
{
|
||||
"server": "poseidon",
|
||||
"suite_params": [
|
||||
{
|
||||
"suite": "crud_add.yaml",
|
||||
"image_type": "gpu"
|
||||
},
|
||||
{
|
||||
"suite": "gpu_accuracy_sift1m.yaml",
|
||||
"image_type": "gpu"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
11
tests/benchmark/scheduler/build.json
Normal file
11
tests/benchmark/scheduler/build.json
Normal file
@ -0,0 +1,11 @@
|
||||
[
|
||||
{
|
||||
"server": "eros",
|
||||
"suite_params": [
|
||||
{
|
||||
"suite": "011_gpu_build_sift1b.yaml",
|
||||
"image_type": "gpu"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
11
tests/benchmark/scheduler/clean.json
Normal file
11
tests/benchmark/scheduler/clean.json
Normal file
@ -0,0 +1,11 @@
|
||||
[
|
||||
{
|
||||
"server": "poseidon",
|
||||
"suite_params": [
|
||||
{
|
||||
"suite": "clean.yaml",
|
||||
"image_type": "cpu"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
11
tests/benchmark/scheduler/debug.json
Normal file
11
tests/benchmark/scheduler/debug.json
Normal file
@ -0,0 +1,11 @@
|
||||
[
|
||||
{
|
||||
"server": "idc-sh002",
|
||||
"suite_params": [
|
||||
{
|
||||
"suite": "011_cpu_search_sift10m.yaml",
|
||||
"image_type": "cpu"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
53
tests/benchmark/scheduler/default_config.json
Normal file
53
tests/benchmark/scheduler/default_config.json
Normal file
@ -0,0 +1,53 @@
|
||||
[
|
||||
{
|
||||
"server": "apollo",
|
||||
"suite_params": [
|
||||
{
|
||||
"suite": "cpu_accuracy_ann.yaml",
|
||||
"image_type": "cpu"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"server": "poseidon",
|
||||
"suite_params": [
|
||||
{
|
||||
"suite": "gpu_search_performance.yaml",
|
||||
"image_type": "gpu"
|
||||
},
|
||||
{
|
||||
"suite": "cpu_search_performance.yaml",
|
||||
"image_type": "cpu"
|
||||
},
|
||||
{
|
||||
"suite": "insert_performance.yaml",
|
||||
"image_type": "gpu"
|
||||
},
|
||||
{
|
||||
"suite": "gpu_accuracy.yaml",
|
||||
"image_type": "gpu"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"server": "eros",
|
||||
"suite_params": [
|
||||
{
|
||||
"suite": "gpu_accuracy_ann.yaml",
|
||||
"image_type": "gpu"
|
||||
},
|
||||
{
|
||||
"suite": "gpu_search_stability.yaml",
|
||||
"image_type": "gpu"
|
||||
},
|
||||
{
|
||||
"suite": "gpu_build_performance.yaml",
|
||||
"image_type": "gpu"
|
||||
},
|
||||
{
|
||||
"suite": "cpu_build_performance.yaml",
|
||||
"image_type": "cpu"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
11
tests/benchmark/scheduler/file_size.json
Normal file
11
tests/benchmark/scheduler/file_size.json
Normal file
@ -0,0 +1,11 @@
|
||||
[
|
||||
{
|
||||
"server": "athena",
|
||||
"suite_params": [
|
||||
{
|
||||
"suite": "file_size.yaml",
|
||||
"image_type": "cpu"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
11
tests/benchmark/scheduler/filter.json
Normal file
11
tests/benchmark/scheduler/filter.json
Normal file
@ -0,0 +1,11 @@
|
||||
[
|
||||
{
|
||||
"server": "poseidon",
|
||||
"suite_params": [
|
||||
{
|
||||
"suite": "011_search_dsl.yaml",
|
||||
"image_type": "gpu"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
11
tests/benchmark/scheduler/idc.json
Normal file
11
tests/benchmark/scheduler/idc.json
Normal file
@ -0,0 +1,11 @@
|
||||
[
|
||||
{
|
||||
"server": "idc-sh004",
|
||||
"suite_params": [
|
||||
{
|
||||
"suite": "011_cpu_search_debug.yaml",
|
||||
"image_type": "gpu"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
11
tests/benchmark/scheduler/insert.json
Normal file
11
tests/benchmark/scheduler/insert.json
Normal file
@ -0,0 +1,11 @@
|
||||
[
|
||||
{
|
||||
"server": "idc-sh002",
|
||||
"suite_params": [
|
||||
{
|
||||
"suite": "011_insert_data.yaml",
|
||||
"image_type": "cpu"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
11
tests/benchmark/scheduler/jaccard.json
Normal file
11
tests/benchmark/scheduler/jaccard.json
Normal file
@ -0,0 +1,11 @@
|
||||
[
|
||||
{
|
||||
"server": "athena",
|
||||
"suite_params": [
|
||||
{
|
||||
"suite": "011_cpu_search_binary.yaml",
|
||||
"image_type": "cpu"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
11
tests/benchmark/scheduler/locust.json
Normal file
11
tests/benchmark/scheduler/locust.json
Normal file
@ -0,0 +1,11 @@
|
||||
[
|
||||
{
|
||||
"server": "idc-sh002",
|
||||
"suite_params": [
|
||||
{
|
||||
"suite": "locust_cluster_search.yaml",
|
||||
"image_type": "cpu"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
10
tests/benchmark/scheduler/locust_mix_debug.json
Normal file
10
tests/benchmark/scheduler/locust_mix_debug.json
Normal file
@ -0,0 +1,10 @@
|
||||
[
|
||||
{
|
||||
"suite_params": [
|
||||
{
|
||||
"suite": "locust_mix.yaml",
|
||||
"image_type": "cpu"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
10
tests/benchmark/scheduler/loop.json
Normal file
10
tests/benchmark/scheduler/loop.json
Normal file
@ -0,0 +1,10 @@
|
||||
[
|
||||
{
|
||||
"suite_params": [
|
||||
{
|
||||
"suite": "loop_stability.yaml",
|
||||
"image_type": "gpu"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
11
tests/benchmark/scheduler/search.json
Normal file
11
tests/benchmark/scheduler/search.json
Normal file
@ -0,0 +1,11 @@
|
||||
[
|
||||
{
|
||||
"server": "athena",
|
||||
"suite_params": [
|
||||
{
|
||||
"suite": "011_cpu_search_sift1b.yaml",
|
||||
"image_type": "cpu"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
18
tests/benchmark/scheduler/shards.json
Normal file
18
tests/benchmark/scheduler/shards.json
Normal file
@ -0,0 +1,18 @@
|
||||
[
|
||||
{
|
||||
"suite_params": [
|
||||
{
|
||||
"suite": "shards_insert_performance.yaml",
|
||||
"image_type": "cpu"
|
||||
},
|
||||
{
|
||||
"suite": "shards_ann_debug.yaml",
|
||||
"image_type": "cpu"
|
||||
},
|
||||
{
|
||||
"suite": "shards_loop_stability.yaml",
|
||||
"image_type": "cpu"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
10
tests/benchmark/scheduler/shards_ann.json
Normal file
10
tests/benchmark/scheduler/shards_ann.json
Normal file
@ -0,0 +1,10 @@
|
||||
[
|
||||
{
|
||||
"suite_params": [
|
||||
{
|
||||
"suite": "shards_ann_debug.yaml",
|
||||
"image_type": "cpu"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user