mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-12-05 05:18:52 +08:00
7af02fa531
Signed-off-by: yah01 <yang.cen@zilliz.com>
238 lines
4.9 KiB
Go
238 lines
4.9 KiB
Go
// Licensed to the LF AI & Data foundation under one
|
|
// or more contributor license agreements. See the NOTICE file
|
|
// distributed with this work for additional information
|
|
// regarding copyright ownership. The ASF licenses this file
|
|
// to you under the Apache License, Version 2.0 (the
|
|
// "License"); you may not use this file except in compliance
|
|
// with the License. You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package funcutil
|
|
|
|
import (
|
|
"reflect"
|
|
"runtime"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/milvus-io/milvus/internal/log"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
// GetFunctionName returns the name of input
|
|
func GetFunctionName(i interface{}) string {
|
|
return runtime.FuncForPC(reflect.ValueOf(i).Pointer()).Name()
|
|
}
|
|
|
|
type TaskFunc func() error
|
|
type ProcessFunc func(idx int) error
|
|
type DataProcessFunc func(data interface{}) error
|
|
|
|
// ProcessFuncParallel processes function in parallel.
|
|
//
|
|
// ProcessFuncParallel waits for all goroutines done if no errors occur.
|
|
// If some goroutines return error, ProcessFuncParallel cancels other goroutines as soon as possible and wait
|
|
// for all other goroutines done, and returns the first error occurs.
|
|
// Reference: https://stackoverflow.com/questions/40809504/idiomatic-goroutine-termination-and-error-handling
|
|
func ProcessFuncParallel(total, maxParallel int, f ProcessFunc, fname string) error {
|
|
if maxParallel <= 0 {
|
|
maxParallel = 1
|
|
}
|
|
|
|
t := time.Now()
|
|
defer func() {
|
|
log.Debug(fname, zap.Any("time cost", time.Since(t)))
|
|
}()
|
|
|
|
nPerBatch := (total + maxParallel - 1) / maxParallel
|
|
log.Debug(fname, zap.Any("total", total))
|
|
log.Debug(fname, zap.Any("nPerBatch", nPerBatch))
|
|
|
|
quit := make(chan bool)
|
|
errc := make(chan error)
|
|
done := make(chan error)
|
|
getMin := func(a, b int) int {
|
|
if a < b {
|
|
return a
|
|
}
|
|
return b
|
|
}
|
|
routineNum := 0
|
|
var wg sync.WaitGroup
|
|
for begin := 0; begin < total; begin = begin + nPerBatch {
|
|
j := begin
|
|
wg.Add(1)
|
|
go func(begin int) {
|
|
defer wg.Done()
|
|
|
|
select {
|
|
case <-quit:
|
|
return
|
|
default:
|
|
}
|
|
|
|
err := error(nil)
|
|
|
|
end := getMin(total, begin+nPerBatch)
|
|
for idx := begin; idx < end; idx++ {
|
|
err = f(idx)
|
|
if err != nil {
|
|
log.Error(fname, zap.Error(err), zap.Any("idx", idx))
|
|
break
|
|
}
|
|
}
|
|
|
|
ch := done // send to done channel
|
|
if err != nil {
|
|
ch = errc // send to error channel
|
|
}
|
|
|
|
select {
|
|
case ch <- err:
|
|
return
|
|
case <-quit:
|
|
return
|
|
}
|
|
}(j)
|
|
|
|
routineNum++
|
|
}
|
|
|
|
log.Debug(fname, zap.Any("NumOfGoRoutines", routineNum))
|
|
|
|
if routineNum <= 0 {
|
|
return nil
|
|
}
|
|
|
|
count := 0
|
|
for {
|
|
select {
|
|
case err := <-errc:
|
|
close(quit)
|
|
wg.Wait()
|
|
return err
|
|
case <-done:
|
|
count++
|
|
if count == routineNum {
|
|
wg.Wait()
|
|
return nil
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// ProcessTaskParallel processes tasks in parallel.
|
|
// Similar to ProcessFuncParallel
|
|
func ProcessTaskParallel(maxParallel int, fname string, tasks ...TaskFunc) error {
|
|
// option := parallelProcessOption{}
|
|
// for _, opt := range opts {
|
|
// opt(&option)
|
|
// }
|
|
|
|
if maxParallel <= 0 {
|
|
maxParallel = 1
|
|
}
|
|
|
|
t := time.Now()
|
|
defer func() {
|
|
log.Debug(fname, zap.Any("time cost", time.Since(t)))
|
|
}()
|
|
|
|
total := len(tasks)
|
|
nPerBatch := (total + maxParallel - 1) / maxParallel
|
|
log.Debug(fname, zap.Any("total", total))
|
|
log.Debug(fname, zap.Any("nPerBatch", nPerBatch))
|
|
|
|
quit := make(chan bool)
|
|
errc := make(chan error)
|
|
done := make(chan error)
|
|
getMin := func(a, b int) int {
|
|
if a < b {
|
|
return a
|
|
}
|
|
return b
|
|
}
|
|
routineNum := 0
|
|
var wg sync.WaitGroup
|
|
for begin := 0; begin < total; begin = begin + nPerBatch {
|
|
j := begin
|
|
|
|
// if option.preExecute != nil {
|
|
// err := option.preExecute()
|
|
// if err != nil {
|
|
// close(quit)
|
|
// wg.Wait()
|
|
// return err
|
|
// }
|
|
// }
|
|
|
|
wg.Add(1)
|
|
go func(begin int) {
|
|
defer wg.Done()
|
|
|
|
select {
|
|
case <-quit:
|
|
return
|
|
default:
|
|
}
|
|
|
|
err := error(nil)
|
|
|
|
end := getMin(total, begin+nPerBatch)
|
|
for idx := begin; idx < end; idx++ {
|
|
err = tasks[idx]()
|
|
if err != nil {
|
|
log.Error(fname, zap.Error(err), zap.Any("idx", idx))
|
|
break
|
|
}
|
|
}
|
|
|
|
ch := done // send to done channel
|
|
if err != nil {
|
|
ch = errc // send to error channel
|
|
}
|
|
|
|
select {
|
|
case ch <- err:
|
|
return
|
|
case <-quit:
|
|
return
|
|
}
|
|
}(j)
|
|
// if option.postExecute != nil {
|
|
// option.postExecute()
|
|
// }
|
|
|
|
routineNum++
|
|
}
|
|
|
|
log.Debug(fname, zap.Any("NumOfGoRoutines", routineNum))
|
|
|
|
if routineNum <= 0 {
|
|
return nil
|
|
}
|
|
|
|
count := 0
|
|
for {
|
|
select {
|
|
case err := <-errc:
|
|
close(quit)
|
|
wg.Wait()
|
|
return err
|
|
case <-done:
|
|
count++
|
|
if count == routineNum {
|
|
wg.Wait()
|
|
return nil
|
|
}
|
|
}
|
|
}
|
|
}
|