mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-12-02 11:59:00 +08:00
Add KNN cgo pool (#23526)
Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
This commit is contained in:
parent
99a0713b0c
commit
4fe363c4b2
@ -6,7 +6,7 @@ import (
|
||||
"github.com/milvus-io/milvus/pkg/util/conc"
|
||||
)
|
||||
|
||||
var ioPool *conc.Pool
|
||||
var ioPool *conc.Pool[any]
|
||||
var ioPoolInitOnce sync.Once
|
||||
|
||||
func initIOPool() {
|
||||
@ -15,10 +15,10 @@ func initIOPool() {
|
||||
capacity = 32
|
||||
}
|
||||
// error only happens with negative expiry duration or with negative pre-alloc size.
|
||||
ioPool = conc.NewPool(capacity)
|
||||
ioPool = conc.NewPool[any](capacity)
|
||||
}
|
||||
|
||||
func getOrCreateIOPool() *conc.Pool {
|
||||
func getOrCreateIOPool() *conc.Pool[any] {
|
||||
ioPoolInitOnce.Do(initIOPool)
|
||||
return ioPool
|
||||
}
|
||||
|
52
internal/querynodev2/segments/pool.go
Normal file
52
internal/querynodev2/segments/pool.go
Normal file
@ -0,0 +1,52 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package segments
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"sync"
|
||||
|
||||
"github.com/milvus-io/milvus/pkg/util/conc"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
ants "github.com/panjf2000/ants/v2"
|
||||
"go.uber.org/atomic"
|
||||
)
|
||||
|
||||
var (
|
||||
p atomic.Pointer[conc.Pool[any]]
|
||||
initOnce sync.Once
|
||||
)
|
||||
|
||||
// InitPool initialize
|
||||
func InitPool() {
|
||||
initOnce.Do(func() {
|
||||
pool := conc.NewPool[any](
|
||||
paramtable.Get().QueryNodeCfg.MaxReadConcurrency.GetAsInt(),
|
||||
ants.WithPreAlloc(true),
|
||||
ants.WithDisablePurge(true),
|
||||
)
|
||||
conc.WarmupPool(pool, runtime.LockOSThread)
|
||||
|
||||
p.Store(pool)
|
||||
})
|
||||
}
|
||||
|
||||
// GetPool returns the singleton pool instance.
|
||||
func GetPool() *conc.Pool[any] {
|
||||
InitPool()
|
||||
return p.Load()
|
||||
}
|
@ -226,7 +226,11 @@ func (s *LocalSegment) InsertCount() int64 {
|
||||
s.mut.RLock()
|
||||
defer s.mut.RUnlock()
|
||||
|
||||
rowCount := C.GetRowCount(s.ptr)
|
||||
var rowCount C.int64_t
|
||||
GetPool().Submit(func() (any, error) {
|
||||
rowCount = C.GetRowCount(s.ptr)
|
||||
return nil, nil
|
||||
}).Await()
|
||||
|
||||
return int64(rowCount)
|
||||
}
|
||||
@ -235,7 +239,11 @@ func (s *LocalSegment) RowNum() int64 {
|
||||
s.mut.RLock()
|
||||
defer s.mut.RUnlock()
|
||||
|
||||
rowCount := C.GetRealCount(s.ptr)
|
||||
var rowCount C.int64_t
|
||||
GetPool().Submit(func() (any, error) {
|
||||
rowCount = C.GetRealCount(s.ptr)
|
||||
return nil, nil
|
||||
}).Await()
|
||||
|
||||
return int64(rowCount)
|
||||
}
|
||||
@ -244,7 +252,11 @@ func (s *LocalSegment) MemSize() int64 {
|
||||
s.mut.RLock()
|
||||
defer s.mut.RUnlock()
|
||||
|
||||
memoryUsageInBytes := C.GetMemoryUsageInBytes(s.ptr)
|
||||
var memoryUsageInBytes C.int64_t
|
||||
GetPool().Submit(func() (any, error) {
|
||||
memoryUsageInBytes = C.GetMemoryUsageInBytes(s.ptr)
|
||||
return nil, nil
|
||||
}).Await()
|
||||
|
||||
return int64(memoryUsageInBytes)
|
||||
}
|
||||
@ -345,15 +357,19 @@ func (s *LocalSegment) Search(ctx context.Context, searchReq *SearchRequest) (*S
|
||||
log.Debug("search segment...")
|
||||
|
||||
var searchResult SearchResult
|
||||
tr := timerecord.NewTimeRecorder("cgoSearch")
|
||||
status := C.Search(s.ptr,
|
||||
searchReq.plan.cSearchPlan,
|
||||
searchReq.cPlaceholderGroup,
|
||||
traceCtx,
|
||||
C.uint64_t(searchReq.timestamp),
|
||||
&searchResult.cSearchResult,
|
||||
)
|
||||
metrics.QueryNodeSQSegmentLatencyInCore.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel).Observe(float64(tr.ElapseSpan().Milliseconds()))
|
||||
var status C.CStatus
|
||||
GetPool().Submit(func() (any, error) {
|
||||
tr := timerecord.NewTimeRecorder("cgoSearch")
|
||||
status = C.Search(s.ptr,
|
||||
searchReq.plan.cSearchPlan,
|
||||
searchReq.cPlaceholderGroup,
|
||||
traceCtx,
|
||||
C.uint64_t(searchReq.timestamp),
|
||||
&searchResult.cSearchResult,
|
||||
)
|
||||
metrics.QueryNodeSQSegmentLatencyInCore.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel).Observe(float64(tr.ElapseSpan().Milliseconds()))
|
||||
return nil, nil
|
||||
}).Await()
|
||||
if err := HandleCStatus(&status, "Search failed"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -386,17 +402,21 @@ func (s *LocalSegment) Retrieve(ctx context.Context, plan *RetrievePlan) (*segco
|
||||
}
|
||||
|
||||
var retrieveResult RetrieveResult
|
||||
ts := C.uint64_t(plan.Timestamp)
|
||||
var status C.CStatus
|
||||
GetPool().Submit(func() (any, error) {
|
||||
ts := C.uint64_t(plan.Timestamp)
|
||||
tr := timerecord.NewTimeRecorder("cgoRetrieve")
|
||||
status = C.Retrieve(s.ptr,
|
||||
plan.cRetrievePlan,
|
||||
traceCtx,
|
||||
ts,
|
||||
&retrieveResult.cRetrieveResult,
|
||||
)
|
||||
metrics.QueryNodeSQSegmentLatencyInCore.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()),
|
||||
metrics.QueryLabel).Observe(float64(tr.ElapseSpan().Milliseconds()))
|
||||
return nil, nil
|
||||
}).Await()
|
||||
|
||||
tr := timerecord.NewTimeRecorder("cgoRetrieve")
|
||||
status := C.Retrieve(s.ptr,
|
||||
plan.cRetrievePlan,
|
||||
traceCtx,
|
||||
ts,
|
||||
&retrieveResult.cRetrieveResult,
|
||||
)
|
||||
metrics.QueryNodeSQSegmentLatencyInCore.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()),
|
||||
metrics.QueryLabel).Observe(float64(tr.ElapseSpan().Milliseconds()))
|
||||
log.Debug("do retrieve on segment",
|
||||
zap.Int64("msgID", plan.msgID),
|
||||
zap.String("segmentType", s.typ.String()),
|
||||
@ -405,6 +425,7 @@ func (s *LocalSegment) Retrieve(ctx context.Context, plan *RetrievePlan) (*segco
|
||||
if err := HandleCStatus(&status, "Retrieve failed"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := new(segcorepb.RetrieveResults)
|
||||
if err := HandleCProto(&retrieveResult.cRetrieveResult, result); err != nil {
|
||||
return nil, err
|
||||
@ -486,7 +507,11 @@ func (s *LocalSegment) preInsert(numOfRecords int) (int64, error) {
|
||||
var offset int64
|
||||
cOffset := (*C.int64_t)(&offset)
|
||||
|
||||
status := C.PreInsert(s.ptr, C.int64_t(int64(numOfRecords)), cOffset)
|
||||
var status C.CStatus
|
||||
GetPool().Submit(func() (any, error) {
|
||||
status = C.PreInsert(s.ptr, C.int64_t(int64(numOfRecords)), cOffset)
|
||||
return nil, nil
|
||||
}).Await()
|
||||
if err := HandleCStatus(&status, "PreInsert failed"); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@ -498,7 +523,11 @@ func (s *LocalSegment) preDelete(numOfRecords int) int64 {
|
||||
long int
|
||||
PreDelete(CSegmentInterface c_segment, long int size);
|
||||
*/
|
||||
offset := C.PreDelete(s.ptr, C.int64_t(int64(numOfRecords)))
|
||||
var offset C.int64_t
|
||||
GetPool().Submit(func() (any, error) {
|
||||
offset = C.PreDelete(s.ptr, C.int64_t(int64(numOfRecords)))
|
||||
return nil, nil
|
||||
}).Await()
|
||||
return int64(offset)
|
||||
}
|
||||
|
||||
@ -530,14 +559,19 @@ func (s *LocalSegment) Insert(rowIDs []int64, timestamps []typeutil.Timestamp, r
|
||||
var cEntityIdsPtr = (*C.int64_t)(&(rowIDs)[0])
|
||||
var cTimestampsPtr = (*C.uint64_t)(&(timestamps)[0])
|
||||
|
||||
status := C.Insert(s.ptr,
|
||||
cOffset,
|
||||
cNumOfRows,
|
||||
cEntityIdsPtr,
|
||||
cTimestampsPtr,
|
||||
(*C.uint8_t)(unsafe.Pointer(&insertRecordBlob[0])),
|
||||
(C.uint64_t)(len(insertRecordBlob)),
|
||||
)
|
||||
var status C.CStatus
|
||||
|
||||
GetPool().Submit(func() (any, error) {
|
||||
status = C.Insert(s.ptr,
|
||||
cOffset,
|
||||
cNumOfRows,
|
||||
cEntityIdsPtr,
|
||||
cTimestampsPtr,
|
||||
(*C.uint8_t)(unsafe.Pointer(&insertRecordBlob[0])),
|
||||
(C.uint64_t)(len(insertRecordBlob)),
|
||||
)
|
||||
return nil, nil
|
||||
}).Await()
|
||||
if err := HandleCStatus(&status, "Insert failed"); err != nil {
|
||||
return err
|
||||
}
|
||||
@ -604,13 +638,17 @@ func (s *LocalSegment) Delete(primaryKeys []storage.PrimaryKey, timestamps []typ
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal ids: %s", err)
|
||||
}
|
||||
status := C.Delete(s.ptr,
|
||||
cOffset,
|
||||
cSize,
|
||||
(*C.uint8_t)(unsafe.Pointer(&dataBlob[0])),
|
||||
(C.uint64_t)(len(dataBlob)),
|
||||
cTimestampsPtr,
|
||||
)
|
||||
var status C.CStatus
|
||||
GetPool().Submit(func() (any, error) {
|
||||
status = C.Delete(s.ptr,
|
||||
cOffset,
|
||||
cSize,
|
||||
(*C.uint8_t)(unsafe.Pointer(&dataBlob[0])),
|
||||
(C.uint64_t)(len(dataBlob)),
|
||||
cTimestampsPtr,
|
||||
)
|
||||
return nil, nil
|
||||
}).Await()
|
||||
|
||||
if err := HandleCStatus(&status, "Delete failed"); err != nil {
|
||||
return err
|
||||
@ -667,7 +705,11 @@ func (s *LocalSegment) LoadField(rowCount int64, data *schemapb.FieldData) error
|
||||
mmap_dir_path: mmapDirPath,
|
||||
}
|
||||
|
||||
status := C.LoadFieldData(s.ptr, loadInfo)
|
||||
var status C.CStatus
|
||||
GetPool().Submit(func() (any, error) {
|
||||
status = C.LoadFieldData(s.ptr, loadInfo)
|
||||
return nil, nil
|
||||
}).Await()
|
||||
if err := HandleCStatus(&status, "LoadFieldData failed"); err != nil {
|
||||
return err
|
||||
}
|
||||
@ -739,7 +781,12 @@ func (s *LocalSegment) LoadDeltaData(deltaData *storage.DeleteData) error {
|
||||
CStatus
|
||||
LoadDeletedRecord(CSegmentInterface c_segment, CLoadDeletedRecordInfo deleted_record_info)
|
||||
*/
|
||||
status := C.LoadDeletedRecord(s.ptr, loadInfo)
|
||||
var status C.CStatus
|
||||
GetPool().Submit(func() (any, error) {
|
||||
status = C.LoadDeletedRecord(s.ptr, loadInfo)
|
||||
return nil, nil
|
||||
}).Await()
|
||||
|
||||
if err := HandleCStatus(&status, "LoadDeletedRecord failed"); err != nil {
|
||||
return err
|
||||
}
|
||||
@ -783,7 +830,12 @@ func (s *LocalSegment) LoadIndex(bytesIndex [][]byte, indexInfo *querypb.FieldIn
|
||||
zap.Int64("segmentID", s.ID()),
|
||||
)
|
||||
|
||||
status := C.UpdateSealedSegmentIndex(s.ptr, loadIndexInfo.cLoadIndexInfo)
|
||||
var status C.CStatus
|
||||
GetPool().Submit(func() (any, error) {
|
||||
status = C.UpdateSealedSegmentIndex(s.ptr, loadIndexInfo.cLoadIndexInfo)
|
||||
return nil, nil
|
||||
}).Await()
|
||||
|
||||
if err := HandleCStatus(&status, "UpdateSealedSegmentIndex failed"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -89,7 +89,7 @@ func NewLoader(
|
||||
ioPoolSize = configPoolSize
|
||||
}
|
||||
|
||||
ioPool := conc.NewPool(ioPoolSize, ants.WithPreAlloc(true))
|
||||
ioPool := conc.NewPool[*storage.Blob](ioPoolSize, ants.WithPreAlloc(true))
|
||||
|
||||
log.Info("SegmentLoader created", zap.Int("ioPoolSize", ioPoolSize))
|
||||
|
||||
@ -106,7 +106,7 @@ func NewLoader(
|
||||
type segmentLoader struct {
|
||||
manager CollectionManager
|
||||
cm storage.ChunkManager
|
||||
ioPool *conc.Pool
|
||||
ioPool *conc.Pool[*storage.Blob]
|
||||
}
|
||||
|
||||
var _ Loader = (*segmentLoader)(nil)
|
||||
@ -418,7 +418,7 @@ func (loader *segmentLoader) loadGrowingSegmentFields(ctx context.Context, segme
|
||||
iCodec := storage.InsertCodec{}
|
||||
|
||||
// change all field bin log loading into concurrent
|
||||
loadFutures := make([]*conc.Future[any], 0, len(fieldBinlogs))
|
||||
loadFutures := make([]*conc.Future[*storage.Blob], 0, len(fieldBinlogs))
|
||||
for _, fieldBinlog := range fieldBinlogs {
|
||||
futures := loader.loadFieldBinlogsAsync(ctx, fieldBinlog)
|
||||
loadFutures = append(loadFutures, futures...)
|
||||
@ -431,8 +431,7 @@ func (loader *segmentLoader) loadGrowingSegmentFields(ctx context.Context, segme
|
||||
return future.Err()
|
||||
}
|
||||
|
||||
blob := future.Value().(*storage.Blob)
|
||||
blobs[index] = blob
|
||||
blobs[index] = future.Value()
|
||||
}
|
||||
log.Info("log field binlogs done",
|
||||
zap.Int64("collection", segment.collectionID),
|
||||
@ -507,8 +506,7 @@ func (loader *segmentLoader) loadSealedField(ctx context.Context, segment *Local
|
||||
|
||||
blobs := make([]*storage.Blob, len(futures))
|
||||
for index, future := range futures {
|
||||
blob := future.Value().(*storage.Blob)
|
||||
blobs[index] = blob
|
||||
blobs[index] = future.Value()
|
||||
}
|
||||
|
||||
insertData := storage.InsertData{
|
||||
@ -524,11 +522,11 @@ func (loader *segmentLoader) loadSealedField(ctx context.Context, segment *Local
|
||||
}
|
||||
|
||||
// Load binlogs concurrently into memory from KV storage asyncly
|
||||
func (loader *segmentLoader) loadFieldBinlogsAsync(ctx context.Context, field *datapb.FieldBinlog) []*conc.Future[any] {
|
||||
futures := make([]*conc.Future[any], 0, len(field.Binlogs))
|
||||
func (loader *segmentLoader) loadFieldBinlogsAsync(ctx context.Context, field *datapb.FieldBinlog) []*conc.Future[*storage.Blob] {
|
||||
futures := make([]*conc.Future[*storage.Blob], 0, len(field.Binlogs))
|
||||
for i := range field.Binlogs {
|
||||
path := field.Binlogs[i].GetLogPath()
|
||||
future := loader.ioPool.Submit(func() (interface{}, error) {
|
||||
future := loader.ioPool.Submit(func() (*storage.Blob, error) {
|
||||
binLog, err := loader.cm.Read(ctx, path)
|
||||
if err != nil {
|
||||
log.Warn("failed to load binlog", zap.String("filePath", path), zap.Error(err))
|
||||
@ -571,7 +569,7 @@ func (loader *segmentLoader) loadFieldsIndex(ctx context.Context, segment *Local
|
||||
func (loader *segmentLoader) loadFieldIndex(ctx context.Context, segment *LocalSegment, indexInfo *querypb.FieldIndexInfo) error {
|
||||
indexBuffer := make([][]byte, 0, len(indexInfo.IndexFilePaths))
|
||||
filteredPaths := make([]string, 0, len(indexInfo.IndexFilePaths))
|
||||
futures := make([]*conc.Future[any], 0, len(indexInfo.IndexFilePaths))
|
||||
futures := make([]*conc.Future[*storage.Blob], 0, len(indexInfo.IndexFilePaths))
|
||||
indexCodec := storage.NewIndexFileBinlogCodec()
|
||||
|
||||
// TODO, remove the load index info froam
|
||||
@ -616,7 +614,7 @@ func (loader *segmentLoader) loadFieldIndex(ctx context.Context, segment *LocalS
|
||||
// load in memory index
|
||||
for _, p := range indexInfo.IndexFilePaths {
|
||||
indexPath := p
|
||||
indexFuture := loader.ioPool.Submit(func() (interface{}, error) {
|
||||
indexFuture := loader.ioPool.Submit(func() (*storage.Blob, error) {
|
||||
log.Info("load index file", zap.String("path", indexPath))
|
||||
data, err := loader.cm.Read(ctx, indexPath)
|
||||
if err != nil {
|
||||
@ -624,7 +622,10 @@ func (loader *segmentLoader) loadFieldIndex(ctx context.Context, segment *LocalS
|
||||
return nil, err
|
||||
}
|
||||
blobs, _, _, _, err := indexCodec.Deserialize([]*storage.Blob{{Key: path.Base(indexPath), Value: data}})
|
||||
return blobs, err
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return blobs[0], nil
|
||||
})
|
||||
|
||||
futures = append(futures, indexFuture)
|
||||
@ -636,8 +637,7 @@ func (loader *segmentLoader) loadFieldIndex(ctx context.Context, segment *LocalS
|
||||
}
|
||||
|
||||
for _, index := range futures {
|
||||
blobs := index.Value().([]*storage.Blob)
|
||||
indexBuffer = append(indexBuffer, blobs[0].Value)
|
||||
indexBuffer = append(indexBuffer, index.Value().GetValue())
|
||||
}
|
||||
|
||||
return segment.LoadIndex(indexBuffer, indexInfo, fieldType)
|
||||
|
@ -59,7 +59,6 @@ import (
|
||||
"github.com/milvus-io/milvus/pkg/config"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/mq/msgdispatcher"
|
||||
"github.com/milvus-io/milvus/pkg/util/conc"
|
||||
"github.com/milvus-io/milvus/pkg/util/gc"
|
||||
"github.com/milvus-io/milvus/pkg/util/hardware"
|
||||
"github.com/milvus-io/milvus/pkg/util/lifetime"
|
||||
@ -119,8 +118,9 @@ type QueryNode struct {
|
||||
vectorStorage storage.ChunkManager
|
||||
etcdKV *etcdkv.EtcdKV
|
||||
|
||||
// Pool for search/query
|
||||
taskPool *conc.Pool
|
||||
/*
|
||||
// Pool for search/query
|
||||
knnPool *conc.Pool*/
|
||||
|
||||
// parameter turning hook
|
||||
queryHook queryHook
|
||||
@ -271,7 +271,6 @@ func (node *QueryNode) Init() error {
|
||||
node.etcdKV = etcdkv.NewEtcdKV(node.etcdCli, paramtable.Get().EtcdCfg.MetaRootPath.GetValue())
|
||||
log.Info("queryNode try to connect etcd success", zap.String("MetaRootPath", paramtable.Get().EtcdCfg.MetaRootPath.GetValue()))
|
||||
|
||||
node.taskPool = conc.NewDefaultPool()
|
||||
node.scheduler = tasks.NewScheduler()
|
||||
|
||||
node.clusterManager = cluster.NewWorkerManager(func(nodeID int64) (cluster.Worker, error) {
|
||||
|
@ -25,7 +25,7 @@ type Scheduler struct {
|
||||
queryProcessQueue chan *QueryTask
|
||||
queryWaitQueue chan *QueryTask
|
||||
|
||||
pool *conc.Pool
|
||||
pool *conc.Pool[any]
|
||||
}
|
||||
|
||||
func NewScheduler() *Scheduler {
|
||||
@ -38,7 +38,7 @@ func NewScheduler() *Scheduler {
|
||||
mergedSearchTasks: make(chan *SearchTask, maxReadConcurrency),
|
||||
// queryProcessQueue: make(chan),
|
||||
|
||||
pool: conc.NewPool(maxReadConcurrency, ants.WithPreAlloc(true)),
|
||||
pool: conc.NewPool[any](maxReadConcurrency, ants.WithPreAlloc(true)),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -18,41 +18,43 @@ package conc
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"sync"
|
||||
|
||||
"github.com/milvus-io/milvus/pkg/util/generic"
|
||||
ants "github.com/panjf2000/ants/v2"
|
||||
)
|
||||
|
||||
// A goroutine pool
|
||||
type Pool struct {
|
||||
type Pool[T any] struct {
|
||||
inner *ants.Pool
|
||||
}
|
||||
|
||||
// NewPool returns a goroutine pool.
|
||||
// cap: the number of workers.
|
||||
// This panic if provide any invalid option.
|
||||
func NewPool(cap int, opts ...ants.Option) *Pool {
|
||||
func NewPool[T any](cap int, opts ...ants.Option) *Pool[T] {
|
||||
pool, err := ants.NewPool(cap, opts...)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return &Pool{
|
||||
return &Pool[T]{
|
||||
inner: pool,
|
||||
}
|
||||
}
|
||||
|
||||
// NewDefaultPool returns a pool with cap of the number of logical CPU,
|
||||
// and pre-alloced goroutines.
|
||||
func NewDefaultPool() *Pool {
|
||||
return NewPool(runtime.GOMAXPROCS(0), ants.WithPreAlloc(true))
|
||||
func NewDefaultPool[T any]() *Pool[T] {
|
||||
return NewPool[T](runtime.GOMAXPROCS(0), ants.WithPreAlloc(true))
|
||||
}
|
||||
|
||||
// Submit a task into the pool,
|
||||
// executes it asynchronously.
|
||||
// This will block if the pool has finite workers and no idle worker.
|
||||
// NOTE: As now golang doesn't support the member method being generic, we use Future[any]
|
||||
func (pool *Pool) Submit(method func() (any, error)) *Future[any] {
|
||||
future := newFuture[any]()
|
||||
func (pool *Pool[T]) Submit(method func() (T, error)) *Future[T] {
|
||||
future := newFuture[T]()
|
||||
err := pool.inner.Submit(func() {
|
||||
defer close(future.ch)
|
||||
res, err := method()
|
||||
@ -71,20 +73,38 @@ func (pool *Pool) Submit(method func() (any, error)) *Future[any] {
|
||||
}
|
||||
|
||||
// The number of workers
|
||||
func (pool *Pool) Cap() int {
|
||||
func (pool *Pool[T]) Cap() int {
|
||||
return pool.inner.Cap()
|
||||
}
|
||||
|
||||
// The number of running workers
|
||||
func (pool *Pool) Running() int {
|
||||
func (pool *Pool[T]) Running() int {
|
||||
return pool.inner.Running()
|
||||
}
|
||||
|
||||
// Free returns the number of free workers
|
||||
func (pool *Pool) Free() int {
|
||||
func (pool *Pool[T]) Free() int {
|
||||
return pool.inner.Free()
|
||||
}
|
||||
|
||||
func (pool *Pool) Release() {
|
||||
func (pool *Pool[T]) Release() {
|
||||
pool.inner.Release()
|
||||
}
|
||||
|
||||
// WarmupPool do warm up logic for each goroutine in pool
|
||||
func WarmupPool[T any](pool *Pool[T], warmup func()) {
|
||||
cap := pool.Cap()
|
||||
ch := make(chan struct{})
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(cap)
|
||||
for i := 0; i < cap; i++ {
|
||||
pool.Submit(func() (T, error) {
|
||||
warmup()
|
||||
wg.Done()
|
||||
<-ch
|
||||
return generic.Zero[T](), nil
|
||||
})
|
||||
}
|
||||
wg.Wait()
|
||||
close(ch)
|
||||
}
|
||||
|
@ -24,13 +24,13 @@ import (
|
||||
)
|
||||
|
||||
func TestPool(t *testing.T) {
|
||||
pool := NewDefaultPool()
|
||||
pool := NewDefaultPool[any]()
|
||||
|
||||
taskNum := pool.Cap() * 2
|
||||
futures := make([]*Future[any], 0, taskNum)
|
||||
for i := 0; i < taskNum; i++ {
|
||||
res := i
|
||||
future := pool.Submit(func() (interface{}, error) {
|
||||
future := pool.Submit(func() (any, error) {
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
return res, nil
|
||||
})
|
||||
|
Loading…
Reference in New Issue
Block a user