mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-11-30 02:48:45 +08:00
2ef6cbbf59
issue: #36672 The expression supports filling elements through templates, which helps to reduce the overhead of parsing the elements. --------- Signed-off-by: Cai Zhang <cai.zhang@zilliz.com>
707 lines
19 KiB
Go
707 lines
19 KiB
Go
package proxy
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
|
|
"github.com/cockroachdb/errors"
|
|
"go.opentelemetry.io/otel"
|
|
"go.uber.org/atomic"
|
|
"go.uber.org/zap"
|
|
"google.golang.org/protobuf/proto"
|
|
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
|
"github.com/milvus-io/milvus/internal/allocator"
|
|
"github.com/milvus-io/milvus/internal/parser/planparserv2"
|
|
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
|
"github.com/milvus-io/milvus/internal/proto/planpb"
|
|
"github.com/milvus-io/milvus/internal/proto/querypb"
|
|
"github.com/milvus-io/milvus/internal/types"
|
|
"github.com/milvus-io/milvus/internal/util/exprutil"
|
|
"github.com/milvus-io/milvus/internal/util/streamingutil"
|
|
"github.com/milvus-io/milvus/pkg/common"
|
|
"github.com/milvus-io/milvus/pkg/log"
|
|
"github.com/milvus-io/milvus/pkg/metrics"
|
|
"github.com/milvus-io/milvus/pkg/mq/msgstream"
|
|
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
|
|
"github.com/milvus-io/milvus/pkg/util/merr"
|
|
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
|
"github.com/milvus-io/milvus/pkg/util/timerecord"
|
|
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
|
)
|
|
|
|
type BaseDeleteTask = msgstream.DeleteMsg
|
|
|
|
type deleteTask struct {
|
|
baseTask
|
|
Condition
|
|
ctx context.Context
|
|
tr *timerecord.TimeRecorder
|
|
|
|
req *milvuspb.DeleteRequest
|
|
|
|
// channel
|
|
chMgr channelsMgr
|
|
chTicker channelsTimeTicker
|
|
pChannels []pChan
|
|
vChannels []vChan
|
|
|
|
idAllocator allocator.Interface
|
|
|
|
// delete info
|
|
primaryKeys *schemapb.IDs
|
|
collectionID UniqueID
|
|
partitionID UniqueID
|
|
partitionKeyMode bool
|
|
|
|
// set by scheduler
|
|
ts Timestamp
|
|
msgID UniqueID
|
|
|
|
// result
|
|
count int64
|
|
allQueryCnt int64
|
|
|
|
sessionTS Timestamp
|
|
}
|
|
|
|
func (dt *deleteTask) TraceCtx() context.Context {
|
|
return dt.ctx
|
|
}
|
|
|
|
func (dt *deleteTask) ID() UniqueID {
|
|
return dt.msgID
|
|
}
|
|
|
|
func (dt *deleteTask) SetID(uid UniqueID) {
|
|
dt.msgID = uid
|
|
}
|
|
|
|
func (dt *deleteTask) Type() commonpb.MsgType {
|
|
return commonpb.MsgType_Delete
|
|
}
|
|
|
|
func (dt *deleteTask) Name() string {
|
|
return DeleteTaskName
|
|
}
|
|
|
|
func (dt *deleteTask) BeginTs() Timestamp {
|
|
return dt.ts
|
|
}
|
|
|
|
func (dt *deleteTask) EndTs() Timestamp {
|
|
return dt.ts
|
|
}
|
|
|
|
func (dt *deleteTask) SetTs(ts Timestamp) {
|
|
dt.ts = ts
|
|
}
|
|
|
|
func (dt *deleteTask) OnEnqueue() error {
|
|
if dt.req.Base == nil {
|
|
dt.req.Base = commonpbutil.NewMsgBase()
|
|
}
|
|
dt.req.Base.MsgType = commonpb.MsgType_Delete
|
|
dt.req.Base.SourceID = paramtable.GetNodeID()
|
|
return nil
|
|
}
|
|
|
|
func (dt *deleteTask) setChannels() error {
|
|
collID, err := globalMetaCache.GetCollectionID(dt.ctx, dt.req.GetDbName(), dt.req.GetCollectionName())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
channels, err := dt.chMgr.getChannels(collID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
dt.pChannels = channels
|
|
return nil
|
|
}
|
|
|
|
func (dt *deleteTask) getChannels() []pChan {
|
|
return dt.pChannels
|
|
}
|
|
|
|
func (dt *deleteTask) PreExecute(ctx context.Context) error {
|
|
return nil
|
|
}
|
|
|
|
func (dt *deleteTask) Execute(ctx context.Context) (err error) {
|
|
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Delete-Execute")
|
|
defer sp.End()
|
|
// log := log.Ctx(ctx)
|
|
|
|
if len(dt.req.GetExpr()) == 0 {
|
|
return merr.WrapErrParameterInvalid("valid expr", "empty expr", "invalid expression")
|
|
}
|
|
|
|
dt.tr = timerecord.NewTimeRecorder(fmt.Sprintf("proxy execute delete %d", dt.ID()))
|
|
stream, err := dt.chMgr.getOrCreateDmlStream(dt.collectionID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
result, numRows, err := repackDeleteMsgByHash(
|
|
ctx,
|
|
dt.primaryKeys,
|
|
dt.vChannels,
|
|
dt.idAllocator,
|
|
dt.ts,
|
|
dt.collectionID,
|
|
dt.req.GetCollectionName(),
|
|
dt.partitionID,
|
|
dt.req.GetPartitionName(),
|
|
)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// send delete request to log broker
|
|
msgPack := &msgstream.MsgPack{
|
|
BeginTs: dt.BeginTs(),
|
|
EndTs: dt.EndTs(),
|
|
}
|
|
|
|
for _, msgs := range result {
|
|
for _, msg := range msgs {
|
|
msgPack.Msgs = append(msgPack.Msgs, msg)
|
|
}
|
|
}
|
|
|
|
log.Debug("send delete request to virtual channels",
|
|
zap.String("collectionName", dt.req.GetCollectionName()),
|
|
zap.Int64("collectionID", dt.collectionID),
|
|
zap.Strings("virtual_channels", dt.vChannels),
|
|
zap.Int64("taskID", dt.ID()),
|
|
zap.Duration("prepare duration", dt.tr.RecordSpan()))
|
|
|
|
err = stream.Produce(msgPack)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
dt.sessionTS = dt.ts
|
|
dt.count += numRows
|
|
return nil
|
|
}
|
|
|
|
func (dt *deleteTask) PostExecute(ctx context.Context) error {
|
|
metrics.ProxyDeleteVectors.WithLabelValues(paramtable.GetStringNodeID(), dt.req.GetDbName()).Add(float64(dt.count))
|
|
return nil
|
|
}
|
|
|
|
func repackDeleteMsgByHash(
|
|
ctx context.Context,
|
|
primaryKeys *schemapb.IDs,
|
|
vChannels []string,
|
|
idAllocator allocator.Interface,
|
|
ts uint64,
|
|
collectionID int64,
|
|
collectionName string,
|
|
partitionID int64,
|
|
partitionName string,
|
|
) (map[uint32][]*msgstream.DeleteMsg, int64, error) {
|
|
maxSize := Params.PulsarCfg.MaxMessageSize.GetAsInt()
|
|
hashValues := typeutil.HashPK2Channels(primaryKeys, vChannels)
|
|
// repack delete msg by dmChannel
|
|
result := make(map[uint32][]*msgstream.DeleteMsg)
|
|
lastMessageSize := map[uint32]int{}
|
|
|
|
numRows := int64(0)
|
|
numMessage := 0
|
|
|
|
createMessage := func(key uint32, vchannel string) *msgstream.DeleteMsg {
|
|
numMessage++
|
|
lastMessageSize[key] = 0
|
|
return &msgstream.DeleteMsg{
|
|
BaseMsg: msgstream.BaseMsg{
|
|
Ctx: ctx,
|
|
},
|
|
DeleteRequest: &msgpb.DeleteRequest{
|
|
Base: commonpbutil.NewMsgBase(
|
|
commonpbutil.WithMsgType(commonpb.MsgType_Delete),
|
|
// msgid of delete msg must be set later
|
|
// or it will be seen as duplicated msg in mq
|
|
commonpbutil.WithTimeStamp(ts),
|
|
commonpbutil.WithSourceID(paramtable.GetNodeID()),
|
|
),
|
|
CollectionID: collectionID,
|
|
PartitionID: partitionID,
|
|
CollectionName: collectionName,
|
|
PartitionName: partitionName,
|
|
PrimaryKeys: &schemapb.IDs{},
|
|
ShardName: vchannel,
|
|
},
|
|
}
|
|
}
|
|
|
|
for index, key := range hashValues {
|
|
vchannel := vChannels[key]
|
|
msgs, ok := result[key]
|
|
if !ok {
|
|
result[key] = make([]*msgstream.DeleteMsg, 1)
|
|
msgs = result[key]
|
|
result[key][0] = createMessage(key, vchannel)
|
|
}
|
|
curMsg := msgs[len(msgs)-1]
|
|
size, id := typeutil.GetId(primaryKeys, index)
|
|
if lastMessageSize[key]+16+size > maxSize {
|
|
curMsg = createMessage(key, vchannel)
|
|
result[key] = append(result[key], curMsg)
|
|
}
|
|
curMsg.HashValues = append(curMsg.HashValues, hashValues[index])
|
|
curMsg.Timestamps = append(curMsg.Timestamps, ts)
|
|
|
|
typeutil.AppendID(curMsg.PrimaryKeys, id)
|
|
lastMessageSize[key] += 16 + size
|
|
curMsg.NumRows++
|
|
numRows++
|
|
}
|
|
|
|
// alloc messageID
|
|
start, _, err := idAllocator.Alloc(uint32(numMessage))
|
|
if err != nil {
|
|
return nil, 0, err
|
|
}
|
|
|
|
cnt := int64(0)
|
|
for _, msgs := range result {
|
|
for _, msg := range msgs {
|
|
msg.Base.MsgID = start + cnt
|
|
cnt++
|
|
}
|
|
}
|
|
return result, numRows, nil
|
|
}
|
|
|
|
type deleteRunner struct {
|
|
req *milvuspb.DeleteRequest
|
|
result *milvuspb.MutationResult
|
|
|
|
// channel
|
|
chMgr channelsMgr
|
|
chTicker channelsTimeTicker
|
|
vChannels []vChan
|
|
|
|
idAllocator allocator.Interface
|
|
tsoAllocatorIns tsoAllocator
|
|
limiter types.Limiter
|
|
|
|
// delete info
|
|
schema *schemaInfo
|
|
dbID UniqueID
|
|
collectionID UniqueID
|
|
partitionID UniqueID
|
|
partitionKeyMode bool
|
|
|
|
// for query
|
|
msgID int64
|
|
ts uint64
|
|
lb LBPolicy
|
|
count atomic.Int64
|
|
|
|
// task queue
|
|
queue *dmTaskQueue
|
|
|
|
allQueryCnt atomic.Int64
|
|
sessionTS atomic.Uint64
|
|
}
|
|
|
|
func (dr *deleteRunner) Init(ctx context.Context) error {
|
|
log := log.Ctx(ctx)
|
|
var err error
|
|
|
|
collName := dr.req.GetCollectionName()
|
|
if err := validateCollectionName(collName); err != nil {
|
|
return ErrWithLog(log, "Invalid collection name", err)
|
|
}
|
|
|
|
db, err := globalMetaCache.GetDatabaseInfo(ctx, dr.req.GetDbName())
|
|
if err != nil {
|
|
return merr.WrapErrAsInputErrorWhen(err, merr.ErrDatabaseNotFound)
|
|
}
|
|
dr.dbID = db.dbID
|
|
|
|
dr.collectionID, err = globalMetaCache.GetCollectionID(ctx, dr.req.GetDbName(), collName)
|
|
if err != nil {
|
|
return ErrWithLog(log, "Failed to get collection id", merr.WrapErrAsInputErrorWhen(err, merr.ErrCollectionNotFound))
|
|
}
|
|
|
|
dr.schema, err = globalMetaCache.GetCollectionSchema(ctx, dr.req.GetDbName(), collName)
|
|
if err != nil {
|
|
return ErrWithLog(log, "Failed to get collection schema", err)
|
|
}
|
|
|
|
dr.partitionKeyMode = dr.schema.IsPartitionKeyCollection()
|
|
// get partitionIDs of delete
|
|
dr.partitionID = common.AllPartitionsID
|
|
if len(dr.req.PartitionName) > 0 {
|
|
if dr.partitionKeyMode {
|
|
return errors.New("not support manually specifying the partition names if partition key mode is used")
|
|
}
|
|
|
|
partName := dr.req.GetPartitionName()
|
|
if err := validatePartitionTag(partName, true); err != nil {
|
|
return ErrWithLog(log, "Invalid partition name", err)
|
|
}
|
|
partID, err := globalMetaCache.GetPartitionID(ctx, dr.req.GetDbName(), collName, partName)
|
|
if err != nil {
|
|
return ErrWithLog(log, "Failed to get partition id", err)
|
|
}
|
|
dr.partitionID = partID
|
|
}
|
|
|
|
// hash primary keys to channels
|
|
channelNames, err := dr.chMgr.getVChannels(dr.collectionID)
|
|
if err != nil {
|
|
return ErrWithLog(log, "Failed to get primary keys from expr", err)
|
|
}
|
|
dr.vChannels = channelNames
|
|
|
|
dr.result = &milvuspb.MutationResult{
|
|
Status: merr.Success(),
|
|
IDs: &schemapb.IDs{
|
|
IdField: nil,
|
|
},
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (dr *deleteRunner) Run(ctx context.Context) error {
|
|
plan, err := planparserv2.CreateRetrievePlan(dr.schema.schemaHelper, dr.req.GetExpr(), dr.req.GetExprTemplateValues())
|
|
if err != nil {
|
|
return merr.WrapErrAsInputError(merr.WrapErrParameterInvalidMsg("failed to create delete plan: %v", err))
|
|
}
|
|
|
|
if planparserv2.IsAlwaysTruePlan(plan) {
|
|
return merr.WrapErrAsInputError(merr.WrapErrParameterInvalidMsg("delete plan can't be empty or always true : %s", dr.req.GetExpr()))
|
|
}
|
|
|
|
isSimple, pk, numRow := getPrimaryKeysFromPlan(dr.schema.CollectionSchema, plan)
|
|
if isSimple {
|
|
// if could get delete.primaryKeys from delete expr
|
|
err := dr.simpleDelete(ctx, pk, numRow)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
} else {
|
|
// if get complex delete expr
|
|
// need query from querynode before delete
|
|
err = dr.complexDelete(ctx, plan)
|
|
if err != nil {
|
|
log.Warn("complex delete failed,but delete some data", zap.Int64("count", dr.result.DeleteCnt), zap.String("expr", dr.req.GetExpr()))
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (dr *deleteRunner) produce(ctx context.Context, primaryKeys *schemapb.IDs) (*deleteTask, error) {
|
|
dt := &deleteTask{
|
|
ctx: ctx,
|
|
Condition: NewTaskCondition(ctx),
|
|
req: dr.req,
|
|
idAllocator: dr.idAllocator,
|
|
chMgr: dr.chMgr,
|
|
chTicker: dr.chTicker,
|
|
collectionID: dr.collectionID,
|
|
partitionID: dr.partitionID,
|
|
partitionKeyMode: dr.partitionKeyMode,
|
|
vChannels: dr.vChannels,
|
|
primaryKeys: primaryKeys,
|
|
}
|
|
var enqueuedTask task = dt
|
|
if streamingutil.IsStreamingServiceEnabled() {
|
|
enqueuedTask = &deleteTaskByStreamingService{deleteTask: dt}
|
|
}
|
|
|
|
if err := dr.queue.Enqueue(enqueuedTask); err != nil {
|
|
log.Error("Failed to enqueue delete task: " + err.Error())
|
|
return nil, err
|
|
}
|
|
|
|
return dt, nil
|
|
}
|
|
|
|
// getStreamingQueryAndDelteFunc return query function used by LBPolicy
|
|
// make sure it concurrent safe
|
|
func (dr *deleteRunner) getStreamingQueryAndDelteFunc(plan *planpb.PlanNode) executeFunc {
|
|
return func(ctx context.Context, nodeID int64, qn types.QueryNodeClient, channel string) error {
|
|
var partitionIDs []int64
|
|
|
|
// optimize query when partitionKey on
|
|
if dr.partitionKeyMode {
|
|
expr, err := exprutil.ParseExprFromPlan(plan)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
partitionKeys := exprutil.ParseKeys(expr, exprutil.PartitionKey)
|
|
hashedPartitionNames, err := assignPartitionKeys(ctx, dr.req.GetDbName(), dr.req.GetCollectionName(), partitionKeys)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
partitionIDs, err = getPartitionIDs(ctx, dr.req.GetDbName(), dr.req.GetCollectionName(), hashedPartitionNames)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
} else if dr.partitionID != common.InvalidFieldID {
|
|
partitionIDs = []int64{dr.partitionID}
|
|
}
|
|
|
|
log := log.Ctx(ctx).With(
|
|
zap.Int64("collectionID", dr.collectionID),
|
|
zap.Int64s("partitionIDs", partitionIDs),
|
|
zap.String("channel", channel),
|
|
zap.Int64("nodeID", nodeID))
|
|
|
|
// set plan
|
|
_, outputFieldIDs := translatePkOutputFields(dr.schema.CollectionSchema)
|
|
outputFieldIDs = append(outputFieldIDs, common.TimeStampField)
|
|
plan.OutputFieldIds = outputFieldIDs
|
|
|
|
serializedPlan, err := proto.Marshal(plan)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
queryReq := &querypb.QueryRequest{
|
|
Req: &internalpb.RetrieveRequest{
|
|
Base: commonpbutil.NewMsgBase(
|
|
commonpbutil.WithMsgType(commonpb.MsgType_Retrieve),
|
|
commonpbutil.WithMsgID(dr.msgID),
|
|
commonpbutil.WithSourceID(paramtable.GetNodeID()),
|
|
commonpbutil.WithTargetID(nodeID),
|
|
),
|
|
MvccTimestamp: dr.ts,
|
|
ReqID: paramtable.GetNodeID(),
|
|
DbID: 0, // TODO
|
|
CollectionID: dr.collectionID,
|
|
PartitionIDs: partitionIDs,
|
|
SerializedExprPlan: serializedPlan,
|
|
OutputFieldsId: outputFieldIDs,
|
|
GuaranteeTimestamp: parseGuaranteeTsFromConsistency(dr.ts, dr.ts, dr.req.GetConsistencyLevel()),
|
|
},
|
|
DmlChannels: []string{channel},
|
|
Scope: querypb.DataScope_All,
|
|
}
|
|
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
defer cancel()
|
|
log.Debug("start query for delete", zap.Int64("msgID", dr.msgID))
|
|
client, err := qn.QueryStream(ctx, queryReq)
|
|
if err != nil {
|
|
log.Warn("query stream for delete create failed", zap.Error(err))
|
|
return err
|
|
}
|
|
|
|
taskCh := make(chan *deleteTask, 256)
|
|
var receiveErr error
|
|
go func() {
|
|
receiveErr = dr.receiveQueryResult(ctx, client, taskCh, partitionIDs)
|
|
close(taskCh)
|
|
}()
|
|
var allQueryCnt int64
|
|
// wait all task finish
|
|
var sessionTS uint64
|
|
for task := range taskCh {
|
|
err := task.WaitToFinish()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
dr.count.Add(task.count)
|
|
allQueryCnt += task.allQueryCnt
|
|
if sessionTS < task.sessionTS {
|
|
sessionTS = task.sessionTS
|
|
}
|
|
}
|
|
|
|
// query or produce task failed
|
|
if receiveErr != nil {
|
|
return receiveErr
|
|
}
|
|
dr.allQueryCnt.Add(allQueryCnt)
|
|
dr.sessionTS.Store(sessionTS)
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func (dr *deleteRunner) receiveQueryResult(ctx context.Context, client querypb.QueryNode_QueryStreamClient, taskCh chan *deleteTask, partitionIDs []int64) error {
|
|
for {
|
|
result, err := client.Recv()
|
|
if err != nil {
|
|
if err == io.EOF {
|
|
log.Debug("query stream for delete finished", zap.Int64("msgID", dr.msgID))
|
|
return nil
|
|
}
|
|
return err
|
|
}
|
|
|
|
err = merr.Error(result.GetStatus())
|
|
if err != nil {
|
|
log.Warn("query stream for delete get error status", zap.Int64("msgID", dr.msgID), zap.Error(err))
|
|
return err
|
|
}
|
|
|
|
if dr.limiter != nil {
|
|
err := dr.limiter.Alloc(ctx, dr.dbID, map[int64][]int64{dr.collectionID: partitionIDs}, internalpb.RateType_DMLDelete, proto.Size(result.GetIds()))
|
|
if err != nil {
|
|
log.Warn("query stream for delete failed because rate limiter", zap.Int64("msgID", dr.msgID), zap.Error(err))
|
|
return err
|
|
}
|
|
}
|
|
|
|
task, err := dr.produce(ctx, result.GetIds())
|
|
if err != nil {
|
|
log.Warn("produce delete task failed", zap.Error(err))
|
|
return err
|
|
}
|
|
task.allQueryCnt = result.GetAllRetrieveCount()
|
|
|
|
taskCh <- task
|
|
}
|
|
}
|
|
|
|
func (dr *deleteRunner) complexDelete(ctx context.Context, plan *planpb.PlanNode) error {
|
|
rc := timerecord.NewTimeRecorder("QueryStreamDelete")
|
|
var err error
|
|
|
|
dr.msgID, err = dr.idAllocator.AllocOne()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
dr.ts, err = dr.tsoAllocatorIns.AllocOne(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = dr.lb.Execute(ctx, CollectionWorkLoad{
|
|
db: dr.req.GetDbName(),
|
|
collectionName: dr.req.GetCollectionName(),
|
|
collectionID: dr.collectionID,
|
|
nq: 1,
|
|
exec: dr.getStreamingQueryAndDelteFunc(plan),
|
|
})
|
|
dr.result.DeleteCnt = dr.count.Load()
|
|
dr.result.Timestamp = dr.sessionTS.Load()
|
|
if err != nil {
|
|
log.Warn("fail to execute complex delete",
|
|
zap.Int64("deleteCnt", dr.result.GetDeleteCnt()),
|
|
zap.Duration("interval", rc.ElapseSpan()),
|
|
zap.Error(err))
|
|
return err
|
|
}
|
|
|
|
log.Info("complex delete finished", zap.Int64("deleteCnt", dr.result.GetDeleteCnt()), zap.Duration("interval", rc.ElapseSpan()))
|
|
return nil
|
|
}
|
|
|
|
func (dr *deleteRunner) simpleDelete(ctx context.Context, pk *schemapb.IDs, numRow int64) error {
|
|
log.Debug("get primary keys from expr",
|
|
zap.Int64("len of primary keys", numRow),
|
|
zap.Int64("collectionID", dr.collectionID),
|
|
zap.Int64("partitionID", dr.partitionID))
|
|
|
|
task, err := dr.produce(ctx, pk)
|
|
if err != nil {
|
|
log.Warn("produce delete task failed")
|
|
return err
|
|
}
|
|
|
|
err = task.WaitToFinish()
|
|
if err == nil {
|
|
dr.result.DeleteCnt = task.count
|
|
dr.result.Timestamp = task.sessionTS
|
|
}
|
|
return err
|
|
}
|
|
|
|
func getPrimaryKeysFromPlan(schema *schemapb.CollectionSchema, plan *planpb.PlanNode) (bool, *schemapb.IDs, int64) {
|
|
// simple delete request need expr with "pk in [a, b]"
|
|
termExpr, ok := plan.Node.(*planpb.PlanNode_Query).Query.Predicates.Expr.(*planpb.Expr_TermExpr)
|
|
if ok {
|
|
if !termExpr.TermExpr.GetColumnInfo().GetIsPrimaryKey() {
|
|
return false, nil, 0
|
|
}
|
|
|
|
ids, rowNum, err := getPrimaryKeysFromTermExpr(schema, termExpr)
|
|
if err != nil {
|
|
return false, nil, 0
|
|
}
|
|
return true, ids, rowNum
|
|
}
|
|
|
|
// simple delete if expr with "pk == a"
|
|
unaryRangeExpr, ok := plan.Node.(*planpb.PlanNode_Query).Query.Predicates.Expr.(*planpb.Expr_UnaryRangeExpr)
|
|
if ok {
|
|
if unaryRangeExpr.UnaryRangeExpr.GetOp() != planpb.OpType_Equal || !unaryRangeExpr.UnaryRangeExpr.GetColumnInfo().GetIsPrimaryKey() {
|
|
return false, nil, 0
|
|
}
|
|
|
|
ids, err := getPrimaryKeysFromUnaryRangeExpr(schema, unaryRangeExpr)
|
|
if err != nil {
|
|
return false, nil, 0
|
|
}
|
|
return true, ids, 1
|
|
}
|
|
|
|
return false, nil, 0
|
|
}
|
|
|
|
func getPrimaryKeysFromUnaryRangeExpr(schema *schemapb.CollectionSchema, unaryRangeExpr *planpb.Expr_UnaryRangeExpr) (res *schemapb.IDs, err error) {
|
|
res = &schemapb.IDs{}
|
|
switch unaryRangeExpr.UnaryRangeExpr.GetColumnInfo().GetDataType() {
|
|
case schemapb.DataType_Int64:
|
|
res.IdField = &schemapb.IDs_IntId{
|
|
IntId: &schemapb.LongArray{
|
|
Data: []int64{unaryRangeExpr.UnaryRangeExpr.GetValue().GetInt64Val()},
|
|
},
|
|
}
|
|
case schemapb.DataType_VarChar:
|
|
res.IdField = &schemapb.IDs_StrId{
|
|
StrId: &schemapb.StringArray{
|
|
Data: []string{unaryRangeExpr.UnaryRangeExpr.GetValue().GetStringVal()},
|
|
},
|
|
}
|
|
default:
|
|
return res, fmt.Errorf("invalid field data type specifyed in simple delete expr")
|
|
}
|
|
|
|
return res, nil
|
|
}
|
|
|
|
func getPrimaryKeysFromTermExpr(schema *schemapb.CollectionSchema, termExpr *planpb.Expr_TermExpr) (res *schemapb.IDs, rowNum int64, err error) {
|
|
res = &schemapb.IDs{}
|
|
rowNum = int64(len(termExpr.TermExpr.Values))
|
|
switch termExpr.TermExpr.ColumnInfo.GetDataType() {
|
|
case schemapb.DataType_Int64:
|
|
ids := make([]int64, 0)
|
|
for _, v := range termExpr.TermExpr.Values {
|
|
ids = append(ids, v.GetInt64Val())
|
|
}
|
|
res.IdField = &schemapb.IDs_IntId{
|
|
IntId: &schemapb.LongArray{
|
|
Data: ids,
|
|
},
|
|
}
|
|
case schemapb.DataType_VarChar:
|
|
ids := make([]string, 0)
|
|
for _, v := range termExpr.TermExpr.Values {
|
|
ids = append(ids, v.GetStringVal())
|
|
}
|
|
res.IdField = &schemapb.IDs_StrId{
|
|
StrId: &schemapb.StringArray{
|
|
Data: ids,
|
|
},
|
|
}
|
|
default:
|
|
return res, 0, fmt.Errorf("invalid field data type specifyed in simple delete expr")
|
|
}
|
|
|
|
return res, rowNum, nil
|
|
}
|