Add KNN cgo pool (#23526)

Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
This commit is contained in:
congqixia 2023-04-19 23:56:31 +08:00 committed by GitHub
parent 99a0713b0c
commit 4fe363c4b2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 202 additions and 79 deletions

View File

@ -6,7 +6,7 @@ import (
"github.com/milvus-io/milvus/pkg/util/conc"
)
var ioPool *conc.Pool
var ioPool *conc.Pool[any]
var ioPoolInitOnce sync.Once
func initIOPool() {
@ -15,10 +15,10 @@ func initIOPool() {
capacity = 32
}
// error only happens with negative expiry duration or with negative pre-alloc size.
ioPool = conc.NewPool(capacity)
ioPool = conc.NewPool[any](capacity)
}
func getOrCreateIOPool() *conc.Pool {
func getOrCreateIOPool() *conc.Pool[any] {
ioPoolInitOnce.Do(initIOPool)
return ioPool
}

View File

@ -0,0 +1,52 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package segments
import (
"runtime"
"sync"
"github.com/milvus-io/milvus/pkg/util/conc"
"github.com/milvus-io/milvus/pkg/util/paramtable"
ants "github.com/panjf2000/ants/v2"
"go.uber.org/atomic"
)
var (
p atomic.Pointer[conc.Pool[any]]
initOnce sync.Once
)
// InitPool initialize
func InitPool() {
initOnce.Do(func() {
pool := conc.NewPool[any](
paramtable.Get().QueryNodeCfg.MaxReadConcurrency.GetAsInt(),
ants.WithPreAlloc(true),
ants.WithDisablePurge(true),
)
conc.WarmupPool(pool, runtime.LockOSThread)
p.Store(pool)
})
}
// GetPool returns the singleton pool instance.
func GetPool() *conc.Pool[any] {
InitPool()
return p.Load()
}

View File

@ -226,7 +226,11 @@ func (s *LocalSegment) InsertCount() int64 {
s.mut.RLock()
defer s.mut.RUnlock()
rowCount := C.GetRowCount(s.ptr)
var rowCount C.int64_t
GetPool().Submit(func() (any, error) {
rowCount = C.GetRowCount(s.ptr)
return nil, nil
}).Await()
return int64(rowCount)
}
@ -235,7 +239,11 @@ func (s *LocalSegment) RowNum() int64 {
s.mut.RLock()
defer s.mut.RUnlock()
rowCount := C.GetRealCount(s.ptr)
var rowCount C.int64_t
GetPool().Submit(func() (any, error) {
rowCount = C.GetRealCount(s.ptr)
return nil, nil
}).Await()
return int64(rowCount)
}
@ -244,7 +252,11 @@ func (s *LocalSegment) MemSize() int64 {
s.mut.RLock()
defer s.mut.RUnlock()
memoryUsageInBytes := C.GetMemoryUsageInBytes(s.ptr)
var memoryUsageInBytes C.int64_t
GetPool().Submit(func() (any, error) {
memoryUsageInBytes = C.GetMemoryUsageInBytes(s.ptr)
return nil, nil
}).Await()
return int64(memoryUsageInBytes)
}
@ -345,15 +357,19 @@ func (s *LocalSegment) Search(ctx context.Context, searchReq *SearchRequest) (*S
log.Debug("search segment...")
var searchResult SearchResult
tr := timerecord.NewTimeRecorder("cgoSearch")
status := C.Search(s.ptr,
searchReq.plan.cSearchPlan,
searchReq.cPlaceholderGroup,
traceCtx,
C.uint64_t(searchReq.timestamp),
&searchResult.cSearchResult,
)
metrics.QueryNodeSQSegmentLatencyInCore.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel).Observe(float64(tr.ElapseSpan().Milliseconds()))
var status C.CStatus
GetPool().Submit(func() (any, error) {
tr := timerecord.NewTimeRecorder("cgoSearch")
status = C.Search(s.ptr,
searchReq.plan.cSearchPlan,
searchReq.cPlaceholderGroup,
traceCtx,
C.uint64_t(searchReq.timestamp),
&searchResult.cSearchResult,
)
metrics.QueryNodeSQSegmentLatencyInCore.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel).Observe(float64(tr.ElapseSpan().Milliseconds()))
return nil, nil
}).Await()
if err := HandleCStatus(&status, "Search failed"); err != nil {
return nil, err
}
@ -386,17 +402,21 @@ func (s *LocalSegment) Retrieve(ctx context.Context, plan *RetrievePlan) (*segco
}
var retrieveResult RetrieveResult
ts := C.uint64_t(plan.Timestamp)
var status C.CStatus
GetPool().Submit(func() (any, error) {
ts := C.uint64_t(plan.Timestamp)
tr := timerecord.NewTimeRecorder("cgoRetrieve")
status = C.Retrieve(s.ptr,
plan.cRetrievePlan,
traceCtx,
ts,
&retrieveResult.cRetrieveResult,
)
metrics.QueryNodeSQSegmentLatencyInCore.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()),
metrics.QueryLabel).Observe(float64(tr.ElapseSpan().Milliseconds()))
return nil, nil
}).Await()
tr := timerecord.NewTimeRecorder("cgoRetrieve")
status := C.Retrieve(s.ptr,
plan.cRetrievePlan,
traceCtx,
ts,
&retrieveResult.cRetrieveResult,
)
metrics.QueryNodeSQSegmentLatencyInCore.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()),
metrics.QueryLabel).Observe(float64(tr.ElapseSpan().Milliseconds()))
log.Debug("do retrieve on segment",
zap.Int64("msgID", plan.msgID),
zap.String("segmentType", s.typ.String()),
@ -405,6 +425,7 @@ func (s *LocalSegment) Retrieve(ctx context.Context, plan *RetrievePlan) (*segco
if err := HandleCStatus(&status, "Retrieve failed"); err != nil {
return nil, err
}
result := new(segcorepb.RetrieveResults)
if err := HandleCProto(&retrieveResult.cRetrieveResult, result); err != nil {
return nil, err
@ -486,7 +507,11 @@ func (s *LocalSegment) preInsert(numOfRecords int) (int64, error) {
var offset int64
cOffset := (*C.int64_t)(&offset)
status := C.PreInsert(s.ptr, C.int64_t(int64(numOfRecords)), cOffset)
var status C.CStatus
GetPool().Submit(func() (any, error) {
status = C.PreInsert(s.ptr, C.int64_t(int64(numOfRecords)), cOffset)
return nil, nil
}).Await()
if err := HandleCStatus(&status, "PreInsert failed"); err != nil {
return 0, err
}
@ -498,7 +523,11 @@ func (s *LocalSegment) preDelete(numOfRecords int) int64 {
long int
PreDelete(CSegmentInterface c_segment, long int size);
*/
offset := C.PreDelete(s.ptr, C.int64_t(int64(numOfRecords)))
var offset C.int64_t
GetPool().Submit(func() (any, error) {
offset = C.PreDelete(s.ptr, C.int64_t(int64(numOfRecords)))
return nil, nil
}).Await()
return int64(offset)
}
@ -530,14 +559,19 @@ func (s *LocalSegment) Insert(rowIDs []int64, timestamps []typeutil.Timestamp, r
var cEntityIdsPtr = (*C.int64_t)(&(rowIDs)[0])
var cTimestampsPtr = (*C.uint64_t)(&(timestamps)[0])
status := C.Insert(s.ptr,
cOffset,
cNumOfRows,
cEntityIdsPtr,
cTimestampsPtr,
(*C.uint8_t)(unsafe.Pointer(&insertRecordBlob[0])),
(C.uint64_t)(len(insertRecordBlob)),
)
var status C.CStatus
GetPool().Submit(func() (any, error) {
status = C.Insert(s.ptr,
cOffset,
cNumOfRows,
cEntityIdsPtr,
cTimestampsPtr,
(*C.uint8_t)(unsafe.Pointer(&insertRecordBlob[0])),
(C.uint64_t)(len(insertRecordBlob)),
)
return nil, nil
}).Await()
if err := HandleCStatus(&status, "Insert failed"); err != nil {
return err
}
@ -604,13 +638,17 @@ func (s *LocalSegment) Delete(primaryKeys []storage.PrimaryKey, timestamps []typ
if err != nil {
return fmt.Errorf("failed to marshal ids: %s", err)
}
status := C.Delete(s.ptr,
cOffset,
cSize,
(*C.uint8_t)(unsafe.Pointer(&dataBlob[0])),
(C.uint64_t)(len(dataBlob)),
cTimestampsPtr,
)
var status C.CStatus
GetPool().Submit(func() (any, error) {
status = C.Delete(s.ptr,
cOffset,
cSize,
(*C.uint8_t)(unsafe.Pointer(&dataBlob[0])),
(C.uint64_t)(len(dataBlob)),
cTimestampsPtr,
)
return nil, nil
}).Await()
if err := HandleCStatus(&status, "Delete failed"); err != nil {
return err
@ -667,7 +705,11 @@ func (s *LocalSegment) LoadField(rowCount int64, data *schemapb.FieldData) error
mmap_dir_path: mmapDirPath,
}
status := C.LoadFieldData(s.ptr, loadInfo)
var status C.CStatus
GetPool().Submit(func() (any, error) {
status = C.LoadFieldData(s.ptr, loadInfo)
return nil, nil
}).Await()
if err := HandleCStatus(&status, "LoadFieldData failed"); err != nil {
return err
}
@ -739,7 +781,12 @@ func (s *LocalSegment) LoadDeltaData(deltaData *storage.DeleteData) error {
CStatus
LoadDeletedRecord(CSegmentInterface c_segment, CLoadDeletedRecordInfo deleted_record_info)
*/
status := C.LoadDeletedRecord(s.ptr, loadInfo)
var status C.CStatus
GetPool().Submit(func() (any, error) {
status = C.LoadDeletedRecord(s.ptr, loadInfo)
return nil, nil
}).Await()
if err := HandleCStatus(&status, "LoadDeletedRecord failed"); err != nil {
return err
}
@ -783,7 +830,12 @@ func (s *LocalSegment) LoadIndex(bytesIndex [][]byte, indexInfo *querypb.FieldIn
zap.Int64("segmentID", s.ID()),
)
status := C.UpdateSealedSegmentIndex(s.ptr, loadIndexInfo.cLoadIndexInfo)
var status C.CStatus
GetPool().Submit(func() (any, error) {
status = C.UpdateSealedSegmentIndex(s.ptr, loadIndexInfo.cLoadIndexInfo)
return nil, nil
}).Await()
if err := HandleCStatus(&status, "UpdateSealedSegmentIndex failed"); err != nil {
return err
}

View File

@ -89,7 +89,7 @@ func NewLoader(
ioPoolSize = configPoolSize
}
ioPool := conc.NewPool(ioPoolSize, ants.WithPreAlloc(true))
ioPool := conc.NewPool[*storage.Blob](ioPoolSize, ants.WithPreAlloc(true))
log.Info("SegmentLoader created", zap.Int("ioPoolSize", ioPoolSize))
@ -106,7 +106,7 @@ func NewLoader(
type segmentLoader struct {
manager CollectionManager
cm storage.ChunkManager
ioPool *conc.Pool
ioPool *conc.Pool[*storage.Blob]
}
var _ Loader = (*segmentLoader)(nil)
@ -418,7 +418,7 @@ func (loader *segmentLoader) loadGrowingSegmentFields(ctx context.Context, segme
iCodec := storage.InsertCodec{}
// change all field bin log loading into concurrent
loadFutures := make([]*conc.Future[any], 0, len(fieldBinlogs))
loadFutures := make([]*conc.Future[*storage.Blob], 0, len(fieldBinlogs))
for _, fieldBinlog := range fieldBinlogs {
futures := loader.loadFieldBinlogsAsync(ctx, fieldBinlog)
loadFutures = append(loadFutures, futures...)
@ -431,8 +431,7 @@ func (loader *segmentLoader) loadGrowingSegmentFields(ctx context.Context, segme
return future.Err()
}
blob := future.Value().(*storage.Blob)
blobs[index] = blob
blobs[index] = future.Value()
}
log.Info("log field binlogs done",
zap.Int64("collection", segment.collectionID),
@ -507,8 +506,7 @@ func (loader *segmentLoader) loadSealedField(ctx context.Context, segment *Local
blobs := make([]*storage.Blob, len(futures))
for index, future := range futures {
blob := future.Value().(*storage.Blob)
blobs[index] = blob
blobs[index] = future.Value()
}
insertData := storage.InsertData{
@ -524,11 +522,11 @@ func (loader *segmentLoader) loadSealedField(ctx context.Context, segment *Local
}
// Load binlogs concurrently into memory from KV storage asyncly
func (loader *segmentLoader) loadFieldBinlogsAsync(ctx context.Context, field *datapb.FieldBinlog) []*conc.Future[any] {
futures := make([]*conc.Future[any], 0, len(field.Binlogs))
func (loader *segmentLoader) loadFieldBinlogsAsync(ctx context.Context, field *datapb.FieldBinlog) []*conc.Future[*storage.Blob] {
futures := make([]*conc.Future[*storage.Blob], 0, len(field.Binlogs))
for i := range field.Binlogs {
path := field.Binlogs[i].GetLogPath()
future := loader.ioPool.Submit(func() (interface{}, error) {
future := loader.ioPool.Submit(func() (*storage.Blob, error) {
binLog, err := loader.cm.Read(ctx, path)
if err != nil {
log.Warn("failed to load binlog", zap.String("filePath", path), zap.Error(err))
@ -571,7 +569,7 @@ func (loader *segmentLoader) loadFieldsIndex(ctx context.Context, segment *Local
func (loader *segmentLoader) loadFieldIndex(ctx context.Context, segment *LocalSegment, indexInfo *querypb.FieldIndexInfo) error {
indexBuffer := make([][]byte, 0, len(indexInfo.IndexFilePaths))
filteredPaths := make([]string, 0, len(indexInfo.IndexFilePaths))
futures := make([]*conc.Future[any], 0, len(indexInfo.IndexFilePaths))
futures := make([]*conc.Future[*storage.Blob], 0, len(indexInfo.IndexFilePaths))
indexCodec := storage.NewIndexFileBinlogCodec()
// TODO, remove the load index info froam
@ -616,7 +614,7 @@ func (loader *segmentLoader) loadFieldIndex(ctx context.Context, segment *LocalS
// load in memory index
for _, p := range indexInfo.IndexFilePaths {
indexPath := p
indexFuture := loader.ioPool.Submit(func() (interface{}, error) {
indexFuture := loader.ioPool.Submit(func() (*storage.Blob, error) {
log.Info("load index file", zap.String("path", indexPath))
data, err := loader.cm.Read(ctx, indexPath)
if err != nil {
@ -624,7 +622,10 @@ func (loader *segmentLoader) loadFieldIndex(ctx context.Context, segment *LocalS
return nil, err
}
blobs, _, _, _, err := indexCodec.Deserialize([]*storage.Blob{{Key: path.Base(indexPath), Value: data}})
return blobs, err
if err != nil {
return nil, err
}
return blobs[0], nil
})
futures = append(futures, indexFuture)
@ -636,8 +637,7 @@ func (loader *segmentLoader) loadFieldIndex(ctx context.Context, segment *LocalS
}
for _, index := range futures {
blobs := index.Value().([]*storage.Blob)
indexBuffer = append(indexBuffer, blobs[0].Value)
indexBuffer = append(indexBuffer, index.Value().GetValue())
}
return segment.LoadIndex(indexBuffer, indexInfo, fieldType)

View File

@ -59,7 +59,6 @@ import (
"github.com/milvus-io/milvus/pkg/config"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/mq/msgdispatcher"
"github.com/milvus-io/milvus/pkg/util/conc"
"github.com/milvus-io/milvus/pkg/util/gc"
"github.com/milvus-io/milvus/pkg/util/hardware"
"github.com/milvus-io/milvus/pkg/util/lifetime"
@ -119,8 +118,9 @@ type QueryNode struct {
vectorStorage storage.ChunkManager
etcdKV *etcdkv.EtcdKV
// Pool for search/query
taskPool *conc.Pool
/*
// Pool for search/query
knnPool *conc.Pool*/
// parameter turning hook
queryHook queryHook
@ -271,7 +271,6 @@ func (node *QueryNode) Init() error {
node.etcdKV = etcdkv.NewEtcdKV(node.etcdCli, paramtable.Get().EtcdCfg.MetaRootPath.GetValue())
log.Info("queryNode try to connect etcd success", zap.String("MetaRootPath", paramtable.Get().EtcdCfg.MetaRootPath.GetValue()))
node.taskPool = conc.NewDefaultPool()
node.scheduler = tasks.NewScheduler()
node.clusterManager = cluster.NewWorkerManager(func(nodeID int64) (cluster.Worker, error) {

View File

@ -25,7 +25,7 @@ type Scheduler struct {
queryProcessQueue chan *QueryTask
queryWaitQueue chan *QueryTask
pool *conc.Pool
pool *conc.Pool[any]
}
func NewScheduler() *Scheduler {
@ -38,7 +38,7 @@ func NewScheduler() *Scheduler {
mergedSearchTasks: make(chan *SearchTask, maxReadConcurrency),
// queryProcessQueue: make(chan),
pool: conc.NewPool(maxReadConcurrency, ants.WithPreAlloc(true)),
pool: conc.NewPool[any](maxReadConcurrency, ants.WithPreAlloc(true)),
}
}

View File

@ -18,41 +18,43 @@ package conc
import (
"runtime"
"sync"
"github.com/milvus-io/milvus/pkg/util/generic"
ants "github.com/panjf2000/ants/v2"
)
// A goroutine pool
type Pool struct {
type Pool[T any] struct {
inner *ants.Pool
}
// NewPool returns a goroutine pool.
// cap: the number of workers.
// This panic if provide any invalid option.
func NewPool(cap int, opts ...ants.Option) *Pool {
func NewPool[T any](cap int, opts ...ants.Option) *Pool[T] {
pool, err := ants.NewPool(cap, opts...)
if err != nil {
panic(err)
}
return &Pool{
return &Pool[T]{
inner: pool,
}
}
// NewDefaultPool returns a pool with cap of the number of logical CPU,
// and pre-alloced goroutines.
func NewDefaultPool() *Pool {
return NewPool(runtime.GOMAXPROCS(0), ants.WithPreAlloc(true))
func NewDefaultPool[T any]() *Pool[T] {
return NewPool[T](runtime.GOMAXPROCS(0), ants.WithPreAlloc(true))
}
// Submit a task into the pool,
// executes it asynchronously.
// This will block if the pool has finite workers and no idle worker.
// NOTE: As now golang doesn't support the member method being generic, we use Future[any]
func (pool *Pool) Submit(method func() (any, error)) *Future[any] {
future := newFuture[any]()
func (pool *Pool[T]) Submit(method func() (T, error)) *Future[T] {
future := newFuture[T]()
err := pool.inner.Submit(func() {
defer close(future.ch)
res, err := method()
@ -71,20 +73,38 @@ func (pool *Pool) Submit(method func() (any, error)) *Future[any] {
}
// The number of workers
func (pool *Pool) Cap() int {
func (pool *Pool[T]) Cap() int {
return pool.inner.Cap()
}
// The number of running workers
func (pool *Pool) Running() int {
func (pool *Pool[T]) Running() int {
return pool.inner.Running()
}
// Free returns the number of free workers
func (pool *Pool) Free() int {
func (pool *Pool[T]) Free() int {
return pool.inner.Free()
}
func (pool *Pool) Release() {
func (pool *Pool[T]) Release() {
pool.inner.Release()
}
// WarmupPool do warm up logic for each goroutine in pool
func WarmupPool[T any](pool *Pool[T], warmup func()) {
cap := pool.Cap()
ch := make(chan struct{})
wg := sync.WaitGroup{}
wg.Add(cap)
for i := 0; i < cap; i++ {
pool.Submit(func() (T, error) {
warmup()
wg.Done()
<-ch
return generic.Zero[T](), nil
})
}
wg.Wait()
close(ch)
}

View File

@ -24,13 +24,13 @@ import (
)
func TestPool(t *testing.T) {
pool := NewDefaultPool()
pool := NewDefaultPool[any]()
taskNum := pool.Cap() * 2
futures := make([]*Future[any], 0, taskNum)
for i := 0; i < taskNum; i++ {
res := i
future := pool.Submit(func() (interface{}, error) {
future := pool.Submit(func() (any, error) {
time.Sleep(500 * time.Millisecond)
return res, nil
})