package sessionutil import ( "context" "errors" "fmt" "math/rand" "strconv" "strings" "sync" "testing" "time" etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" "github.com/milvus-io/milvus/internal/util/paramtable" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.etcd.io/etcd/api/v3/mvccpb" v3rpc "go.etcd.io/etcd/api/v3/v3rpc/rpctypes" clientv3 "go.etcd.io/etcd/client/v3" ) var Params paramtable.BaseTable func TestGetServerIDConcurrently(t *testing.T) { ctx := context.Background() Params.Init() endpoints, err := Params.Load("_EtcdEndpoints") metaRoot := fmt.Sprintf("%d/%s", rand.Int(), DefaultServiceRoot) if err != nil { panic(err) } etcdEndpoints := strings.Split(endpoints, ",") etcdKV, err := etcdkv.NewEtcdKV(etcdEndpoints, metaRoot) assert.NoError(t, err) err = etcdKV.RemoveWithPrefix("") assert.NoError(t, err) defer etcdKV.Close() defer etcdKV.RemoveWithPrefix("") var wg sync.WaitGroup var muList = sync.Mutex{} s := NewSession(ctx, metaRoot, etcdEndpoints) res := make([]int64, 0) getIDFunc := func() { s.checkIDExist() id, err := s.getServerID() assert.Nil(t, err) muList.Lock() res = append(res, id) muList.Unlock() wg.Done() } for i := 0; i < 10; i++ { wg.Add(1) go getIDFunc() } wg.Wait() for i := 1; i <= 10; i++ { assert.Contains(t, res, int64(i)) } } func TestInit(t *testing.T) { ctx := context.Background() Params.Init() endpoints, err := Params.Load("_EtcdEndpoints") if err != nil { panic(err) } metaRoot := fmt.Sprintf("%d/%s", rand.Int(), DefaultServiceRoot) etcdEndpoints := strings.Split(endpoints, ",") etcdKV, err := etcdkv.NewEtcdKV(etcdEndpoints, metaRoot) assert.NoError(t, err) err = etcdKV.RemoveWithPrefix("") assert.NoError(t, err) defer etcdKV.Close() defer etcdKV.RemoveWithPrefix("") s := NewSession(ctx, metaRoot, etcdEndpoints) s.Init("inittest", "testAddr", false) assert.NotEqual(t, int64(0), s.leaseID) assert.NotEqual(t, int64(0), s.ServerID) sessions, _, err := s.GetSessions("inittest") assert.Nil(t, err) assert.Contains(t, sessions, "inittest-"+strconv.FormatInt(s.ServerID, 10)) } func TestUpdateSessions(t *testing.T) { ctx := context.Background() Params.Init() endpoints, err := Params.Load("_EtcdEndpoints") if err != nil { panic(err) } etcdEndpoints := strings.Split(endpoints, ",") metaRoot := fmt.Sprintf("%d/%s", rand.Int(), DefaultServiceRoot) etcdKV, err := etcdkv.NewEtcdKV(etcdEndpoints, "") assert.NoError(t, err) defer etcdKV.Close() defer etcdKV.RemoveWithPrefix("") var wg sync.WaitGroup var muList = sync.Mutex{} s := NewSession(ctx, metaRoot, etcdEndpoints) sessions, rev, err := s.GetSessions("test") assert.Nil(t, err) assert.Equal(t, len(sessions), 0) eventCh := s.WatchServices("test", rev, nil) sList := []*Session{} getIDFunc := func() { singleS := NewSession(ctx, metaRoot, etcdEndpoints) singleS.Init("test", "testAddr", false) muList.Lock() sList = append(sList, singleS) muList.Unlock() wg.Done() } for i := 0; i < 10; i++ { wg.Add(1) go getIDFunc() } wg.Wait() assert.Eventually(t, func() bool { sessions, _, _ := s.GetSessions("test") return len(sessions) == 10 }, 10*time.Second, 100*time.Millisecond) notExistSessions, _, _ := s.GetSessions("testt") assert.Equal(t, len(notExistSessions), 0) etcdKV.RemoveWithPrefix(metaRoot) assert.Eventually(t, func() bool { sessions, _, _ := s.GetSessions("test") return len(sessions) == 0 }, 10*time.Second, 100*time.Millisecond) sessionEvents := []*SessionEvent{} addEventLen := 0 delEventLen := 0 eventLength := len(eventCh) for i := 0; i < eventLength; i++ { sessionEvent := <-eventCh if sessionEvent.EventType == SessionAddEvent { addEventLen++ } if sessionEvent.EventType == SessionDelEvent { delEventLen++ } sessionEvents = append(sessionEvents, sessionEvent) } assert.Equal(t, len(sessionEvents), 20) assert.Equal(t, addEventLen, 10) assert.Equal(t, delEventLen, 10) } func TestSessionLivenessCheck(t *testing.T) { s := &Session{} ctx := context.Background() ch := make(chan bool) s.liveCh = ch signal := make(chan struct{}, 1) flag := false go s.LivenessCheck(ctx, func() { flag = true signal <- struct{}{} }) assert.False(t, flag) ch <- true assert.False(t, flag) close(ch) <-signal assert.True(t, flag) ctx, cancel := context.WithCancel(ctx) cancel() ch = make(chan bool) s.liveCh = ch flag = false go s.LivenessCheck(ctx, func() { flag = true signal <- struct{}{} }) assert.False(t, flag) } func TestWatcherHandleWatchResp(t *testing.T) { ctx := context.Background() Params.Init() endpoints, err := Params.Load("_EtcdEndpoints") require.NoError(t, err) etcdEndpoints := strings.Split(endpoints, ",") metaRoot := fmt.Sprintf("%d/%s", rand.Int(), DefaultServiceRoot) etcdKV, err := etcdkv.NewEtcdKV(etcdEndpoints, "/by-dev/session-ut") require.NoError(t, err) defer etcdKV.Close() defer etcdKV.RemoveWithPrefix("/by-dev/session-ut") s := NewSession(ctx, metaRoot, etcdEndpoints) defer s.Revoke(time.Second) getWatcher := func(s *Session, rewatch Rewatch) *sessionWatcher { return &sessionWatcher{ s: s, prefix: "test", rewatch: rewatch, eventCh: make(chan *SessionEvent, 10), } } t.Run("handle normal events", func(t *testing.T) { w := getWatcher(s, nil) wresp := clientv3.WatchResponse{ Events: []*clientv3.Event{ { Type: mvccpb.PUT, Kv: &mvccpb.KeyValue{ Value: []byte(`{"ServerID": 1, "ServerName": "test1"}`), }, }, { Type: mvccpb.DELETE, PrevKv: &mvccpb.KeyValue{ Value: []byte(`{"ServerID": 2, "ServerName": "test2"}`), }, }, }, } err := w.handleWatchResponse(wresp) assert.NoError(t, err) assert.Equal(t, 2, len(w.eventCh)) }) t.Run("handle abnormal events", func(t *testing.T) { w := getWatcher(s, nil) wresp := clientv3.WatchResponse{ Events: []*clientv3.Event{ { Type: mvccpb.PUT, Kv: &mvccpb.KeyValue{ Value: []byte(``), }, }, { Type: mvccpb.DELETE, PrevKv: &mvccpb.KeyValue{ Value: []byte(``), }, }, }, } var err error assert.NotPanics(t, func() { err = w.handleWatchResponse(wresp) }) assert.NoError(t, err) assert.Equal(t, 0, len(w.eventCh)) }) t.Run("err compacted resp, nil Rewatch", func(t *testing.T) { w := getWatcher(s, nil) wresp := clientv3.WatchResponse{ CompactRevision: 1, } err := w.handleWatchResponse(wresp) assert.Error(t, err) assert.Equal(t, v3rpc.ErrCompacted, err) }) t.Run("err compacted resp, valid Rewatch", func(t *testing.T) { w := getWatcher(s, func(sessions map[string]*Session) error { return nil }) wresp := clientv3.WatchResponse{ CompactRevision: 1, } err := w.handleWatchResponse(wresp) assert.NoError(t, err) }) t.Run("err canceled", func(t *testing.T) { w := getWatcher(s, nil) wresp := clientv3.WatchResponse{ Canceled: true, } err := w.handleWatchResponse(wresp) assert.Error(t, err) }) t.Run("err handled but list failed", func(t *testing.T) { s := NewSession(ctx, "/by-dev/session-ut", etcdEndpoints) s.etcdCli.Close() w := getWatcher(s, func(sessions map[string]*Session) error { return nil }) wresp := clientv3.WatchResponse{ CompactRevision: 1, } err = w.handleWatchResponse(wresp) assert.Error(t, err) }) t.Run("err handled but rewatch failed", func(t *testing.T) { w := getWatcher(s, func(sessions map[string]*Session) error { return errors.New("mocked") }) wresp := clientv3.WatchResponse{ CompactRevision: 1, } err := w.handleWatchResponse(wresp) assert.Error(t, err) }) } func TestSessionRevoke(t *testing.T) { s := &Session{} assert.NotPanics(t, func() { s.Revoke(time.Second) }) s = (*Session)(nil) assert.NotPanics(t, func() { s.Revoke(time.Second) }) ctx := context.Background() Params.Init() endpoints, err := Params.Load("_EtcdEndpoints") if err != nil { panic(err) } metaRoot := fmt.Sprintf("%d/%s", rand.Int(), DefaultServiceRoot) etcdEndpoints := strings.Split(endpoints, ",") etcdKV, err := etcdkv.NewEtcdKV(etcdEndpoints, metaRoot) assert.NoError(t, err) err = etcdKV.RemoveWithPrefix("") assert.NoError(t, err) defer etcdKV.Close() defer etcdKV.RemoveWithPrefix("") s = NewSession(ctx, metaRoot, etcdEndpoints) s.Init("revoketest", "testAddr", false) assert.NotPanics(t, func() { s.Revoke(time.Second) }) } func TestSession_Registered(t *testing.T) { session := &Session{} session.UpdateRegistered(false) assert.False(t, session.Registered()) session.UpdateRegistered(true) assert.True(t, session.Registered()) }