diff --git a/internal/metastore/kv/streamingcoord/constant.go b/internal/metastore/kv/streamingcoord/constant.go index 0603aeda4d..5ae1f85b7d 100644 --- a/internal/metastore/kv/streamingcoord/constant.go +++ b/internal/metastore/kv/streamingcoord/constant.go @@ -2,5 +2,5 @@ package streamingcoord const ( MetaPrefix = "streamingcoord-meta" - PChannelMeta = MetaPrefix + "/pchannel-meta" + PChannelMeta = MetaPrefix + "/pchannel" ) diff --git a/internal/proto/streaming.proto b/internal/proto/streaming.proto index 0a7221f6e8..7b6740e843 100644 --- a/internal/proto/streaming.proto +++ b/internal/proto/streaming.proto @@ -13,7 +13,7 @@ import "google/protobuf/empty.proto"; // MessageID is the unique identifier of a message. message MessageID { - bytes id = 1; + string id = 1; } // Message is the basic unit of communication between publisher and consumer. diff --git a/internal/proto/streamingpb/extends.go b/internal/proto/streamingpb/extends.go index 5d0f3fd85d..6d562592ca 100644 --- a/internal/proto/streamingpb/extends.go +++ b/internal/proto/streamingpb/extends.go @@ -1,5 +1,5 @@ package streamingpb const ( - ServiceMethodPrefix = "/milvus.proto.log" + ServiceMethodPrefix = "/milvus.proto.streaming" ) diff --git a/internal/streamingnode/server/resource/timestamp/timestamp_allocator.go b/internal/streamingnode/server/resource/idalloc/allocator.go similarity index 78% rename from internal/streamingnode/server/resource/timestamp/timestamp_allocator.go rename to internal/streamingnode/server/resource/idalloc/allocator.go index 6d2eba1a6a..6b3dd6e6af 100644 --- a/internal/streamingnode/server/resource/timestamp/timestamp_allocator.go +++ b/internal/streamingnode/server/resource/idalloc/allocator.go @@ -14,7 +14,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package timestamp +package idalloc import ( "context" @@ -28,15 +28,28 @@ const batchAllocateSize = 1000 var _ Allocator = (*allocatorImpl)(nil) -// NewAllocator creates a new allocator. -func NewAllocator(rc types.RootCoordClient) Allocator { +// NewTSOAllocator creates a new allocator. +func NewTSOAllocator(rc types.RootCoordClient) Allocator { return &allocatorImpl{ mu: sync.Mutex{}, - remoteAllocator: newRemoteAllocator(rc), + remoteAllocator: newTSOAllocator(rc), localAllocator: newLocalAllocator(), } } +// NewIDAllocator creates a new allocator. +func NewIDAllocator(rc types.RootCoordClient) Allocator { + return &allocatorImpl{ + mu: sync.Mutex{}, + remoteAllocator: newIDAllocator(rc), + localAllocator: newLocalAllocator(), + } +} + +type remoteBatchAllocator interface { + batchAllocate(ctx context.Context, count uint32) (uint64, int, error) +} + type Allocator interface { // Allocate allocates a timestamp. Allocate(ctx context.Context) (uint64, error) @@ -48,7 +61,7 @@ type Allocator interface { type allocatorImpl struct { mu sync.Mutex - remoteAllocator *remoteAllocator + remoteAllocator remoteBatchAllocator localAllocator *localAllocator } @@ -77,7 +90,7 @@ func (ta *allocatorImpl) Sync() { // allocateRemote allocates timestamp from remote root coordinator. func (ta *allocatorImpl) allocateRemote(ctx context.Context) (uint64, error) { // Update local allocator from remote. - start, count, err := ta.remoteAllocator.allocate(ctx, batchAllocateSize) + start, count, err := ta.remoteAllocator.batchAllocate(ctx, batchAllocateSize) if err != nil { return 0, err } diff --git a/internal/streamingnode/server/resource/timestamp/timestamp_allocator_test.go b/internal/streamingnode/server/resource/idalloc/allocator_test.go similarity index 93% rename from internal/streamingnode/server/resource/timestamp/timestamp_allocator_test.go rename to internal/streamingnode/server/resource/idalloc/allocator_test.go index bb0c41a99f..c4db2e520a 100644 --- a/internal/streamingnode/server/resource/timestamp/timestamp_allocator_test.go +++ b/internal/streamingnode/server/resource/idalloc/allocator_test.go @@ -1,4 +1,4 @@ -package timestamp +package idalloc import ( "context" @@ -19,7 +19,7 @@ func TestTimestampAllocator(t *testing.T) { paramtable.SetNodeID(1) client := NewMockRootCoordClient(t) - allocator := NewAllocator(client) + allocator := NewTSOAllocator(client) for i := 0; i < 5000; i++ { ts, err := allocator.Allocate(context.Background()) @@ -46,7 +46,7 @@ func TestTimestampAllocator(t *testing.T) { }, nil }, ) - allocator = NewAllocator(client) + allocator = NewTSOAllocator(client) _, err := allocator.Allocate(context.Background()) assert.Error(t, err) } diff --git a/internal/streamingnode/server/resource/timestamp/basic_allocator.go b/internal/streamingnode/server/resource/idalloc/basic_allocator.go similarity index 58% rename from internal/streamingnode/server/resource/timestamp/basic_allocator.go rename to internal/streamingnode/server/resource/idalloc/basic_allocator.go index 448c8274a4..8e0ad90e63 100644 --- a/internal/streamingnode/server/resource/timestamp/basic_allocator.go +++ b/internal/streamingnode/server/resource/idalloc/basic_allocator.go @@ -1,4 +1,4 @@ -package timestamp +package idalloc import ( "context" @@ -54,22 +54,22 @@ func (a *localAllocator) exhausted() { a.nextStartID = a.endStartID } -// remoteAllocator allocate timestamp from remote root coordinator. -type remoteAllocator struct { +// tsoAllocator allocate timestamp from remote root coordinator. +type tsoAllocator struct { rc types.RootCoordClient nodeID int64 } -// newRemoteAllocator creates a new remote allocator. -func newRemoteAllocator(rc types.RootCoordClient) *remoteAllocator { - a := &remoteAllocator{ +// newTSOAllocator creates a new remote allocator. +func newTSOAllocator(rc types.RootCoordClient) *tsoAllocator { + a := &tsoAllocator{ nodeID: paramtable.GetNodeID(), rc: rc, } return a } -func (ta *remoteAllocator) allocate(ctx context.Context, count uint32) (uint64, int, error) { +func (ta *tsoAllocator) batchAllocate(ctx context.Context, count uint32) (uint64, int, error) { ctx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() req := &rootcoordpb.AllocTimestampRequest{ @@ -93,3 +93,46 @@ func (ta *remoteAllocator) allocate(ctx context.Context, count uint32) (uint64, } return resp.GetTimestamp(), int(resp.GetCount()), nil } + +// idAllocator allocate timestamp from remote root coordinator. +type idAllocator struct { + rc types.RootCoordClient + nodeID int64 +} + +// newIDAllocator creates a new remote allocator. +func newIDAllocator(rc types.RootCoordClient) *idAllocator { + a := &idAllocator{ + nodeID: paramtable.GetNodeID(), + rc: rc, + } + return a +} + +func (ta *idAllocator) batchAllocate(ctx context.Context, count uint32) (uint64, int, error) { + ctx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + req := &rootcoordpb.AllocIDRequest{ + Base: commonpbutil.NewMsgBase( + commonpbutil.WithMsgType(commonpb.MsgType_RequestID), + commonpbutil.WithMsgID(0), + commonpbutil.WithSourceID(ta.nodeID), + ), + Count: count, + } + + resp, err := ta.rc.AllocID(ctx, req) + if err != nil { + return 0, 0, fmt.Errorf("AllocID Failed:%w", err) + } + if resp.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { + return 0, 0, fmt.Errorf("AllocID Failed:%s", resp.GetStatus().GetReason()) + } + if resp == nil { + return 0, 0, fmt.Errorf("empty AllocID") + } + if resp.GetID() < 0 { + panic("get unexpected negative id") + } + return uint64(resp.GetID()), int(resp.GetCount()), nil +} diff --git a/internal/streamingnode/server/resource/timestamp/basic_allocator_test.go b/internal/streamingnode/server/resource/idalloc/basic_allocator_test.go similarity index 58% rename from internal/streamingnode/server/resource/timestamp/basic_allocator_test.go rename to internal/streamingnode/server/resource/idalloc/basic_allocator_test.go index 53b6adc098..081832006f 100644 --- a/internal/streamingnode/server/resource/timestamp/basic_allocator_test.go +++ b/internal/streamingnode/server/resource/idalloc/basic_allocator_test.go @@ -1,4 +1,4 @@ -package timestamp +package idalloc import ( "context" @@ -58,14 +58,14 @@ func TestLocalAllocator(t *testing.T) { assert.Zero(t, ts) } -func TestRemoteAllocator(t *testing.T) { +func TestRemoteTSOAllocator(t *testing.T) { paramtable.Init() paramtable.SetNodeID(1) client := NewMockRootCoordClient(t) - allocator := newRemoteAllocator(client) - ts, count, err := allocator.allocate(context.Background(), 100) + allocator := newTSOAllocator(client) + ts, count, err := allocator.batchAllocate(context.Background(), 100) assert.NoError(t, err) assert.NotZero(t, ts) assert.Equal(t, count, 100) @@ -77,8 +77,8 @@ func TestRemoteAllocator(t *testing.T) { return nil, errors.New("test") }, ) - allocator = newRemoteAllocator(client) - _, _, err = allocator.allocate(context.Background(), 100) + allocator = newTSOAllocator(client) + _, _, err = allocator.batchAllocate(context.Background(), 100) assert.Error(t, err) client.EXPECT().AllocTimestamp(mock.Anything, mock.Anything).Unset() @@ -91,7 +91,45 @@ func TestRemoteAllocator(t *testing.T) { }, nil }, ) - allocator = newRemoteAllocator(client) - _, _, err = allocator.allocate(context.Background(), 100) + allocator = newTSOAllocator(client) + _, _, err = allocator.batchAllocate(context.Background(), 100) + assert.Error(t, err) +} + +func TestRemoteIDAllocator(t *testing.T) { + paramtable.Init() + paramtable.SetNodeID(1) + + client := NewMockRootCoordClient(t) + + allocator := newIDAllocator(client) + ts, count, err := allocator.batchAllocate(context.Background(), 100) + assert.NoError(t, err) + assert.NotZero(t, ts) + assert.Equal(t, count, 100) + + // Test error. + client = mocks.NewMockRootCoordClient(t) + client.EXPECT().AllocID(mock.Anything, mock.Anything).RunAndReturn( + func(ctx context.Context, atr *rootcoordpb.AllocIDRequest, co ...grpc.CallOption) (*rootcoordpb.AllocIDResponse, error) { + return nil, errors.New("test") + }, + ) + allocator = newIDAllocator(client) + _, _, err = allocator.batchAllocate(context.Background(), 100) + assert.Error(t, err) + + client.EXPECT().AllocID(mock.Anything, mock.Anything).Unset() + client.EXPECT().AllocID(mock.Anything, mock.Anything).RunAndReturn( + func(ctx context.Context, atr *rootcoordpb.AllocIDRequest, co ...grpc.CallOption) (*rootcoordpb.AllocIDResponse, error) { + return &rootcoordpb.AllocIDResponse{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_ForceDeny, + }, + }, nil + }, + ) + allocator = newIDAllocator(client) + _, _, err = allocator.batchAllocate(context.Background(), 100) assert.Error(t, err) } diff --git a/internal/streamingnode/server/resource/timestamp/test_mock_root_coord_client.go b/internal/streamingnode/server/resource/idalloc/test_mock_root_coord_client.go similarity index 64% rename from internal/streamingnode/server/resource/timestamp/test_mock_root_coord_client.go rename to internal/streamingnode/server/resource/idalloc/test_mock_root_coord_client.go index dc28876366..aac552af2c 100644 --- a/internal/streamingnode/server/resource/timestamp/test_mock_root_coord_client.go +++ b/internal/streamingnode/server/resource/idalloc/test_mock_root_coord_client.go @@ -1,7 +1,7 @@ //go:build test // +build test -package timestamp +package idalloc import ( "context" @@ -34,6 +34,21 @@ func NewMockRootCoordClient(t *testing.T) *mocks.MockRootCoordClient { Count: atr.Count, }, nil }, - ) + ).Maybe() + client.EXPECT().AllocID(mock.Anything, mock.Anything).RunAndReturn( + func(ctx context.Context, atr *rootcoordpb.AllocIDRequest, co ...grpc.CallOption) (*rootcoordpb.AllocIDResponse, error) { + if atr.Count > 1000 { + panic(fmt.Sprintf("count %d is too large", atr.Count)) + } + c := counter.Add(uint64(atr.Count)) + return &rootcoordpb.AllocIDResponse{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_Success, + }, + ID: int64(c - uint64(atr.Count)), + Count: atr.Count, + }, nil + }, + ).Maybe() return client } diff --git a/internal/streamingnode/server/resource/resource.go b/internal/streamingnode/server/resource/resource.go index a16ac4681b..0b964312d4 100644 --- a/internal/streamingnode/server/resource/resource.go +++ b/internal/streamingnode/server/resource/resource.go @@ -5,7 +5,7 @@ import ( clientv3 "go.etcd.io/etcd/client/v3" - "github.com/milvus-io/milvus/internal/streamingnode/server/resource/timestamp" + "github.com/milvus-io/milvus/internal/streamingnode/server/resource/idalloc" "github.com/milvus-io/milvus/internal/types" ) @@ -35,9 +35,10 @@ func Init(opts ...optResourceInit) { for _, opt := range opts { opt(r) } - r.timestampAllocator = timestamp.NewAllocator(r.rootCoordClient) + r.timestampAllocator = idalloc.NewTSOAllocator(r.rootCoordClient) + r.idAllocator = idalloc.NewIDAllocator(r.rootCoordClient) - assertNotNil(r.TimestampAllocator()) + assertNotNil(r.TSOAllocator()) assertNotNil(r.ETCD()) assertNotNil(r.RootCoordClient()) } @@ -50,16 +51,22 @@ func Resource() *resourceImpl { // resourceImpl is a basic resource dependency for streamingnode server. // All utility on it is concurrent-safe and singleton. type resourceImpl struct { - timestampAllocator timestamp.Allocator + timestampAllocator idalloc.Allocator + idAllocator idalloc.Allocator etcdClient *clientv3.Client rootCoordClient types.RootCoordClient } -// TimestampAllocator returns the timestamp allocator to allocate timestamp. -func (r *resourceImpl) TimestampAllocator() timestamp.Allocator { +// TSOAllocator returns the timestamp allocator to allocate timestamp. +func (r *resourceImpl) TSOAllocator() idalloc.Allocator { return r.timestampAllocator } +// IDAllocator returns the id allocator to allocate id. +func (r *resourceImpl) IDAllocator() idalloc.Allocator { + return r.idAllocator +} + // ETCD returns the etcd client. func (r *resourceImpl) ETCD() *clientv3.Client { return r.etcdClient diff --git a/internal/streamingnode/server/resource/resource_test.go b/internal/streamingnode/server/resource/resource_test.go index b8c0f3f62b..7c84d920de 100644 --- a/internal/streamingnode/server/resource/resource_test.go +++ b/internal/streamingnode/server/resource/resource_test.go @@ -21,7 +21,7 @@ func TestInit(t *testing.T) { }) Init(OptETCD(&clientv3.Client{}), OptRootCoordClient(mocks.NewMockRootCoordClient(t))) - assert.NotNil(t, Resource().TimestampAllocator()) + assert.NotNil(t, Resource().TSOAllocator()) assert.NotNil(t, Resource().ETCD()) assert.NotNil(t, Resource().RootCoordClient()) } diff --git a/internal/streamingnode/server/resource/test_utility.go b/internal/streamingnode/server/resource/test_utility.go index 5079f685fb..1bb2bd3a8a 100644 --- a/internal/streamingnode/server/resource/test_utility.go +++ b/internal/streamingnode/server/resource/test_utility.go @@ -3,7 +3,7 @@ package resource -import "github.com/milvus-io/milvus/internal/streamingnode/server/resource/timestamp" +import "github.com/milvus-io/milvus/internal/streamingnode/server/resource/idalloc" // InitForTest initializes the singleton of resources for test. func InitForTest(opts ...optResourceInit) { @@ -12,6 +12,6 @@ func InitForTest(opts ...optResourceInit) { opt(r) } if r.rootCoordClient != nil { - r.timestampAllocator = timestamp.NewAllocator(r.rootCoordClient) + r.timestampAllocator = idalloc.NewTSOAllocator(r.rootCoordClient) } } diff --git a/internal/streamingnode/server/service/handler/consumer/consume_server.go b/internal/streamingnode/server/service/handler/consumer/consume_server.go index 156d018f4a..b1cf7d1538 100644 --- a/internal/streamingnode/server/service/handler/consumer/consume_server.go +++ b/internal/streamingnode/server/service/handler/consumer/consume_server.go @@ -31,9 +31,7 @@ func CreateConsumeServer(walManager walmanager.Manager, streamServer streamingpb if err != nil { return nil, status.NewInvaildArgument("create consumer request is required") } - - pchanelInfo := typeconverter.NewPChannelInfoFromProto(createReq.Pchannel) - l, err := walManager.GetAvailableWAL(pchanelInfo) + l, err := walManager.GetAvailableWAL(typeconverter.NewPChannelInfoFromProto(createReq.GetPchannel())) if err != nil { return nil, err } diff --git a/internal/streamingnode/server/service/handler/producer/produce_grpc_server_helper.go b/internal/streamingnode/server/service/handler/producer/produce_grpc_server_helper.go index b5332a9cbf..22499547de 100644 --- a/internal/streamingnode/server/service/handler/producer/produce_grpc_server_helper.go +++ b/internal/streamingnode/server/service/handler/producer/produce_grpc_server_helper.go @@ -19,12 +19,10 @@ func (p *produceGrpcServerHelper) SendProduceMessage(resp *streamingpb.ProduceMe } // SendCreated sends the create response to client. -func (p *produceGrpcServerHelper) SendCreated(walName string) error { +func (p *produceGrpcServerHelper) SendCreated(resp *streamingpb.CreateProducerResponse) error { return p.Send(&streamingpb.ProduceResponse{ Response: &streamingpb.ProduceResponse_Create{ - Create: &streamingpb.CreateProducerResponse{ - WalName: walName, - }, + Create: resp, }, }) } diff --git a/internal/streamingnode/server/service/handler/producer/produce_server.go b/internal/streamingnode/server/service/handler/producer/produce_server.go index 13135f343c..0841f15774 100644 --- a/internal/streamingnode/server/service/handler/producer/produce_server.go +++ b/internal/streamingnode/server/service/handler/producer/produce_server.go @@ -33,7 +33,7 @@ func CreateProduceServer(walManager walmanager.Manager, streamServer streamingpb if err != nil { return nil, status.NewInvaildArgument("create producer request is required") } - l, err := walManager.GetAvailableWAL(typeconverter.NewPChannelInfoFromProto(createReq.Pchannel)) + l, err := walManager.GetAvailableWAL(typeconverter.NewPChannelInfoFromProto(createReq.GetPchannel())) if err != nil { return nil, err } @@ -41,7 +41,9 @@ func CreateProduceServer(walManager walmanager.Manager, streamServer streamingpb produceServer := &produceGrpcServerHelper{ StreamingNodeHandlerService_ProduceServer: streamServer, } - if err := produceServer.SendCreated(l.WALName()); err != nil { + if err := produceServer.SendCreated(&streamingpb.CreateProducerResponse{ + WalName: l.WALName(), + }); err != nil { return nil, errors.Wrap(err, "at send created") } return &ProduceServer{ @@ -170,13 +172,13 @@ func (p *ProduceServer) handleProduce(req *streamingpb.ProduceMessageRequest) { func (p *ProduceServer) validateMessage(msg message.MutableMessage) error { // validate the msg. if !msg.Version().GT(message.VersionOld) { - return status.NewInner("unsupported message version") + return status.NewInvaildArgument("unsupported message version") } if !msg.MessageType().Valid() { - return status.NewInner("unsupported message type") + return status.NewInvaildArgument("unsupported message type") } if msg.Payload() == nil { - return status.NewInner("empty payload for message") + return status.NewInvaildArgument("empty payload for message") } return nil } diff --git a/internal/streamingnode/server/service/handler/producer/produce_server_test.go b/internal/streamingnode/server/service/handler/producer/produce_server_test.go index f2468bf879..eec1582741 100644 --- a/internal/streamingnode/server/service/handler/producer/produce_server_test.go +++ b/internal/streamingnode/server/service/handler/producer/produce_server_test.go @@ -3,6 +3,7 @@ package producer import ( "context" "io" + "strconv" "sync" "testing" "time" @@ -56,6 +57,7 @@ func TestCreateProduceServer(t *testing.T) { l := mock_wal.NewMockWAL(t) l.EXPECT().WALName().Return("test") manager.ExpectedCalls = nil + l.EXPECT().WALName().Return("test") manager.EXPECT().GetAvailableWAL(types.PChannelInfo{Name: "test", Term: 1}).Return(l, nil) grpcProduceServer.EXPECT().Send(mock.Anything).Return(errors.New("send created failed")) assertCreateProduceServerFail(t, manager, grpcProduceServer) @@ -214,7 +216,7 @@ func TestProduceServerRecvArm(t *testing.T) { Payload: []byte("test"), Properties: map[string]string{ "_v": "1", - "_t": "1", + "_t": strconv.FormatInt(int64(message.MessageTypeTimeTick), 10), }, }, }, diff --git a/internal/streamingnode/server/wal/adaptor/wal_adaptor.go b/internal/streamingnode/server/wal/adaptor/wal_adaptor.go index d3894214ef..5cc6eb0011 100644 --- a/internal/streamingnode/server/wal/adaptor/wal_adaptor.go +++ b/internal/streamingnode/server/wal/adaptor/wal_adaptor.go @@ -100,7 +100,7 @@ func (w *walAdaptorImpl) AppendAsync(ctx context.Context, msg message.MutableMes _ = w.appendExecutionPool.Submit(func() (struct{}, error) { defer w.lifetime.Done() - msgID, err := w.inner.Append(ctx, msg) + msgID, err := w.Append(ctx, msg) cb(msgID, err) return struct{}{}, nil }) diff --git a/internal/streamingnode/server/wal/adaptor/wal_test.go b/internal/streamingnode/server/wal/adaptor/wal_test.go index a48ae83d63..aad5964f5a 100644 --- a/internal/streamingnode/server/wal/adaptor/wal_test.go +++ b/internal/streamingnode/server/wal/adaptor/wal_test.go @@ -16,7 +16,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus/internal/streamingnode/server/resource" - "github.com/milvus-io/milvus/internal/streamingnode/server/resource/timestamp" + "github.com/milvus-io/milvus/internal/streamingnode/server/resource/idalloc" "github.com/milvus-io/milvus/internal/streamingnode/server/wal" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/registry" "github.com/milvus-io/milvus/pkg/streaming/util/message" @@ -32,7 +32,7 @@ type walTestFramework struct { } func TestWAL(t *testing.T) { - rc := timestamp.NewMockRootCoordClient(t) + rc := idalloc.NewMockRootCoordClient(t) resource.InitForTest(resource.OptRootCoordClient(rc)) b := registry.MustGetBuilder(walimplstest.WALName) diff --git a/internal/streamingnode/server/wal/interceptors/timetick/ack/ack_test.go b/internal/streamingnode/server/wal/interceptors/timetick/ack/ack_test.go index 815003fb3c..55f9be181d 100644 --- a/internal/streamingnode/server/wal/interceptors/timetick/ack/ack_test.go +++ b/internal/streamingnode/server/wal/interceptors/timetick/ack/ack_test.go @@ -7,7 +7,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/milvus-io/milvus/internal/streamingnode/server/resource" - "github.com/milvus-io/milvus/internal/streamingnode/server/resource/timestamp" + "github.com/milvus-io/milvus/internal/streamingnode/server/resource/idalloc" "github.com/milvus-io/milvus/pkg/mocks/streaming/util/mock_message" "github.com/milvus-io/milvus/pkg/util/paramtable" ) @@ -18,7 +18,7 @@ func TestAck(t *testing.T) { ctx := context.Background() - rc := timestamp.NewMockRootCoordClient(t) + rc := idalloc.NewMockRootCoordClient(t) resource.InitForTest(resource.OptRootCoordClient(rc)) ackManager := NewAckManager() diff --git a/internal/streamingnode/server/wal/interceptors/timetick/ack/manager.go b/internal/streamingnode/server/wal/interceptors/timetick/ack/manager.go index 4152e4a361..93ac1842a4 100644 --- a/internal/streamingnode/server/wal/interceptors/timetick/ack/manager.go +++ b/internal/streamingnode/server/wal/interceptors/timetick/ack/manager.go @@ -31,7 +31,7 @@ func (ta *AckManager) Allocate(ctx context.Context) (*Acker, error) { defer ta.mu.Unlock() // allocate one from underlying allocator first. - ts, err := resource.Resource().TimestampAllocator().Allocate(ctx) + ts, err := resource.Resource().TSOAllocator().Allocate(ctx) if err != nil { return nil, err } @@ -47,7 +47,7 @@ func (ta *AckManager) Allocate(ctx context.Context) (*Acker, error) { // Concurrent safe to call with Allocate. func (ta *AckManager) SyncAndGetAcknowledged(ctx context.Context) ([]*AckDetail, error) { // local timestamp may out of date, sync the underlying allocator before get last all acknowledged. - resource.Resource().TimestampAllocator().Sync() + resource.Resource().TSOAllocator().Sync() // Allocate may be uncalled in long term, and the recorder may be out of date. // Do a Allocate and Ack, can sync up the recorder with internal timetick.TimestampAllocator latest time. diff --git a/internal/streamingnode/server/walmanager/manager.go b/internal/streamingnode/server/walmanager/manager.go index 811ae42f15..445f1c68bd 100644 --- a/internal/streamingnode/server/walmanager/manager.go +++ b/internal/streamingnode/server/walmanager/manager.go @@ -12,16 +12,20 @@ var _ Manager = (*managerImpl)(nil) // Manager is the interface for managing the wal instances. type Manager interface { // Open opens a wal instance for the channel on this Manager. + // Return `IgnoreOperation` error if the channel is not found. + // Return `UnmatchedChannelTerm` error if the channel term is not matched. Open(ctx context.Context, channel types.PChannelInfo) error // GetAvailableWAL returns a available wal instance for the channel. // Return nil if the wal instance is not found. - GetAvailableWAL(channel types.PChannelInfo) (wal.WAL, error) + GetAvailableWAL(types.PChannelInfo) (wal.WAL, error) // GetAllAvailableWALInfo returns all available channel info. GetAllAvailableChannels() ([]types.PChannelInfo, error) // Remove removes the wal instance for the channel. + // Return `IgnoreOperation` error if the channel is not found. + // Return `UnmatchedChannelTerm` error if the channel term is not matched. Remove(ctx context.Context, channel types.PChannelInfo) error // Close these manager and release all managed WAL. diff --git a/internal/streamingnode/server/walmanager/manager_impl.go b/internal/streamingnode/server/walmanager/manager_impl.go index 70f4ed26b5..3fcb5c1dd1 100644 --- a/internal/streamingnode/server/walmanager/manager_impl.go +++ b/internal/streamingnode/server/walmanager/manager_impl.go @@ -71,6 +71,7 @@ func (m *managerImpl) Remove(ctx context.Context, channel types.PChannelInfo) (e m.lifetime.Done() if err != nil { log.Warn("remove wal failed", zap.Error(err), zap.String("channel", channel.Name), zap.Int64("term", channel.Term)) + return } log.Info("remove wal success", zap.String("channel", channel.Name), zap.Int64("term", channel.Term)) }() diff --git a/internal/util/sessionutil/session_util.go b/internal/util/sessionutil/session_util.go index 84d3cf8830..828ea2d2cb 100644 --- a/internal/util/sessionutil/session_util.go +++ b/internal/util/sessionutil/session_util.go @@ -1237,3 +1237,8 @@ func saveServerInfoInternal(role string, serverID int64, pid int) { func SaveServerInfo(role string, serverID int64) { saveServerInfoInternal(role, serverID, os.Getpid()) } + +// GetSessionPrefixByRole get session prefix by role +func GetSessionPrefixByRole(role string) string { + return path.Join(paramtable.Get().EtcdCfg.MetaRootPath.GetValue(), DefaultServiceRoot, role) +} diff --git a/internal/util/streamingutil/service/balancer/picker/server_id_picker.go b/internal/util/streamingutil/service/balancer/picker/server_id_picker.go index fe714713cc..9e4c9648cb 100644 --- a/internal/util/streamingutil/service/balancer/picker/server_id_picker.go +++ b/internal/util/streamingutil/service/balancer/picker/server_id_picker.go @@ -17,7 +17,7 @@ import ( var _ balancer.Picker = &serverIDPicker{} -var ErrNoSubConnNotExist = status.New(codes.Unavailable, "sub connection not exist").Err() +var ErrSubConnNotExist = status.New(codes.Unavailable, "subConn not exist").Err() type subConnInfo struct { serverID int64 @@ -107,18 +107,18 @@ func (p *serverIDPicker) useGivenAddr(_ balancer.PickInfo, serverID int64) (*sub // FailPrecondition will be converted to Internal by grpc framework in function `IsRestrictedControlPlaneCode`. // Use Unavailable here. // Unavailable code is retried in many cases, so it's better to be used here to avoid when Subconn is not ready scene. - return nil, ErrNoSubConnNotExist + return nil, ErrSubConnNotExist } -// IsErrNoSubConnForPick checks whether the error is ErrNoSubConnForPick. -func IsErrNoSubConnForPick(err error) bool { - if errors.Is(err, ErrNoSubConnNotExist) { +// IsErrSubConnNoExist checks whether the error is ErrNoSubConnForPick. +func IsErrSubConnNoExist(err error) bool { + if errors.Is(err, ErrSubConnNotExist) { return true } if se, ok := err.(interface { GRPCStatus() *status.Status }); ok { - return errors.Is(se.GRPCStatus().Err(), ErrNoSubConnNotExist) + return errors.Is(se.GRPCStatus().Err(), ErrSubConnNotExist) } return false } diff --git a/internal/util/streamingutil/service/balancer/picker/server_id_picker_test.go b/internal/util/streamingutil/service/balancer/picker/server_id_picker_test.go index 35bd559ff9..707eba3464 100644 --- a/internal/util/streamingutil/service/balancer/picker/server_id_picker_test.go +++ b/internal/util/streamingutil/service/balancer/picker/server_id_picker_test.go @@ -91,13 +91,13 @@ func TestServerIDPickerBuilder(t *testing.T) { Ctx: contextutil.WithPickServerID(context.Background(), 4), }) assert.Error(t, err) - assert.ErrorIs(t, err, ErrNoSubConnNotExist) + assert.ErrorIs(t, err, ErrSubConnNotExist) assert.NotNil(t, info) } func TestIsErrNoSubConnForPick(t *testing.T) { - assert.True(t, IsErrNoSubConnForPick(ErrNoSubConnNotExist)) - assert.False(t, IsErrNoSubConnForPick(errors.New("test"))) - err := status.ConvertStreamingError("test", ErrNoSubConnNotExist) - assert.True(t, IsErrNoSubConnForPick(err)) + assert.True(t, IsErrSubConnNoExist(ErrSubConnNotExist)) + assert.False(t, IsErrSubConnNoExist(errors.New("test"))) + err := status.ConvertStreamingError("test", ErrSubConnNotExist) + assert.True(t, IsErrSubConnNoExist(err)) } diff --git a/internal/util/streamingutil/service/discoverer/channel_assignment_discoverer.go b/internal/util/streamingutil/service/discoverer/channel_assignment_discoverer.go index 80d75d9fdf..2dc40858c4 100644 --- a/internal/util/streamingutil/service/discoverer/channel_assignment_discoverer.go +++ b/internal/util/streamingutil/service/discoverer/channel_assignment_discoverer.go @@ -58,7 +58,10 @@ func (d *channelAssignmentDiscoverer) parseState() VersionedState { for _, assignment := range d.lastDiscovery.Assignments { assignment := assignment addrs = append(addrs, resolver.Address{ - Addr: assignment.NodeInfo.Address, + Addr: assignment.NodeInfo.Address, + // resolverAttributes is important to use when resolving, server id to make resolver.Address with same adresss different. + Attributes: attributes.WithServerID(new(attributes.Attributes), assignment.NodeInfo.ServerID), + // balancerAttributes can be seen by picker of grpc balancer. BalancerAttributes: attributes.WithChannelAssignmentInfo(new(attributes.Attributes), &assignment), }) } diff --git a/internal/util/streamingutil/service/discoverer/channel_assignment_discoverer_test.go b/internal/util/streamingutil/service/discoverer/channel_assignment_discoverer_test.go index 3886643fe0..6acb570bc7 100644 --- a/internal/util/streamingutil/service/discoverer/channel_assignment_discoverer_test.go +++ b/internal/util/streamingutil/service/discoverer/channel_assignment_discoverer_test.go @@ -8,6 +8,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "github.com/milvus-io/milvus/internal/util/streamingutil/service/attributes" "github.com/milvus-io/milvus/pkg/mocks/streaming/util/mock_types" "github.com/milvus-io/milvus/pkg/streaming/util/types" "github.com/milvus-io/milvus/pkg/util/typeutil" @@ -92,6 +93,12 @@ func TestChannelAssignmentDiscoverer(t *testing.T) { idx++ return nil } + + // resolver attributes + for _, addr := range state.State.Addresses { + serverID := attributes.GetServerID(addr.Attributes) + assert.NotNil(t, serverID) + } return io.EOF }) assert.ErrorIs(t, err, io.EOF) diff --git a/internal/util/streamingutil/service/discoverer/session_discoverer.go b/internal/util/streamingutil/service/discoverer/session_discoverer.go index 03d42070ff..5cbc58b6e0 100644 --- a/internal/util/streamingutil/service/discoverer/session_discoverer.go +++ b/internal/util/streamingutil/service/discoverer/session_discoverer.go @@ -166,16 +166,18 @@ func (sw *sessionDiscoverer) parseState() VersionedState { continue } // filter low version. + // !!! important, stopping nodes should not be removed here. if !sw.versionRange(v) { sw.logger.Info("skip low version node", zap.Int64("serverID", session.ServerID), zap.String("version", session.Version)) continue } - // !!! important, stopping nodes should not be removed here. - attr := new(attributes.Attributes) - attr = attributes.WithSession(attr, session) + addrs = append(addrs, resolver.Address{ - Addr: session.Address, - BalancerAttributes: attr, + Addr: session.Address, + // resolverAttributes is important to use when resolving, server id to make resolver.Address with same adresss different. + Attributes: attributes.WithServerID(new(attributes.Attributes), session.ServerID), + // balancerAttributes can be seen by picker of grpc balancer. + BalancerAttributes: attributes.WithSession(new(attributes.Attributes), session), }) } diff --git a/internal/util/streamingutil/service/discoverer/session_discoverer_test.go b/internal/util/streamingutil/service/discoverer/session_discoverer_test.go index 4705bf5f2b..304f8c8245 100644 --- a/internal/util/streamingutil/service/discoverer/session_discoverer_test.go +++ b/internal/util/streamingutil/service/discoverer/session_discoverer_test.go @@ -12,6 +12,7 @@ import ( clientv3 "go.etcd.io/etcd/client/v3" "github.com/milvus-io/milvus/internal/util/sessionutil" + "github.com/milvus-io/milvus/internal/util/streamingutil/service/attributes" "github.com/milvus-io/milvus/pkg/util/etcd" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -96,8 +97,8 @@ func TestSessionDiscoverer(t *testing.T) { // Do a init discover here. d = NewSessionDiscoverer(etcdClient, "session/", targetVersion) err = d.Discover(ctx, func(state VersionedState) error { + // balance attributes sessions := state.Sessions() - expectedSessions := make(map[int64]*sessionutil.SessionRaw, len(expected[idx])) for k, v := range expected[idx] { if semver.MustParse(v.Version).GT(semver.MustParse(targetVersion)) { @@ -105,6 +106,12 @@ func TestSessionDiscoverer(t *testing.T) { } } assert.Equal(t, expectedSessions, sessions) + + // resolver attributes + for _, addr := range state.State.Addresses { + serverID := attributes.GetServerID(addr.Attributes) + assert.NotNil(t, serverID) + } return io.EOF }) assert.ErrorIs(t, err, io.EOF) diff --git a/internal/util/streamingutil/service/resolver/watch_based_grpc_resolver.go b/internal/util/streamingutil/service/resolver/watch_based_grpc_resolver.go index bc1bf288cc..d861870c69 100644 --- a/internal/util/streamingutil/service/resolver/watch_based_grpc_resolver.go +++ b/internal/util/streamingutil/service/resolver/watch_based_grpc_resolver.go @@ -43,6 +43,8 @@ func (r *watchBasedGRPCResolver) Close() { r.lifetime.Close() } +// Update updates the state of the resolver. +// Return error if the resolver is closed. func (r *watchBasedGRPCResolver) Update(state VersionedState) error { if r.lifetime.Add(lifetime.IsWorking) != nil { return errors.New("resolver is closed") @@ -50,8 +52,9 @@ func (r *watchBasedGRPCResolver) Update(state VersionedState) error { defer r.lifetime.Done() if err := r.cc.UpdateState(state.State); err != nil { - // watch based resolver could ignore the error. + // watch based resolver could ignore the error, just log and return nil r.logger.Warn("fail to update resolver state", zap.Error(err)) + return nil } r.logger.Info("update resolver state success", zap.Any("state", state.State)) return nil diff --git a/internal/util/streamingutil/typeconverter/deliver_test.go b/internal/util/streamingutil/typeconverter/deliver_test.go index 77ca100631..cd0ceb4a21 100644 --- a/internal/util/streamingutil/typeconverter/deliver_test.go +++ b/internal/util/streamingutil/typeconverter/deliver_test.go @@ -52,8 +52,8 @@ func TestDeliverPolicy(t *testing.T) { assert.Equal(t, policy.Policy(), policy2.Policy()) msgID := mock_message.NewMockMessageID(t) - msgID.EXPECT().Marshal().Return([]byte("mock")) - message.RegisterMessageIDUnmsarshaler("mock", func(b []byte) (message.MessageID, error) { + msgID.EXPECT().Marshal().Return("mock") + message.RegisterMessageIDUnmsarshaler("mock", func(b string) (message.MessageID, error) { return msgID, nil }) diff --git a/pkg/.mockery_pkg.yaml b/pkg/.mockery_pkg.yaml index 4e372c97a7..356c653799 100644 --- a/pkg/.mockery_pkg.yaml +++ b/pkg/.mockery_pkg.yaml @@ -26,4 +26,6 @@ packages: github.com/milvus-io/milvus/pkg/streaming/util/types: interfaces: AssignmentDiscoverWatcher: + AssignmentRebalanceTrigger: + \ No newline at end of file diff --git a/pkg/mocks/streaming/util/mock_message/mock_MessageID.go b/pkg/mocks/streaming/util/mock_message/mock_MessageID.go index d4371e2b3c..fca86396a7 100644 --- a/pkg/mocks/streaming/util/mock_message/mock_MessageID.go +++ b/pkg/mocks/streaming/util/mock_message/mock_MessageID.go @@ -147,16 +147,14 @@ func (_c *MockMessageID_LTE_Call) RunAndReturn(run func(message.MessageID) bool) } // Marshal provides a mock function with given fields: -func (_m *MockMessageID) Marshal() []byte { +func (_m *MockMessageID) Marshal() string { ret := _m.Called() - var r0 []byte - if rf, ok := ret.Get(0).(func() []byte); ok { + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { r0 = rf() } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).([]byte) - } + r0 = ret.Get(0).(string) } return r0 @@ -179,12 +177,12 @@ func (_c *MockMessageID_Marshal_Call) Run(run func()) *MockMessageID_Marshal_Cal return _c } -func (_c *MockMessageID_Marshal_Call) Return(_a0 []byte) *MockMessageID_Marshal_Call { +func (_c *MockMessageID_Marshal_Call) Return(_a0 string) *MockMessageID_Marshal_Call { _c.Call.Return(_a0) return _c } -func (_c *MockMessageID_Marshal_Call) RunAndReturn(run func() []byte) *MockMessageID_Marshal_Call { +func (_c *MockMessageID_Marshal_Call) RunAndReturn(run func() string) *MockMessageID_Marshal_Call { _c.Call.Return(run) return _c } diff --git a/pkg/mocks/streaming/util/mock_message/mock_MutableMessage.go b/pkg/mocks/streaming/util/mock_message/mock_MutableMessage.go index d1649e94f2..89bca90ba3 100644 --- a/pkg/mocks/streaming/util/mock_message/mock_MutableMessage.go +++ b/pkg/mocks/streaming/util/mock_message/mock_MutableMessage.go @@ -190,15 +190,15 @@ func (_c *MockMutableMessage_Payload_Call) RunAndReturn(run func() []byte) *Mock } // Properties provides a mock function with given fields: -func (_m *MockMutableMessage) Properties() message.Properties { +func (_m *MockMutableMessage) Properties() message.RProperties { ret := _m.Called() - var r0 message.Properties - if rf, ok := ret.Get(0).(func() message.Properties); ok { + var r0 message.RProperties + if rf, ok := ret.Get(0).(func() message.RProperties); ok { r0 = rf() } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(message.Properties) + r0 = ret.Get(0).(message.RProperties) } } @@ -222,12 +222,12 @@ func (_c *MockMutableMessage_Properties_Call) Run(run func()) *MockMutableMessag return _c } -func (_c *MockMutableMessage_Properties_Call) Return(_a0 message.Properties) *MockMutableMessage_Properties_Call { +func (_c *MockMutableMessage_Properties_Call) Return(_a0 message.RProperties) *MockMutableMessage_Properties_Call { _c.Call.Return(_a0) return _c } -func (_c *MockMutableMessage_Properties_Call) RunAndReturn(run func() message.Properties) *MockMutableMessage_Properties_Call { +func (_c *MockMutableMessage_Properties_Call) RunAndReturn(run func() message.RProperties) *MockMutableMessage_Properties_Call { _c.Call.Return(run) return _c } @@ -361,6 +361,50 @@ func (_c *MockMutableMessage_WithTimeTick_Call) RunAndReturn(run func(uint64) me return _c } +// WithVChannel provides a mock function with given fields: vChannel +func (_m *MockMutableMessage) WithVChannel(vChannel string) message.MutableMessage { + ret := _m.Called(vChannel) + + var r0 message.MutableMessage + if rf, ok := ret.Get(0).(func(string) message.MutableMessage); ok { + r0 = rf(vChannel) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(message.MutableMessage) + } + } + + return r0 +} + +// MockMutableMessage_WithVChannel_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'WithVChannel' +type MockMutableMessage_WithVChannel_Call struct { + *mock.Call +} + +// WithVChannel is a helper method to define mock.On call +// - vChannel string +func (_e *MockMutableMessage_Expecter) WithVChannel(vChannel interface{}) *MockMutableMessage_WithVChannel_Call { + return &MockMutableMessage_WithVChannel_Call{Call: _e.mock.On("WithVChannel", vChannel)} +} + +func (_c *MockMutableMessage_WithVChannel_Call) Run(run func(vChannel string)) *MockMutableMessage_WithVChannel_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(string)) + }) + return _c +} + +func (_c *MockMutableMessage_WithVChannel_Call) Return(_a0 message.MutableMessage) *MockMutableMessage_WithVChannel_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockMutableMessage_WithVChannel_Call) RunAndReturn(run func(string) message.MutableMessage) *MockMutableMessage_WithVChannel_Call { + _c.Call.Return(run) + return _c +} + // NewMockMutableMessage creates a new instance of MockMutableMessage. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. func NewMockMutableMessage(t interface { diff --git a/pkg/mocks/streaming/util/mock_types/mock_AssignmentDiscoverWatcher.go b/pkg/mocks/streaming/util/mock_types/mock_AssignmentDiscoverWatcher.go index 9f2ddfd815..849c0c31b0 100644 --- a/pkg/mocks/streaming/util/mock_types/mock_AssignmentDiscoverWatcher.go +++ b/pkg/mocks/streaming/util/mock_types/mock_AssignmentDiscoverWatcher.go @@ -65,6 +65,50 @@ func (_c *MockAssignmentDiscoverWatcher_AssignmentDiscover_Call) RunAndReturn(ru return _c } +// ReportAssignmentError provides a mock function with given fields: ctx, pchannel, err +func (_m *MockAssignmentDiscoverWatcher) ReportAssignmentError(ctx context.Context, pchannel types.PChannelInfo, err error) error { + ret := _m.Called(ctx, pchannel, err) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, types.PChannelInfo, error) error); ok { + r0 = rf(ctx, pchannel, err) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockAssignmentDiscoverWatcher_ReportAssignmentError_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ReportAssignmentError' +type MockAssignmentDiscoverWatcher_ReportAssignmentError_Call struct { + *mock.Call +} + +// ReportAssignmentError is a helper method to define mock.On call +// - ctx context.Context +// - pchannel types.PChannelInfo +// - err error +func (_e *MockAssignmentDiscoverWatcher_Expecter) ReportAssignmentError(ctx interface{}, pchannel interface{}, err interface{}) *MockAssignmentDiscoverWatcher_ReportAssignmentError_Call { + return &MockAssignmentDiscoverWatcher_ReportAssignmentError_Call{Call: _e.mock.On("ReportAssignmentError", ctx, pchannel, err)} +} + +func (_c *MockAssignmentDiscoverWatcher_ReportAssignmentError_Call) Run(run func(ctx context.Context, pchannel types.PChannelInfo, err error)) *MockAssignmentDiscoverWatcher_ReportAssignmentError_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.PChannelInfo), args[2].(error)) + }) + return _c +} + +func (_c *MockAssignmentDiscoverWatcher_ReportAssignmentError_Call) Return(_a0 error) *MockAssignmentDiscoverWatcher_ReportAssignmentError_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockAssignmentDiscoverWatcher_ReportAssignmentError_Call) RunAndReturn(run func(context.Context, types.PChannelInfo, error) error) *MockAssignmentDiscoverWatcher_ReportAssignmentError_Call { + _c.Call.Return(run) + return _c +} + // NewMockAssignmentDiscoverWatcher creates a new instance of MockAssignmentDiscoverWatcher. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. func NewMockAssignmentDiscoverWatcher(t interface { diff --git a/pkg/mocks/streaming/util/mock_types/mock_AssignmentRebalanceTrigger.go b/pkg/mocks/streaming/util/mock_types/mock_AssignmentRebalanceTrigger.go new file mode 100644 index 0000000000..d5fbed96a8 --- /dev/null +++ b/pkg/mocks/streaming/util/mock_types/mock_AssignmentRebalanceTrigger.go @@ -0,0 +1,81 @@ +// Code generated by mockery v2.32.4. DO NOT EDIT. + +package mock_types + +import ( + context "context" + + types "github.com/milvus-io/milvus/pkg/streaming/util/types" + mock "github.com/stretchr/testify/mock" +) + +// MockAssignmentRebalanceTrigger is an autogenerated mock type for the AssignmentRebalanceTrigger type +type MockAssignmentRebalanceTrigger struct { + mock.Mock +} + +type MockAssignmentRebalanceTrigger_Expecter struct { + mock *mock.Mock +} + +func (_m *MockAssignmentRebalanceTrigger) EXPECT() *MockAssignmentRebalanceTrigger_Expecter { + return &MockAssignmentRebalanceTrigger_Expecter{mock: &_m.Mock} +} + +// ReportAssignmentError provides a mock function with given fields: ctx, pchannel, err +func (_m *MockAssignmentRebalanceTrigger) ReportAssignmentError(ctx context.Context, pchannel types.PChannelInfo, err error) error { + ret := _m.Called(ctx, pchannel, err) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, types.PChannelInfo, error) error); ok { + r0 = rf(ctx, pchannel, err) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockAssignmentRebalanceTrigger_ReportAssignmentError_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ReportAssignmentError' +type MockAssignmentRebalanceTrigger_ReportAssignmentError_Call struct { + *mock.Call +} + +// ReportAssignmentError is a helper method to define mock.On call +// - ctx context.Context +// - pchannel types.PChannelInfo +// - err error +func (_e *MockAssignmentRebalanceTrigger_Expecter) ReportAssignmentError(ctx interface{}, pchannel interface{}, err interface{}) *MockAssignmentRebalanceTrigger_ReportAssignmentError_Call { + return &MockAssignmentRebalanceTrigger_ReportAssignmentError_Call{Call: _e.mock.On("ReportAssignmentError", ctx, pchannel, err)} +} + +func (_c *MockAssignmentRebalanceTrigger_ReportAssignmentError_Call) Run(run func(ctx context.Context, pchannel types.PChannelInfo, err error)) *MockAssignmentRebalanceTrigger_ReportAssignmentError_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.PChannelInfo), args[2].(error)) + }) + return _c +} + +func (_c *MockAssignmentRebalanceTrigger_ReportAssignmentError_Call) Return(_a0 error) *MockAssignmentRebalanceTrigger_ReportAssignmentError_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockAssignmentRebalanceTrigger_ReportAssignmentError_Call) RunAndReturn(run func(context.Context, types.PChannelInfo, error) error) *MockAssignmentRebalanceTrigger_ReportAssignmentError_Call { + _c.Call.Return(run) + return _c +} + +// NewMockAssignmentRebalanceTrigger creates a new instance of MockAssignmentRebalanceTrigger. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockAssignmentRebalanceTrigger(t interface { + mock.TestingT + Cleanup(func()) +}) *MockAssignmentRebalanceTrigger { + mock := &MockAssignmentRebalanceTrigger{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/pkg/streaming/util/message/builder.go b/pkg/streaming/util/message/builder.go index 14996ea3e6..d5a0bfe56f 100644 --- a/pkg/streaming/util/message/builder.go +++ b/pkg/streaming/util/message/builder.go @@ -43,14 +43,14 @@ func (b *MutableMesasgeBuilder) WithPayload(payload []byte) *MutableMesasgeBuild } // WithProperty creates a new builder with message property. -// A key started with '_' is reserved for log system, should never used at user of client. +// A key started with '_' is reserved for streaming system, should never used at user of client. func (b *MutableMesasgeBuilder) WithProperty(key string, val string) *MutableMesasgeBuilder { b.properties.Set(key, val) return b } // WithProperties creates a new builder with message properties. -// A key started with '_' is reserved for log system, should never used at user of client. +// A key started with '_' is reserved for streaming system, should never used at user of client. func (b *MutableMesasgeBuilder) WithProperties(kvs map[string]string) *MutableMesasgeBuilder { for key, val := range kvs { b.properties.Set(key, val) diff --git a/pkg/streaming/util/message/encoder.go b/pkg/streaming/util/message/encoder.go new file mode 100644 index 0000000000..959619a53e --- /dev/null +++ b/pkg/streaming/util/message/encoder.go @@ -0,0 +1,25 @@ +package message + +import "strconv" + +const base = 36 + +// EncodeInt64 encodes int64 to string. +func EncodeInt64(value int64) string { + return strconv.FormatInt(value, base) +} + +// EncodeUint64 encodes uint64 to string. +func EncodeUint64(value uint64) string { + return strconv.FormatUint(value, base) +} + +// DecodeUint64 decodes string to uint64. +func DecodeUint64(value string) (uint64, error) { + return strconv.ParseUint(value, base, 64) +} + +// DecodeInt64 decodes string to int64. +func DecodeInt64(value string) (int64, error) { + return strconv.ParseInt(value, base, 64) +} diff --git a/pkg/streaming/util/message/message.go b/pkg/streaming/util/message/message.go index bea23482bb..2c65995067 100644 --- a/pkg/streaming/util/message/message.go +++ b/pkg/streaming/util/message/message.go @@ -21,6 +21,10 @@ type BasicMessage interface { // EstimateSize returns the estimated size of message. EstimateSize() int + + // Properties returns the message properties. + // Should be used with read-only promise. + Properties() RProperties } // MutableMessage is the mutable message interface. @@ -29,23 +33,23 @@ type MutableMessage interface { BasicMessage // WithLastConfirmed sets the last confirmed message id of current message. - // !!! preserved for log system internal usage, don't call it outside of log system. + // !!! preserved for streaming system internal usage, don't call it outside of log system. WithLastConfirmed(id MessageID) MutableMessage // WithTimeTick sets the time tick of current message. - // !!! preserved for log system internal usage, don't call it outside of log system. + // !!! preserved for streaming system internal usage, don't call it outside of log system. WithTimeTick(tt uint64) MutableMessage - // Properties returns the message properties. - Properties() Properties + // WithVChannel sets the virtual channel of current message. + // !!! preserved for streaming system internal usage, don't call it outside of log system. + WithVChannel(vChannel string) MutableMessage // IntoImmutableMessage converts the mutable message to immutable message. IntoImmutableMessage(msgID MessageID) ImmutableMessage } // ImmutableMessage is the read-only message interface. -// Once a message is persistent by wal, it will be immutable. -// And the message id will be assigned. +// Once a message is persistent by wal or temporary generated by wal, it will be immutable. type ImmutableMessage interface { BasicMessage @@ -71,7 +75,4 @@ type ImmutableMessage interface { // MessageID returns the message id of current message. MessageID() MessageID - - // Properties returns the message read only properties. - Properties() RProperties } diff --git a/pkg/streaming/util/message/message_builder_test.go b/pkg/streaming/util/message/message_builder_test.go index e20d8d0e7d..87ad351fbe 100644 --- a/pkg/streaming/util/message/message_builder_test.go +++ b/pkg/streaming/util/message/message_builder_test.go @@ -4,7 +4,6 @@ import ( "fmt" "testing" - "github.com/golang/protobuf/proto" "github.com/stretchr/testify/assert" "github.com/milvus-io/milvus/pkg/mocks/streaming/util/mock_message" @@ -25,16 +24,16 @@ func TestMessage(t *testing.T) { assert.Equal(t, "value", v) assert.True(t, ok) assert.Equal(t, message.MessageTypeTimeTick, mutableMessage.MessageType()) - assert.Equal(t, 21, mutableMessage.EstimateSize()) + assert.Equal(t, 24, mutableMessage.EstimateSize()) mutableMessage.WithTimeTick(123) v, ok = mutableMessage.Properties().Get("_tt") assert.True(t, ok) - tt, n := proto.DecodeVarint([]byte(v)) + tt, err := message.DecodeUint64(v) assert.Equal(t, uint64(123), tt) - assert.Equal(t, len([]byte(v)), n) + assert.NoError(t, err) lcMsgID := mock_message.NewMockMessageID(t) - lcMsgID.EXPECT().Marshal().Return([]byte("lcMsgID")) + lcMsgID.EXPECT().Marshal().Return("lcMsgID") mutableMessage.WithLastConfirmed(lcMsgID) v, ok = mutableMessage.Properties().Get("_lc") assert.True(t, ok) @@ -43,8 +42,8 @@ func TestMessage(t *testing.T) { msgID := mock_message.NewMockMessageID(t) msgID.EXPECT().EQ(msgID).Return(true) msgID.EXPECT().WALName().Return("testMsgID") - message.RegisterMessageIDUnmsarshaler("testMsgID", func(data []byte) (message.MessageID, error) { - if string(data) == "lcMsgID" { + message.RegisterMessageIDUnmsarshaler("testMsgID", func(data string) (message.MessageID, error) { + if data == "lcMsgID" { return msgID, nil } panic(fmt.Sprintf("unexpected data: %s", data)) @@ -54,8 +53,8 @@ func TestMessage(t *testing.T) { []byte("payload"), map[string]string{ "key": "value", - "_t": "1", - "_tt": string(proto.EncodeVarint(456)), + "_t": "1200", + "_tt": message.EncodeUint64(456), "_v": "1", "_lc": "lcMsgID", }) @@ -67,7 +66,7 @@ func TestMessage(t *testing.T) { assert.Equal(t, "value", v) assert.True(t, ok) assert.Equal(t, message.MessageTypeTimeTick, immutableMessage.MessageType()) - assert.Equal(t, 36, immutableMessage.EstimateSize()) + assert.Equal(t, 39, immutableMessage.EstimateSize()) assert.Equal(t, message.Version(1), immutableMessage.Version()) assert.Equal(t, uint64(456), immutableMessage.TimeTick()) assert.NotNil(t, immutableMessage.LastConfirmedMessageID()) @@ -77,7 +76,7 @@ func TestMessage(t *testing.T) { []byte("payload"), map[string]string{ "key": "value", - "_t": "1", + "_t": "1200", }) assert.True(t, immutableMessage.MessageID().EQ(msgID)) @@ -87,7 +86,7 @@ func TestMessage(t *testing.T) { assert.Equal(t, "value", v) assert.True(t, ok) assert.Equal(t, message.MessageTypeTimeTick, immutableMessage.MessageType()) - assert.Equal(t, 18, immutableMessage.EstimateSize()) + assert.Equal(t, 21, immutableMessage.EstimateSize()) assert.Equal(t, message.Version(0), immutableMessage.Version()) assert.Panics(t, func() { immutableMessage.TimeTick() diff --git a/pkg/streaming/util/message/message_id.go b/pkg/streaming/util/message/message_id.go index d68ab616a1..b1d9fa14f8 100644 --- a/pkg/streaming/util/message/message_id.go +++ b/pkg/streaming/util/message/message_id.go @@ -22,10 +22,10 @@ func RegisterMessageIDUnmsarshaler(name string, unmarshaler MessageIDUnmarshaler } // MessageIDUnmarshaler is the unmarshaler for message id. -type MessageIDUnmarshaler = func(b []byte) (MessageID, error) +type MessageIDUnmarshaler = func(b string) (MessageID, error) // UnmsarshalMessageID unmarshal the message id. -func UnmarshalMessageID(name string, b []byte) (MessageID, error) { +func UnmarshalMessageID(name string, b string) (MessageID, error) { unmarshaler, ok := messageIDUnmarshaler.Get(name) if !ok { panic("MessageID Unmarshaler not registered: " + name) @@ -48,5 +48,5 @@ type MessageID interface { EQ(MessageID) bool // Marshal marshal the message id. - Marshal() []byte + Marshal() string } diff --git a/pkg/streaming/util/message/message_id_test.go b/pkg/streaming/util/message/message_id_test.go index b93ce2924c..437b35b385 100644 --- a/pkg/streaming/util/message/message_id_test.go +++ b/pkg/streaming/util/message/message_id_test.go @@ -1,7 +1,6 @@ package message_test import ( - "bytes" "testing" "github.com/cockroachdb/errors" @@ -14,28 +13,28 @@ import ( func TestRegisterMessageIDUnmarshaler(t *testing.T) { msgID := mock_message.NewMockMessageID(t) - message.RegisterMessageIDUnmsarshaler("test", func(b []byte) (message.MessageID, error) { - if bytes.Equal(b, []byte("123")) { + message.RegisterMessageIDUnmsarshaler("test", func(b string) (message.MessageID, error) { + if b == "123" { return msgID, nil } return nil, errors.New("invalid") }) - id, err := message.UnmarshalMessageID("test", []byte("123")) + id, err := message.UnmarshalMessageID("test", "123") assert.NotNil(t, id) assert.NoError(t, err) - id, err = message.UnmarshalMessageID("test", []byte("1234")) + id, err = message.UnmarshalMessageID("test", "1234") assert.Nil(t, id) assert.Error(t, err) assert.Panics(t, func() { - message.UnmarshalMessageID("test1", []byte("123")) + message.UnmarshalMessageID("test1", "123") }) assert.Panics(t, func() { - message.RegisterMessageIDUnmsarshaler("test", func(b []byte) (message.MessageID, error) { - if bytes.Equal(b, []byte("123")) { + message.RegisterMessageIDUnmsarshaler("test", func(b string) (message.MessageID, error) { + if b == "123" { return msgID, nil } return nil, errors.New("invalid") diff --git a/pkg/streaming/util/message/message_impl.go b/pkg/streaming/util/message/message_impl.go index 47c5affb25..7b31bcd8e7 100644 --- a/pkg/streaming/util/message/message_impl.go +++ b/pkg/streaming/util/message/message_impl.go @@ -2,8 +2,6 @@ package message import ( "fmt" - - "github.com/golang/protobuf/proto" ) type messageImpl struct { @@ -35,7 +33,7 @@ func (m *messageImpl) Payload() []byte { } // Properties returns the message properties. -func (m *messageImpl) Properties() Properties { +func (m *messageImpl) Properties() RProperties { return m.properties } @@ -45,10 +43,15 @@ func (m *messageImpl) EstimateSize() int { return len(m.payload) + m.properties.EstimateSize() } +// WithVChannel sets the virtual channel of current message. +func (m *messageImpl) WithVChannel(vChannel string) MutableMessage { + m.properties.Set(messageVChannel, vChannel) + return m +} + // WithTimeTick sets the time tick of current message. func (m *messageImpl) WithTimeTick(tt uint64) MutableMessage { - t := proto.EncodeVarint(tt) - m.properties.Set(messageTimeTick, string(t)) + m.properties.Set(messageTimeTick, EncodeUint64(tt)) return m } @@ -82,10 +85,9 @@ func (m *immutableMessageImpl) TimeTick() uint64 { if !ok { panic(fmt.Sprintf("there's a bug in the message codes, timetick lost in properties of message, id: %+v", m.id)) } - v := []byte(value) - tt, n := proto.DecodeVarint(v) - if n != len(v) { - panic(fmt.Sprintf("there's a bug in the message codes, dirty timetick in properties of message, id: %+v", m.id)) + tt, err := DecodeUint64(value) + if err != nil { + panic(fmt.Sprintf("there's a bug in the message codes, dirty timetick %s in properties of message, id: %+v", value, m.id)) } return tt } @@ -95,7 +97,7 @@ func (m *immutableMessageImpl) LastConfirmedMessageID() MessageID { if !ok { panic(fmt.Sprintf("there's a bug in the message codes, last confirmed message lost in properties of message, id: %+v", m.id)) } - id, err := UnmarshalMessageID(m.id.WALName(), []byte(value)) + id, err := UnmarshalMessageID(m.id.WALName(), value) if err != nil { panic(fmt.Sprintf("there's a bug in the message codes, dirty last confirmed message in properties of message, id: %+v", m.id)) } @@ -114,8 +116,3 @@ func (m *immutableMessageImpl) VChannel() string { } return value } - -// Properties returns the message read only properties. -func (m *immutableMessageImpl) Properties() RProperties { - return m.properties -} diff --git a/pkg/streaming/util/message/message_test.go b/pkg/streaming/util/message/message_test.go index f35094e08f..04be6f4913 100644 --- a/pkg/streaming/util/message/message_test.go +++ b/pkg/streaming/util/message/message_test.go @@ -11,9 +11,15 @@ func TestMessageType(t *testing.T) { assert.Equal(t, "0", s) typ := unmarshalMessageType("0") assert.Equal(t, MessageTypeUnknown, typ) + assert.False(t, MessageTypeUnknown.Valid()) typ = unmarshalMessageType("882s9") assert.Equal(t, MessageTypeUnknown, typ) + + s = MessageTypeTimeTick.marshal() + typ = unmarshalMessageType(s) + assert.Equal(t, MessageTypeTimeTick, typ) + assert.True(t, MessageTypeTimeTick.Valid()) } func TestVersion(t *testing.T) { diff --git a/pkg/streaming/util/message/message_type.go b/pkg/streaming/util/message/message_type.go index 61686106c7..3dd4859411 100644 --- a/pkg/streaming/util/message/message_type.go +++ b/pkg/streaming/util/message/message_type.go @@ -1,17 +1,35 @@ package message -import "strconv" +import ( + "strconv" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" +) type MessageType int32 const ( - MessageTypeUnknown MessageType = 0 - MessageTypeTimeTick MessageType = 1 + MessageTypeUnknown MessageType = MessageType(commonpb.MsgType_Undefined) + MessageTypeTimeTick MessageType = MessageType(commonpb.MsgType_TimeTick) + MessageTypeInsert MessageType = MessageType(commonpb.MsgType_Insert) + MessageTypeDelete MessageType = MessageType(commonpb.MsgType_Delete) + MessageTypeFlush MessageType = MessageType(commonpb.MsgType_Flush) + MessageTypeCreateCollection MessageType = MessageType(commonpb.MsgType_CreateCollection) + MessageTypeDropCollection MessageType = MessageType(commonpb.MsgType_DropCollection) + MessageTypeCreatePartition MessageType = MessageType(commonpb.MsgType_CreatePartition) + MessageTypeDropPartition MessageType = MessageType(commonpb.MsgType_DropPartition) ) var messageTypeName = map[MessageType]string{ - MessageTypeUnknown: "MESSAGE_TYPE_UNKNOWN", - MessageTypeTimeTick: "MESSAGE_TYPE_TIME_TICK", + MessageTypeUnknown: "UNKNOWN", + MessageTypeTimeTick: "TIME_TICK", + MessageTypeInsert: "INSERT", + MessageTypeDelete: "DELETE", + MessageTypeFlush: "FLUSH", + MessageTypeCreateCollection: "CREATE_COLLECTION", + MessageTypeDropCollection: "DROP_COLLECTION", + MessageTypeCreatePartition: "CREATE_PARTITION", + MessageTypeDropPartition: "DROP_PARTITION", } // String implements fmt.Stringer interface. @@ -26,8 +44,8 @@ func (t MessageType) marshal() string { // Valid checks if the MessageType is valid. func (t MessageType) Valid() bool { - return t == MessageTypeTimeTick - // TODO: fill more. + _, ok := messageTypeName[t] + return t != MessageTypeUnknown && ok } // unmarshalMessageType unmarshal MessageType from string. diff --git a/pkg/streaming/util/options/deliver.go b/pkg/streaming/util/options/deliver.go index 71e1416629..9ea15ea381 100644 --- a/pkg/streaming/util/options/deliver.go +++ b/pkg/streaming/util/options/deliver.go @@ -88,3 +88,8 @@ func DeliverFilterVChannel(vchannel string) DeliverFilter { vchannel: vchannel, } } + +// IsDeliverFilterTimeTick checks if the filter is time tick filter. +func IsDeliverFilterTimeTick(filter DeliverFilter) bool { + return filter.Type() == DeliverFilterTypeTimeTickGT || filter.Type() == DeliverFilterTypeTimeTickGTE +} diff --git a/pkg/streaming/util/types/pchannel_info.go b/pkg/streaming/util/types/pchannel_info.go index 6a4d65c26f..557a7d68c2 100644 --- a/pkg/streaming/util/types/pchannel_info.go +++ b/pkg/streaming/util/types/pchannel_info.go @@ -1,5 +1,7 @@ package types +import "fmt" + const ( InitialTerm int64 = -1 ) @@ -10,6 +12,10 @@ type PChannelInfo struct { Term int64 // term of pchannel. } +func (c *PChannelInfo) String() string { + return fmt.Sprintf("%s@%d", c.Name, c.Term) +} + type PChannelInfoAssigned struct { Channel PChannelInfo Node StreamingNodeInfo diff --git a/pkg/streaming/util/types/streaming_node.go b/pkg/streaming/util/types/streaming_node.go index 8719b50e84..24f6a6cb1b 100644 --- a/pkg/streaming/util/types/streaming_node.go +++ b/pkg/streaming/util/types/streaming_node.go @@ -19,6 +19,15 @@ type AssignmentDiscoverWatcher interface { // The callback will be called when the discovery is changed. // The final error will be returned when the watcher is closed or broken. AssignmentDiscover(ctx context.Context, cb func(*VersionedStreamingNodeAssignments) error) error + + AssignmentRebalanceTrigger +} + +// AssignmentRebalanceTrigger is the interface for triggering the re-balance of the pchannel. +type AssignmentRebalanceTrigger interface { + // ReportStreamingError is used to report the streaming error. + // Trigger a re-balance of the pchannel. + ReportAssignmentError(ctx context.Context, pchannel PChannelInfo, err error) error } // VersionedStreamingNodeAssignments is the relation between server and channels with version. diff --git a/pkg/streaming/walimpls/impls/pulsar/message_id.go b/pkg/streaming/walimpls/impls/pulsar/message_id.go index 3214dd2959..59614dbe35 100644 --- a/pkg/streaming/walimpls/impls/pulsar/message_id.go +++ b/pkg/streaming/walimpls/impls/pulsar/message_id.go @@ -1,14 +1,17 @@ package pulsar import ( + "encoding/hex" + "github.com/apache/pulsar-client-go/pulsar" + "github.com/cockroachdb/errors" "github.com/milvus-io/milvus/pkg/streaming/util/message" ) var _ message.MessageID = pulsarID{} -func UnmarshalMessageID(data []byte) (message.MessageID, error) { +func UnmarshalMessageID(data string) (message.MessageID, error) { id, err := unmarshalMessageID(data) if err != nil { return nil, err @@ -16,10 +19,14 @@ func UnmarshalMessageID(data []byte) (message.MessageID, error) { return id, nil } -func unmarshalMessageID(data []byte) (pulsarID, error) { - msgID, err := pulsar.DeserializeMessageID(data) +func unmarshalMessageID(data string) (pulsarID, error) { + val, err := hex.DecodeString(data) if err != nil { - return pulsarID{nil}, err + return pulsarID{nil}, errors.Wrapf(message.ErrInvalidMessageID, "decode pulsar fail when decode hex with err: %s, id: %s", err.Error(), data) + } + msgID, err := pulsar.DeserializeMessageID(val) + if err != nil { + return pulsarID{nil}, errors.Wrapf(message.ErrInvalidMessageID, "decode pulsar fail when deserialize with err: %s, id: %s", err.Error(), data) } return pulsarID{msgID}, nil } @@ -61,6 +68,6 @@ func (id pulsarID) EQ(other message.MessageID) bool { id.BatchIdx() == id2.BatchIdx() } -func (id pulsarID) Marshal() []byte { - return id.Serialize() +func (id pulsarID) Marshal() string { + return hex.EncodeToString(id.Serialize()) } diff --git a/pkg/streaming/walimpls/impls/pulsar/message_id_test.go b/pkg/streaming/walimpls/impls/pulsar/message_id_test.go index 599014480f..c63422f20e 100644 --- a/pkg/streaming/walimpls/impls/pulsar/message_id_test.go +++ b/pkg/streaming/walimpls/impls/pulsar/message_id_test.go @@ -36,7 +36,7 @@ func TestMessageID(t *testing.T) { assert.NoError(t, err) assert.True(t, msgID.EQ(pulsarID{newMessageIDOfPulsar(1, 2, 3)})) - _, err = UnmarshalMessageID([]byte{0x01, 0x02, 0x03, 0x04}) + _, err = UnmarshalMessageID(string([]byte{0x01, 0x02, 0x03, 0x04})) assert.Error(t, err) } diff --git a/pkg/streaming/walimpls/impls/rmq/message_id.go b/pkg/streaming/walimpls/impls/rmq/message_id.go index 59c7773387..51637822fd 100644 --- a/pkg/streaming/walimpls/impls/rmq/message_id.go +++ b/pkg/streaming/walimpls/impls/rmq/message_id.go @@ -1,11 +1,7 @@ package rmq import ( - "encoding/base64" - "github.com/cockroachdb/errors" - "github.com/golang/protobuf/proto" - "google.golang.org/protobuf/encoding/protowire" "github.com/milvus-io/milvus/pkg/streaming/util/message" ) @@ -13,7 +9,7 @@ import ( var _ message.MessageID = rmqID(0) // UnmarshalMessageID unmarshal the message id. -func UnmarshalMessageID(data []byte) (message.MessageID, error) { +func UnmarshalMessageID(data string) (message.MessageID, error) { id, err := unmarshalMessageID(data) if err != nil { return nil, err @@ -22,12 +18,12 @@ func UnmarshalMessageID(data []byte) (message.MessageID, error) { } // unmashalMessageID unmarshal the message id. -func unmarshalMessageID(data []byte) (rmqID, error) { - v, n := proto.DecodeVarint(data) - if n <= 0 || n != len(data) { - return 0, errors.Wrapf(message.ErrInvalidMessageID, "rmqID: %s", base64.RawStdEncoding.EncodeToString(data)) +func unmarshalMessageID(data string) (rmqID, error) { + v, err := message.DecodeUint64(data) + if err != nil { + return 0, errors.Wrapf(message.ErrInvalidMessageID, "decode rmqID fail with err: %s, id: %s", err.Error(), data) } - return rmqID(protowire.DecodeZigZag(v)), nil + return rmqID(v), nil } // rmqID is the message id for rmq. @@ -54,6 +50,6 @@ func (id rmqID) EQ(other message.MessageID) bool { } // Marshal marshal the message id. -func (id rmqID) Marshal() []byte { - return proto.EncodeVarint(protowire.EncodeZigZag(int64(id))) +func (id rmqID) Marshal() string { + return message.EncodeInt64(int64(id)) } diff --git a/pkg/streaming/walimpls/impls/rmq/message_id_test.go b/pkg/streaming/walimpls/impls/rmq/message_id_test.go index 9e38751918..b757e57ab6 100644 --- a/pkg/streaming/walimpls/impls/rmq/message_id_test.go +++ b/pkg/streaming/walimpls/impls/rmq/message_id_test.go @@ -20,6 +20,6 @@ func TestMessageID(t *testing.T) { assert.NoError(t, err) assert.Equal(t, rmqID(1), msgID) - _, err = UnmarshalMessageID([]byte{0x01, 0x02, 0x03, 0x04}) + _, err = UnmarshalMessageID(string([]byte{0x01, 0x02, 0x03, 0x04})) assert.Error(t, err) } diff --git a/pkg/streaming/walimpls/impls/walimplstest/message_id.go b/pkg/streaming/walimpls/impls/walimplstest/message_id.go index 711d0047cc..b36d775381 100644 --- a/pkg/streaming/walimpls/impls/walimplstest/message_id.go +++ b/pkg/streaming/walimpls/impls/walimplstest/message_id.go @@ -4,8 +4,6 @@ package walimplstest import ( - "strconv" - "github.com/milvus-io/milvus/pkg/streaming/util/message" ) @@ -17,7 +15,7 @@ func NewTestMessageID(id int64) message.MessageID { } // UnmarshalTestMessageID unmarshal the message id. -func UnmarshalTestMessageID(data []byte) (message.MessageID, error) { +func UnmarshalTestMessageID(data string) (message.MessageID, error) { id, err := unmarshalTestMessageID(data) if err != nil { return nil, err @@ -26,8 +24,8 @@ func UnmarshalTestMessageID(data []byte) (message.MessageID, error) { } // unmashalTestMessageID unmarshal the message id. -func unmarshalTestMessageID(data []byte) (testMessageID, error) { - id, err := strconv.ParseInt(string(data), 10, 64) +func unmarshalTestMessageID(data string) (testMessageID, error) { + id, err := message.DecodeInt64(data) if err != nil { return 0, err } @@ -58,6 +56,6 @@ func (id testMessageID) EQ(other message.MessageID) bool { } // Marshal marshal the message id. -func (id testMessageID) Marshal() []byte { - return []byte(strconv.FormatInt(int64(id), 10)) +func (id testMessageID) Marshal() string { + return message.EncodeInt64(int64(id)) } diff --git a/scripts/run_go_unittest.sh b/scripts/run_go_unittest.sh index 933104ce93..7cbf55a11e 100755 --- a/scripts/run_go_unittest.sh +++ b/scripts/run_go_unittest.sh @@ -120,6 +120,7 @@ go test -race -cover -tags dynamic,test "${PKG_DIR}/log/..." -failfast -count=1 go test -race -cover -tags dynamic,test "${PKG_DIR}/mq/..." -failfast -count=1 -ldflags="-r ${RPATH}" go test -race -cover -tags dynamic,test "${PKG_DIR}/tracer/..." -failfast -count=1 -ldflags="-r ${RPATH}" go test -race -cover -tags dynamic,test "${PKG_DIR}/util/..." -failfast -count=1 -ldflags="-r ${RPATH}" +go test -race -cover -tags dynamic,test "${PKG_DIR}/streaming/..." -failfast -count=1 -ldflags="-r ${RPATH}" popd } @@ -169,6 +170,9 @@ function test_streaming() go test -race -cover -tags dynamic,test "${MILVUS_DIR}/streamingcoord/..." -failfast -count=1 -ldflags="-r ${RPATH}" go test -race -cover -tags dynamic,test "${MILVUS_DIR}/streamingnode/..." -failfast -count=1 -ldflags="-r ${RPATH}" go test -race -cover -tags dynamic,test "${MILVUS_DIR}/util/streamingutil/..." -failfast -count=1 -ldflags="-r ${RPATH}" +pushd pkg +go test -race -cover -tags dynamic,test "${PKG_DIR}/streaming/..." -failfast -count=1 -ldflags="-r ${RPATH}" +popd } function test_all()