milvus/internal/querynodev2/tasks/concurrent_safe_scheduler.go

306 lines
8.5 KiB
Go
Raw Normal View History

package tasks
import (
"fmt"
"sync"
"go.uber.org/atomic"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/querynodev2/collector"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/metrics"
"github.com/milvus-io/milvus/pkg/util/conc"
"github.com/milvus-io/milvus/pkg/util/lifetime"
"github.com/milvus-io/milvus/pkg/util/metricsinfo"
"github.com/milvus-io/milvus/pkg/util/paramtable"
)
const (
maxReceiveChanBatchConsumeNum = 100
)
// newScheduler create a scheduler with given schedule policy.
func newScheduler(policy schedulePolicy) Scheduler {
maxReadConcurrency := paramtable.Get().QueryNodeCfg.MaxReadConcurrency.GetAsInt()
maxReceiveChanSize := paramtable.Get().QueryNodeCfg.MaxReceiveChanSize.GetAsInt()
log.Info("query node use concurrent safe scheduler", zap.Int("max_concurrency", maxReadConcurrency))
return &scheduler{
policy: policy,
receiveChan: make(chan addTaskReq, maxReceiveChanSize),
execChan: make(chan Task),
pool: conc.NewPool[any](maxReadConcurrency, conc.WithPreAlloc(true)),
gpuPool: conc.NewPool[any](paramtable.Get().QueryNodeCfg.MaxGpuReadConcurrency.GetAsInt(), conc.WithPreAlloc(true)),
schedulerCounter: schedulerCounter{},
lifetime: lifetime.NewLifetime(lifetime.Initializing),
}
}
type addTaskReq struct {
task Task
err chan<- error
}
// scheduler is a general concurrent safe scheduler implementation by wrapping a schedule policy.
type scheduler struct {
policy schedulePolicy
receiveChan chan addTaskReq
execChan chan Task
pool *conc.Pool[any]
gpuPool *conc.Pool[any]
// wg is the waitgroup for internal worker goroutine
wg sync.WaitGroup
// lifetime controls scheduler State & make sure all requests accepted will be processed
lifetime lifetime.Lifetime[lifetime.State]
schedulerCounter
}
// Add a new task into scheduler,
// error will be returned if scheduler reaches some limit.
func (s *scheduler) Add(task Task) (err error) {
if err := s.lifetime.Add(lifetime.IsWorking); err != nil {
return err
}
defer s.lifetime.Done()
errCh := make(chan error, 1)
// TODO: add operation should be fast, is UnsolveLen metric unnesscery?
metrics.QueryNodeReadTaskUnsolveLen.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Inc()
// start a new in queue span and send task to add chan
s.receiveChan <- addTaskReq{
task: task,
err: errCh,
}
err = <-errCh
metrics.QueryNodeReadTaskUnsolveLen.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Dec()
return
}
// Start schedule the owned task asynchronously and continuously.
// Start should be only call once.
func (s *scheduler) Start() {
s.wg.Add(2)
// Start a background task executing loop.
go s.exec()
// Begin to schedule tasks.
go s.schedule()
s.lifetime.SetState(lifetime.Working)
}
func (s *scheduler) Stop() {
s.lifetime.SetState(lifetime.Stopped)
// wait all accepted Add done
s.lifetime.Wait()
// close receiveChan start stopping process for `schedule`
close(s.receiveChan)
// wait workers quit
s.wg.Wait()
}
// schedule the owned task asynchronously and continuously.
func (s *scheduler) schedule() {
defer s.wg.Done()
var task Task
for {
s.setupReadyLenMetric()
var execChan chan Task
nq := int64(0)
task, nq, execChan = s.setupExecListener(task)
select {
case req, ok := <-s.receiveChan:
if !ok {
log.Info("receiveChan closed, processing remaining request")
// drain policy maintained task
for task != nil {
execChan <- task
s.updateWaitingTaskCounter(-1, -nq)
task = s.produceExecChan()
}
log.Info("all task put into exeChan, schedule worker exit")
close(s.execChan)
return
}
// Receive add operation request and return the process result.
// And consume recv chan as much as possible.
s.consumeRecvChan(req, maxReceiveChanBatchConsumeNum)
case execChan <- task:
// Task sent, drop the ownership of sent task.
// Update waiting task counter.
s.updateWaitingTaskCounter(-1, -nq)
// And produce new task into execChan as much as possible.
task = s.produceExecChan()
}
}
}
// consumeRecvChan consume the recv chan as much as possible.
func (s *scheduler) consumeRecvChan(req addTaskReq, limit int) {
// Check the dynamic wait task limit.
maxWaitTaskNum := paramtable.Get().QueryNodeCfg.MaxUnsolvedQueueSize.GetAsInt64()
if !s.handleAddTaskRequest(req, maxWaitTaskNum) {
return
}
// consume the add chan until reaching the batch operation limit
for i := 1; i < limit; i++ {
select {
case req, ok := <-s.receiveChan:
if !ok {
return
}
if !s.handleAddTaskRequest(req, maxWaitTaskNum) {
return
}
default:
return
}
}
}
// HandleAddTaskRequest handle a add task request.
// Return true if the process can be continued.
func (s *scheduler) handleAddTaskRequest(req addTaskReq, maxWaitTaskNum int64) bool {
if err := req.task.Canceled(); err != nil {
log.Warn("task canceled before enqueue", zap.Error(err))
req.err <- err
} else {
// Push the task into the policy to schedule and update the counter of the ready queue.
nq := req.task.NQ()
newTaskAdded, err := s.policy.Push(req.task)
if err == nil {
s.updateWaitingTaskCounter(int64(newTaskAdded), nq)
}
req.err <- err
}
// Continue processing if the queue isn't reach the max limit.
return s.GetWaitingTaskTotal() < maxWaitTaskNum
}
// produceExecChan produce task from scheduler into exec chan as much as possible
func (s *scheduler) produceExecChan() Task {
var task Task
for {
var execChan chan Task
nq := int64(0)
task, nq, execChan = s.setupExecListener(task)
select {
case execChan <- task:
// Update waiting task counter.
s.updateWaitingTaskCounter(-1, -nq)
// Task sent, drop the ownership of sent task.
task = nil
default:
return task
}
}
}
// exec exec the ready task in background continuously.
func (s *scheduler) exec() {
defer s.wg.Done()
log.Info("start execute loop")
for {
t, ok := <-s.execChan
if !ok {
log.Info("scheduler execChan closed, worker exit")
return
}
// Skip this task if task is canceled.
if err := t.Canceled(); err != nil {
log.Warn("task canceled before executing", zap.Error(err))
t.Done(err)
continue
}
if err := t.PreExecute(); err != nil {
log.Warn("failed to pre-execute task", zap.Error(err))
t.Done(err)
continue
}
s.getPool(t).Submit(func() (any, error) {
// Update concurrency metric and notify task done.
metrics.QueryNodeReadTaskConcurrency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Inc()
collector.Counter.Inc(metricsinfo.ExecuteQueueType)
err := t.Execute()
// Update all metric after task finished.
metrics.QueryNodeReadTaskConcurrency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Dec()
collector.Counter.Dec(metricsinfo.ExecuteQueueType)
// Notify task done.
t.Done(err)
return nil, err
})
}
}
func (s *scheduler) getPool(t Task) *conc.Pool[any] {
if t.IsGpuIndex() {
return s.gpuPool
}
return s.pool
}
// setupExecListener setup the execChan and next task to run.
func (s *scheduler) setupExecListener(lastWaitingTask Task) (Task, int64, chan Task) {
var execChan chan Task
nq := int64(0)
if lastWaitingTask == nil {
// No task is waiting to send to execChan, schedule a new one from queue.
lastWaitingTask = s.policy.Pop()
}
if lastWaitingTask != nil {
// Try to sent task to execChan if there is a task ready to run.
execChan = s.execChan
nq = lastWaitingTask.NQ()
}
return lastWaitingTask, nq, execChan
}
// setupReadyLenMetric update the read task ready len metric.
func (s *scheduler) setupReadyLenMetric() {
waitingTaskCount := s.GetWaitingTaskTotal()
// Update the ReadyQueue counter for quota.
collector.Counter.Set(metricsinfo.ReadyQueueType, waitingTaskCount)
// Record the waiting task length of policy as ready task metric.
metrics.QueryNodeReadTaskReadyLen.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Set(float64(waitingTaskCount))
}
// scheduler counter implement, concurrent safe.
type schedulerCounter struct {
waitingTaskTotal atomic.Int64
waitingTaskTotalNQ atomic.Int64
}
// GetWaitingTaskTotal get ready task counts.
func (s *schedulerCounter) GetWaitingTaskTotal() int64 {
return s.waitingTaskTotal.Load()
}
// GetWaitingTaskTotalNQ get ready task NQ.
func (s *schedulerCounter) GetWaitingTaskTotalNQ() int64 {
return s.waitingTaskTotalNQ.Load()
}
// updateWaitingTaskCounter update the waiting task counter for observing.
func (s *schedulerCounter) updateWaitingTaskCounter(num int64, nq int64) {
s.waitingTaskTotal.Add(num)
s.waitingTaskTotalNQ.Add(nq)
}