// // Licensed to the LF AI & Data foundation under one // // or more contributor license agreements. See the NOTICE file // // distributed with this work for additional information // // regarding copyright ownership. The ASF licenses this file // // to you under the Apache License, Version 2.0 (the // // "License"); you may not use this file except in compliance // // with the License. You may obtain a copy of the License at // // // // http://www.apache.org/licenses/LICENSE-2.0 // // // // Unless required by applicable law or agreed to in writing, software // // distributed under the License is distributed on an "AS IS" BASIS, // // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // // See the License for the specific language governing permissions and // // limitations under the License. package proxy import ( "context" "fmt" "strconv" "github.com/cockroachdb/errors" "go.opentelemetry.io/otel" "go.uber.org/zap" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/allocator" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/util/commonpbutil" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/timerecord" "github.com/milvus-io/milvus/pkg/util/typeutil" ) type upsertTask struct { baseTask Condition upsertMsg *msgstream.UpsertMsg req *milvuspb.UpsertRequest baseMsg msgstream.BaseMsg ctx context.Context timestamps []uint64 rowIDs []int64 result *milvuspb.MutationResult idAllocator *allocator.IDAllocator segIDAssigner *segIDAssigner collectionID UniqueID chMgr channelsMgr chTicker channelsTimeTicker vChannels []vChan pChannels []pChan schema *schemaInfo partitionKeyMode bool partitionKeys *schemapb.FieldData } // TraceCtx returns upsertTask context func (it *upsertTask) TraceCtx() context.Context { return it.ctx } func (it *upsertTask) ID() UniqueID { return it.req.Base.MsgID } func (it *upsertTask) SetID(uid UniqueID) { it.req.Base.MsgID = uid } func (it *upsertTask) Name() string { return UpsertTaskName } func (it *upsertTask) Type() commonpb.MsgType { return it.req.Base.MsgType } func (it *upsertTask) BeginTs() Timestamp { return it.baseMsg.BeginTimestamp } func (it *upsertTask) SetTs(ts Timestamp) { it.baseMsg.BeginTimestamp = ts it.baseMsg.EndTimestamp = ts } func (it *upsertTask) EndTs() Timestamp { return it.baseMsg.EndTimestamp } func (it *upsertTask) getPChanStats() (map[pChan]pChanStatistics, error) { ret := make(map[pChan]pChanStatistics) channels := it.getChannels() beginTs := it.BeginTs() endTs := it.EndTs() for _, channel := range channels { ret[channel] = pChanStatistics{ minTs: beginTs, maxTs: endTs, } } return ret, nil } func (it *upsertTask) setChannels() error { collID, err := globalMetaCache.GetCollectionID(it.ctx, it.req.GetDbName(), it.req.CollectionName) if err != nil { return err } channels, err := it.chMgr.getChannels(collID) if err != nil { return err } it.pChannels = channels return nil } func (it *upsertTask) getChannels() []pChan { return it.pChannels } func (it *upsertTask) OnEnqueue() error { if it.req.Base == nil { it.req.Base = commonpbutil.NewMsgBase() } it.req.Base.MsgType = commonpb.MsgType_Upsert it.req.Base.SourceID = paramtable.GetNodeID() return nil } func (it *upsertTask) insertPreExecute(ctx context.Context) error { collectionName := it.upsertMsg.InsertMsg.CollectionName if err := validateCollectionName(collectionName); err != nil { log.Error("valid collection name failed", zap.String("collectionName", collectionName), zap.Error(err)) return err } rowNums := uint32(it.upsertMsg.InsertMsg.NRows()) // set upsertTask.insertRequest.rowIDs tr := timerecord.NewTimeRecorder("applyPK") rowIDBegin, rowIDEnd, _ := it.idAllocator.Alloc(rowNums) metrics.ProxyApplyPrimaryKeyLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Observe(float64(tr.ElapseSpan().Milliseconds())) it.upsertMsg.InsertMsg.RowIDs = make([]UniqueID, rowNums) it.rowIDs = make([]UniqueID, rowNums) for i := rowIDBegin; i < rowIDEnd; i++ { offset := i - rowIDBegin it.upsertMsg.InsertMsg.RowIDs[offset] = i it.rowIDs[offset] = i } // set upsertTask.insertRequest.timeStamps rowNum := it.upsertMsg.InsertMsg.NRows() it.upsertMsg.InsertMsg.Timestamps = make([]uint64, rowNum) it.timestamps = make([]uint64, rowNum) for index := range it.timestamps { it.upsertMsg.InsertMsg.Timestamps[index] = it.BeginTs() it.timestamps[index] = it.BeginTs() } // set result.SuccIndex sliceIndex := make([]uint32, rowNums) for i := uint32(0); i < rowNums; i++ { sliceIndex[i] = i } it.result.SuccIndex = sliceIndex if it.schema.EnableDynamicField { err := checkDynamicFieldData(it.schema.CollectionSchema, it.upsertMsg.InsertMsg) if err != nil { return err } } // use the passed pk as new pk when autoID == false // automatic generate pk as new pk wehen autoID == true var err error it.result.IDs, err = checkPrimaryFieldData(it.schema.CollectionSchema, it.upsertMsg.InsertMsg, false) log := log.Ctx(ctx).With(zap.String("collectionName", it.upsertMsg.InsertMsg.CollectionName)) if err != nil { log.Warn("check primary field data and hash primary key failed when upsert", zap.Error(err)) return merr.WrapErrAsInputErrorWhen(err, merr.ErrParameterInvalid) } // set field ID to insert field data err = fillFieldIDBySchema(it.upsertMsg.InsertMsg.GetFieldsData(), it.schema.CollectionSchema) if err != nil { log.Warn("insert set fieldID to fieldData failed when upsert", zap.Error(err)) return merr.WrapErrAsInputErrorWhen(err, merr.ErrParameterInvalid) } if it.partitionKeyMode { fieldSchema, _ := typeutil.GetPartitionKeyFieldSchema(it.schema.CollectionSchema) it.partitionKeys, err = getPartitionKeyFieldData(fieldSchema, it.upsertMsg.InsertMsg) if err != nil { log.Warn("get partition keys from insert request failed", zap.String("collectionName", collectionName), zap.Error(err)) return err } } else { partitionTag := it.upsertMsg.InsertMsg.PartitionName if err = validatePartitionTag(partitionTag, true); err != nil { log.Warn("valid partition name failed", zap.String("partition name", partitionTag), zap.Error(err)) return err } } if err := newValidateUtil(withNANCheck(), withOverflowCheck(), withMaxLenCheck()). Validate(it.upsertMsg.InsertMsg.GetFieldsData(), it.schema.CollectionSchema, it.upsertMsg.InsertMsg.NRows()); err != nil { return err } log.Debug("Proxy Upsert insertPreExecute done") return nil } func (it *upsertTask) deletePreExecute(ctx context.Context) error { collName := it.upsertMsg.DeleteMsg.CollectionName log := log.Ctx(ctx).With( zap.String("collectionName", collName)) if err := validateCollectionName(collName); err != nil { log.Info("Invalid collectionName", zap.Error(err)) return err } collID, err := globalMetaCache.GetCollectionID(ctx, it.req.GetDbName(), collName) if err != nil { log.Info("Failed to get collection id", zap.Error(err)) return err } it.upsertMsg.DeleteMsg.CollectionID = collID it.collectionID = collID if it.partitionKeyMode { // multi entities with same pk and diff partition keys may be hashed to multi physical partitions // if deleteMsg.partitionID = common.InvalidPartition, // all segments with this pk under the collection will have the delete record it.upsertMsg.DeleteMsg.PartitionID = common.AllPartitionsID } else { // partition name could be defaultPartitionName or name specified by sdk partName := it.upsertMsg.DeleteMsg.PartitionName if err := validatePartitionTag(partName, true); err != nil { log.Warn("Invalid partition name", zap.String("partitionName", partName), zap.Error(err)) return err } partID, err := globalMetaCache.GetPartitionID(ctx, it.req.GetDbName(), collName, partName) if err != nil { log.Warn("Failed to get partition id", zap.String("collectionName", collName), zap.String("partitionName", partName), zap.Error(err)) return err } it.upsertMsg.DeleteMsg.PartitionID = partID } it.upsertMsg.DeleteMsg.Timestamps = make([]uint64, it.upsertMsg.DeleteMsg.NumRows) for index := range it.upsertMsg.DeleteMsg.Timestamps { it.upsertMsg.DeleteMsg.Timestamps[index] = it.BeginTs() } log.Debug("Proxy Upsert deletePreExecute done") return nil } func (it *upsertTask) PreExecute(ctx context.Context) error { ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Upsert-PreExecute") defer sp.End() collectionName := it.req.CollectionName log := log.Ctx(ctx).With(zap.String("collectionName", collectionName)) it.result = &milvuspb.MutationResult{ Status: merr.Success(), IDs: &schemapb.IDs{ IdField: nil, }, Timestamp: it.EndTs(), } schema, err := globalMetaCache.GetCollectionSchema(ctx, it.req.GetDbName(), collectionName) if err != nil { log.Warn("Failed to get collection schema", zap.String("collectionName", collectionName), zap.Error(err)) return err } it.schema = schema it.partitionKeyMode, err = isPartitionKeyMode(ctx, it.req.GetDbName(), collectionName) if err != nil { log.Warn("check partition key mode failed", zap.String("collectionName", collectionName), zap.Error(err)) return err } if it.partitionKeyMode { if len(it.req.GetPartitionName()) > 0 { return errors.New("not support manually specifying the partition names if partition key mode is used") } } else { // set default partition name if not use partition key // insert to _default partition partitionTag := it.req.GetPartitionName() if len(partitionTag) <= 0 { partitionTag = Params.CommonCfg.DefaultPartitionName.GetValue() it.req.PartitionName = partitionTag } } it.upsertMsg = &msgstream.UpsertMsg{ InsertMsg: &msgstream.InsertMsg{ InsertRequest: msgpb.InsertRequest{ Base: commonpbutil.NewMsgBase( commonpbutil.WithMsgType(commonpb.MsgType_Insert), commonpbutil.WithSourceID(paramtable.GetNodeID()), ), CollectionName: it.req.CollectionName, PartitionName: it.req.PartitionName, FieldsData: it.req.FieldsData, NumRows: uint64(it.req.NumRows), Version: msgpb.InsertDataVersion_ColumnBased, DbName: it.req.DbName, }, }, DeleteMsg: &msgstream.DeleteMsg{ DeleteRequest: msgpb.DeleteRequest{ Base: commonpbutil.NewMsgBase( commonpbutil.WithMsgType(commonpb.MsgType_Delete), commonpbutil.WithSourceID(paramtable.GetNodeID()), ), DbName: it.req.DbName, CollectionName: it.req.CollectionName, NumRows: int64(it.req.NumRows), PartitionName: it.req.PartitionName, CollectionID: it.collectionID, }, }, } err = it.insertPreExecute(ctx) if err != nil { log.Warn("Fail to insertPreExecute", zap.Error(err)) return err } err = it.deletePreExecute(ctx) if err != nil { log.Warn("Fail to deletePreExecute", zap.Error(err)) return err } it.result.DeleteCnt = it.upsertMsg.DeleteMsg.NumRows it.result.InsertCnt = int64(it.upsertMsg.InsertMsg.NumRows) if it.result.DeleteCnt != it.result.InsertCnt { log.Info("DeleteCnt and InsertCnt are not the same when upsert", zap.Int64("DeleteCnt", it.result.DeleteCnt), zap.Int64("InsertCnt", it.result.InsertCnt)) } it.result.UpsertCnt = it.result.InsertCnt log.Debug("Proxy Upsert PreExecute done") return nil } func (it *upsertTask) insertExecute(ctx context.Context, msgPack *msgstream.MsgPack) error { tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy insertExecute upsert %d", it.ID())) defer tr.Elapse("insert execute done when insertExecute") collectionName := it.upsertMsg.InsertMsg.CollectionName collID, err := globalMetaCache.GetCollectionID(ctx, it.req.GetDbName(), collectionName) if err != nil { return err } it.upsertMsg.InsertMsg.CollectionID = collID log := log.Ctx(ctx).With( zap.Int64("collectionID", collID)) getCacheDur := tr.RecordSpan() _, err = it.chMgr.getOrCreateDmlStream(collID) if err != nil { return err } getMsgStreamDur := tr.RecordSpan() channelNames, err := it.chMgr.getVChannels(collID) if err != nil { log.Warn("get vChannels failed when insertExecute", zap.Error(err)) it.result.Status = merr.Status(err) return err } log.Debug("send insert request to virtual channels when insertExecute", zap.String("collection", it.req.GetCollectionName()), zap.String("partition", it.req.GetPartitionName()), zap.Int64("collection_id", collID), zap.Strings("virtual_channels", channelNames), zap.Int64("task_id", it.ID()), zap.Duration("get cache duration", getCacheDur), zap.Duration("get msgStream duration", getMsgStreamDur)) // assign segmentID for insert data and repack data by segmentID var insertMsgPack *msgstream.MsgPack if it.partitionKeys == nil { insertMsgPack, err = repackInsertData(it.TraceCtx(), channelNames, it.upsertMsg.InsertMsg, it.result, it.idAllocator, it.segIDAssigner) } else { insertMsgPack, err = repackInsertDataWithPartitionKey(it.TraceCtx(), channelNames, it.partitionKeys, it.upsertMsg.InsertMsg, it.result, it.idAllocator, it.segIDAssigner) } if err != nil { log.Warn("assign segmentID and repack insert data failed when insertExecute", zap.Error(err)) it.result.Status = merr.Status(err) return err } assignSegmentIDDur := tr.RecordSpan() log.Debug("assign segmentID for insert data success when insertExecute", zap.String("collectionName", it.req.CollectionName), zap.Duration("assign segmentID duration", assignSegmentIDDur)) msgPack.Msgs = append(msgPack.Msgs, insertMsgPack.Msgs...) log.Debug("Proxy Insert Execute done when upsert", zap.String("collectionName", collectionName)) return nil } func (it *upsertTask) deleteExecute(ctx context.Context, msgPack *msgstream.MsgPack) (err error) { tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy deleteExecute upsert %d", it.ID())) collID := it.upsertMsg.DeleteMsg.CollectionID log := log.Ctx(ctx).With( zap.Int64("collectionID", collID)) // hash primary keys to channels channelNames, err := it.chMgr.getVChannels(collID) if err != nil { log.Warn("get vChannels failed when deleteExecute", zap.Error(err)) it.result.Status = merr.Status(err) return err } it.upsertMsg.DeleteMsg.PrimaryKeys = it.result.IDs it.upsertMsg.DeleteMsg.HashValues = typeutil.HashPK2Channels(it.upsertMsg.DeleteMsg.PrimaryKeys, channelNames) // repack delete msg by dmChannel result := make(map[uint32]msgstream.TsMsg) collectionName := it.upsertMsg.DeleteMsg.CollectionName collectionID := it.upsertMsg.DeleteMsg.CollectionID partitionID := it.upsertMsg.DeleteMsg.PartitionID partitionName := it.upsertMsg.DeleteMsg.PartitionName proxyID := it.upsertMsg.DeleteMsg.Base.SourceID for index, key := range it.upsertMsg.DeleteMsg.HashValues { ts := it.upsertMsg.DeleteMsg.Timestamps[index] _, ok := result[key] if !ok { msgid, err := it.idAllocator.AllocOne() if err != nil { errors.Wrap(err, "failed to allocate MsgID for delete of upsert") } sliceRequest := msgpb.DeleteRequest{ Base: commonpbutil.NewMsgBase( commonpbutil.WithMsgType(commonpb.MsgType_Delete), commonpbutil.WithTimeStamp(ts), // id of upsertTask were set as ts in scheduler // msgid of delete msg must be set // or it will be seen as duplicated msg in mq commonpbutil.WithMsgID(msgid), commonpbutil.WithSourceID(proxyID), ), CollectionID: collectionID, PartitionID: partitionID, CollectionName: collectionName, PartitionName: partitionName, PrimaryKeys: &schemapb.IDs{}, } deleteMsg := &msgstream.DeleteMsg{ BaseMsg: msgstream.BaseMsg{ Ctx: ctx, }, DeleteRequest: sliceRequest, } result[key] = deleteMsg } curMsg := result[key].(*msgstream.DeleteMsg) curMsg.HashValues = append(curMsg.HashValues, it.upsertMsg.DeleteMsg.HashValues[index]) curMsg.Timestamps = append(curMsg.Timestamps, it.upsertMsg.DeleteMsg.Timestamps[index]) typeutil.AppendIDs(curMsg.PrimaryKeys, it.upsertMsg.DeleteMsg.PrimaryKeys, index) curMsg.NumRows++ curMsg.ShardName = channelNames[key] } // send delete request to log broker deleteMsgPack := &msgstream.MsgPack{ BeginTs: it.upsertMsg.DeleteMsg.BeginTs(), EndTs: it.upsertMsg.DeleteMsg.EndTs(), } for _, msg := range result { if msg != nil { deleteMsgPack.Msgs = append(deleteMsgPack.Msgs, msg) } } msgPack.Msgs = append(msgPack.Msgs, deleteMsgPack.Msgs...) log.Debug("Proxy Upsert deleteExecute done", zap.Int64("collection_id", collID), zap.Strings("virtual_channels", channelNames), zap.Int64("taskID", it.ID()), zap.Duration("prepare duration", tr.ElapseSpan())) return nil } func (it *upsertTask) Execute(ctx context.Context) (err error) { ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Upsert-Execute") defer sp.End() log := log.Ctx(ctx).With(zap.String("collectionName", it.req.CollectionName)) tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy execute upsert %d", it.ID())) stream, err := it.chMgr.getOrCreateDmlStream(it.collectionID) if err != nil { return err } msgPack := &msgstream.MsgPack{ BeginTs: it.BeginTs(), EndTs: it.EndTs(), } err = it.insertExecute(ctx, msgPack) if err != nil { log.Warn("Fail to insertExecute", zap.Error(err)) return err } err = it.deleteExecute(ctx, msgPack) if err != nil { log.Warn("Fail to deleteExecute", zap.Error(err)) return err } tr.RecordSpan() err = stream.Produce(msgPack) if err != nil { it.result.Status = merr.Status(err) return err } sendMsgDur := tr.RecordSpan() metrics.ProxySendMutationReqLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.UpsertLabel).Observe(float64(sendMsgDur.Milliseconds())) totalDur := tr.ElapseSpan() log.Debug("Proxy Upsert Execute done", zap.Int64("taskID", it.ID()), zap.Duration("total duration", totalDur)) return nil } func (it *upsertTask) PostExecute(ctx context.Context) error { return nil }