mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-12-02 11:59:00 +08:00
Update session manager (#27761)
See also: #25309 Signed-off-by: yangxuan <xuan.yang@zilliz.com>
This commit is contained in:
parent
da19e49daf
commit
2bbc20c7b8
@ -31,6 +31,7 @@ import (
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/metrics"
|
||||
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/util/retry"
|
||||
"github.com/milvus-io/milvus/pkg/util/tsoutil"
|
||||
@ -249,15 +250,12 @@ func (c *SessionManager) GetCompactionState() map[int64]*datapb.CompactionStateR
|
||||
commonpbutil.WithSourceID(paramtable.GetNodeID()),
|
||||
),
|
||||
})
|
||||
if err != nil {
|
||||
|
||||
if err := merr.CheckRPCCall(resp, err); err != nil {
|
||||
log.Info("Get State failed", zap.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
if resp.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
|
||||
log.Info("Get State failed", zap.String("Reason", resp.GetStatus().GetReason()))
|
||||
return
|
||||
}
|
||||
for _, rst := range resp.GetResults() {
|
||||
plans.Insert(rst.PlanID, rst)
|
||||
}
|
||||
@ -296,6 +294,46 @@ func (c *SessionManager) FlushChannels(ctx context.Context, nodeID int64, req *d
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *SessionManager) NotifyChannelOperation(ctx context.Context, nodeID int64, req *datapb.ChannelOperationsRequest) error {
|
||||
log := log.Ctx(ctx).With(zap.Int64("nodeID", nodeID))
|
||||
cli, err := c.getClient(ctx, nodeID)
|
||||
if err != nil {
|
||||
log.Info("failed to get dataNode client", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(ctx, Params.DataCoordCfg.ChannelOperationRPCTimeout.GetAsDuration(time.Second))
|
||||
defer cancel()
|
||||
resp, err := cli.NotifyChannelOperation(ctx, req)
|
||||
if err := merr.CheckRPCCall(resp, err); err != nil {
|
||||
log.Warn("Notify channel operations failed", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *SessionManager) CheckChannelOperationProgress(ctx context.Context, nodeID int64, info *datapb.ChannelWatchInfo) (*datapb.ChannelOperationProgressResponse, error) {
|
||||
log := log.With(
|
||||
zap.Int64("nodeID", nodeID),
|
||||
zap.String("channel", info.GetVchan().GetChannelName()),
|
||||
zap.String("operation", info.GetState().String()),
|
||||
)
|
||||
cli, err := c.getClient(ctx, nodeID)
|
||||
if err != nil {
|
||||
log.Info("failed to get dataNode client", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, Params.DataCoordCfg.ChannelOperationRPCTimeout.GetAsDuration(time.Second))
|
||||
defer cancel()
|
||||
resp, err := cli.CheckChannelOperationProgress(ctx, info)
|
||||
if err := merr.CheckRPCCall(resp, err); err != nil {
|
||||
log.Warn("Check channel operation failed", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (c *SessionManager) getClient(ctx context.Context, nodeID int64) (types.DataNodeClient, error) {
|
||||
c.sessions.RLock()
|
||||
session, ok := c.sessions.data[nodeID]
|
||||
|
117
internal/datacoord/session_manager_test.go
Normal file
117
internal/datacoord/session_manager_test.go
Normal file
@ -0,0 +1,117 @@
|
||||
package datacoord
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/suite"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/mocks"
|
||||
"github.com/milvus-io/milvus/internal/proto/datapb"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
)
|
||||
|
||||
func TestSessionManagerSuite(t *testing.T) {
|
||||
suite.Run(t, new(SessionManagerSuite))
|
||||
}
|
||||
|
||||
type SessionManagerSuite struct {
|
||||
suite.Suite
|
||||
|
||||
dn *mocks.MockDataNodeClient
|
||||
|
||||
m *SessionManager
|
||||
}
|
||||
|
||||
func (s *SessionManagerSuite) SetupTest() {
|
||||
s.dn = mocks.NewMockDataNodeClient(s.T())
|
||||
|
||||
s.m = NewSessionManager(withSessionCreator(func(ctx context.Context, addr string, nodeID int64) (types.DataNodeClient, error) {
|
||||
return s.dn, nil
|
||||
}))
|
||||
|
||||
s.m.AddSession(&NodeInfo{1000, "addr-1"})
|
||||
}
|
||||
|
||||
func (s *SessionManagerSuite) TestNotifyChannelOperation() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
info := &datapb.ChannelWatchInfo{
|
||||
Vchan: &datapb.VchannelInfo{},
|
||||
State: datapb.ChannelWatchState_ToWatch,
|
||||
OpID: 1,
|
||||
}
|
||||
|
||||
req := &datapb.ChannelOperationsRequest{
|
||||
Infos: []*datapb.ChannelWatchInfo{info},
|
||||
}
|
||||
s.Run("no node", func() {
|
||||
err := s.m.NotifyChannelOperation(ctx, 100, req)
|
||||
s.Error(err)
|
||||
})
|
||||
|
||||
s.Run("fail", func() {
|
||||
s.SetupTest()
|
||||
s.dn.EXPECT().NotifyChannelOperation(mock.Anything, mock.Anything).Return(nil, errors.New("mock"))
|
||||
|
||||
err := s.m.NotifyChannelOperation(ctx, 1000, req)
|
||||
s.Error(err)
|
||||
})
|
||||
|
||||
s.Run("normal", func() {
|
||||
s.SetupTest()
|
||||
s.dn.EXPECT().NotifyChannelOperation(mock.Anything, mock.Anything).Return(merr.Status(nil), nil)
|
||||
|
||||
err := s.m.NotifyChannelOperation(ctx, 1000, req)
|
||||
s.NoError(err)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *SessionManagerSuite) TestCheckCHannelOperationProgress() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
info := &datapb.ChannelWatchInfo{
|
||||
Vchan: &datapb.VchannelInfo{},
|
||||
State: datapb.ChannelWatchState_ToWatch,
|
||||
OpID: 1,
|
||||
}
|
||||
|
||||
s.Run("no node", func() {
|
||||
resp, err := s.m.CheckChannelOperationProgress(ctx, 100, info)
|
||||
s.Error(err)
|
||||
s.Nil(resp)
|
||||
})
|
||||
|
||||
s.Run("fail", func() {
|
||||
s.SetupTest()
|
||||
s.dn.EXPECT().CheckChannelOperationProgress(mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("mock"))
|
||||
|
||||
resp, err := s.m.CheckChannelOperationProgress(ctx, 1000, info)
|
||||
s.Error(err)
|
||||
s.Nil(resp)
|
||||
})
|
||||
|
||||
s.Run("normal", func() {
|
||||
s.SetupTest()
|
||||
s.dn.EXPECT().CheckChannelOperationProgress(mock.Anything, mock.Anything, mock.Anything).
|
||||
Return(
|
||||
&datapb.ChannelOperationProgressResponse{
|
||||
Status: merr.Status(nil),
|
||||
OpID: info.OpID,
|
||||
State: info.State,
|
||||
Progress: 100,
|
||||
},
|
||||
nil)
|
||||
|
||||
resp, err := s.m.CheckChannelOperationProgress(ctx, 1000, info)
|
||||
s.NoError(err)
|
||||
s.Equal(resp.GetState(), info.State)
|
||||
s.Equal(resp.OpID, info.OpID)
|
||||
s.EqualValues(100, resp.Progress)
|
||||
})
|
||||
}
|
Loading…
Reference in New Issue
Block a user