mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-12-01 11:29:48 +08:00
1c9647ff31
Signed-off-by: Zach41 <zongmei.zhang@zilliz.com>
198 lines
5.5 KiB
Go
198 lines
5.5 KiB
Go
package proxy
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"sort"
|
|
"sync"
|
|
"testing"
|
|
|
|
"github.com/milvus-io/milvus/internal/log"
|
|
"github.com/milvus-io/milvus/internal/types"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
func TestUpdateShardsWithRoundRobin(t *testing.T) {
|
|
list := map[string][]nodeInfo{
|
|
"channel-1": {
|
|
{1, "addr1"},
|
|
{2, "addr2"},
|
|
},
|
|
"channel-2": {
|
|
{20, "addr20"},
|
|
{21, "addr21"},
|
|
},
|
|
}
|
|
|
|
updateShardsWithRoundRobin(list)
|
|
|
|
assert.Equal(t, int64(2), list["channel-1"][0].nodeID)
|
|
assert.Equal(t, "addr2", list["channel-1"][0].address)
|
|
assert.Equal(t, int64(21), list["channel-2"][0].nodeID)
|
|
assert.Equal(t, "addr21", list["channel-2"][0].address)
|
|
|
|
t.Run("check print", func(t *testing.T) {
|
|
qns := []nodeInfo{
|
|
{1, "addr1"},
|
|
{2, "addr2"},
|
|
{20, "addr20"},
|
|
{21, "addr21"},
|
|
}
|
|
|
|
res := fmt.Sprintf("list: %v", qns)
|
|
|
|
log.Debug("Check String func",
|
|
zap.Any("Any", qns),
|
|
zap.Any("ok", qns[0]),
|
|
zap.String("ok2", res),
|
|
)
|
|
|
|
})
|
|
}
|
|
|
|
func TestGroupShardLeadersWithSameQueryNode(t *testing.T) {
|
|
var err error
|
|
|
|
Params.Init()
|
|
var (
|
|
ctx = context.TODO()
|
|
)
|
|
|
|
mgr := newShardClientMgr()
|
|
|
|
shard2leaders := map[string][]nodeInfo{
|
|
"c0": {{nodeID: 0, address: "fake"}, {nodeID: 1, address: "fake"}, {nodeID: 2, address: "fake"}},
|
|
"c1": {{nodeID: 1, address: "fake"}, {nodeID: 2, address: "fake"}, {nodeID: 3, address: "fake"}},
|
|
"c2": {{nodeID: 0, address: "fake"}, {nodeID: 2, address: "fake"}, {nodeID: 3, address: "fake"}},
|
|
"c3": {{nodeID: 1, address: "fake"}, {nodeID: 3, address: "fake"}, {nodeID: 4, address: "fake"}},
|
|
}
|
|
mgr.UpdateShardLeaders(nil, shard2leaders)
|
|
nexts := map[string]int{
|
|
"c0": 0,
|
|
"c1": 0,
|
|
"c2": 0,
|
|
"c3": 0,
|
|
}
|
|
errSet := map[string]error{}
|
|
node2dmls, qnSet, err := groupShardleadersWithSameQueryNode(ctx, shard2leaders, nexts, errSet, mgr)
|
|
assert.Nil(t, err)
|
|
for nodeID := range node2dmls {
|
|
sort.Slice(node2dmls[nodeID], func(i, j int) bool { return node2dmls[nodeID][i] < node2dmls[nodeID][j] })
|
|
}
|
|
|
|
cli0, err := mgr.GetClient(ctx, 0)
|
|
assert.Nil(t, err)
|
|
cli1, err := mgr.GetClient(ctx, 1)
|
|
assert.Nil(t, err)
|
|
cli2, err := mgr.GetClient(ctx, 2)
|
|
assert.Nil(t, err)
|
|
cli3, err := mgr.GetClient(ctx, 3)
|
|
assert.Nil(t, err)
|
|
|
|
assert.Equal(t, node2dmls, map[int64][]string{0: {"c0", "c2"}, 1: {"c1", "c3"}})
|
|
assert.Equal(t, qnSet, map[int64]types.QueryNode{0: cli0, 1: cli1})
|
|
assert.Equal(t, nexts, map[string]int{"c0": 1, "c1": 1, "c2": 1, "c3": 1})
|
|
// delete client1 in client mgr
|
|
delete(mgr.clients.data, 1)
|
|
node2dmls, qnSet, err = groupShardleadersWithSameQueryNode(ctx, shard2leaders, nexts, errSet, mgr)
|
|
assert.Nil(t, err)
|
|
for nodeID := range node2dmls {
|
|
sort.Slice(node2dmls[nodeID], func(i, j int) bool { return node2dmls[nodeID][i] < node2dmls[nodeID][j] })
|
|
}
|
|
assert.Equal(t, node2dmls, map[int64][]string{2: {"c1", "c2"}, 3: {"c3"}})
|
|
assert.Equal(t, qnSet, map[int64]types.QueryNode{2: cli2, 3: cli3})
|
|
assert.Equal(t, nexts, map[string]int{"c0": 2, "c1": 2, "c2": 2, "c3": 2})
|
|
assert.NotNil(t, errSet["c0"])
|
|
|
|
nexts["c0"] = 3
|
|
_, _, err = groupShardleadersWithSameQueryNode(ctx, shard2leaders, nexts, errSet, mgr)
|
|
assert.Equal(t, err, errSet["c0"])
|
|
|
|
nexts["c0"] = 2
|
|
nexts["c1"] = 3
|
|
_, _, err = groupShardleadersWithSameQueryNode(ctx, shard2leaders, nexts, errSet, mgr)
|
|
assert.Equal(t, err, fmt.Errorf("no available shard leader"))
|
|
}
|
|
|
|
func TestMergeRoundRobinPolicy(t *testing.T) {
|
|
var err error
|
|
|
|
Params.Init()
|
|
var (
|
|
ctx = context.TODO()
|
|
)
|
|
|
|
mgr := newShardClientMgr()
|
|
|
|
shard2leaders := map[string][]nodeInfo{
|
|
"c0": {{nodeID: 0, address: "fake"}, {nodeID: 1, address: "fake"}, {nodeID: 2, address: "fake"}},
|
|
"c1": {{nodeID: 1, address: "fake"}, {nodeID: 2, address: "fake"}, {nodeID: 3, address: "fake"}},
|
|
"c2": {{nodeID: 0, address: "fake"}, {nodeID: 2, address: "fake"}, {nodeID: 3, address: "fake"}},
|
|
"c3": {{nodeID: 1, address: "fake"}, {nodeID: 3, address: "fake"}, {nodeID: 4, address: "fake"}},
|
|
}
|
|
mgr.UpdateShardLeaders(nil, shard2leaders)
|
|
|
|
querier := &mockQuery{}
|
|
querier.init()
|
|
|
|
err = mergeRoundRobinPolicy(ctx, mgr, querier.query, shard2leaders)
|
|
assert.Nil(t, err)
|
|
assert.Equal(t, querier.records(), map[UniqueID][]string{0: {"c0", "c2"}, 1: {"c1", "c3"}})
|
|
|
|
mockerr := fmt.Errorf("mock query node error")
|
|
querier.init()
|
|
querier.failset[0] = mockerr
|
|
|
|
err = mergeRoundRobinPolicy(ctx, mgr, querier.query, shard2leaders)
|
|
assert.Nil(t, err)
|
|
assert.Equal(t, querier.records(), map[int64][]string{1: {"c0", "c1", "c3"}, 2: {"c2"}})
|
|
|
|
querier.init()
|
|
querier.failset[0] = mockerr
|
|
querier.failset[2] = mockerr
|
|
querier.failset[3] = mockerr
|
|
err = mergeRoundRobinPolicy(ctx, mgr, querier.query, shard2leaders)
|
|
assert.Equal(t, err, mockerr)
|
|
}
|
|
|
|
func mockQueryNodeCreator(ctx context.Context, address string) (types.QueryNode, error) {
|
|
return &QueryNodeMock{address: address}, nil
|
|
}
|
|
|
|
type mockQuery struct {
|
|
mu sync.Mutex
|
|
queryset map[UniqueID][]string
|
|
failset map[UniqueID]error
|
|
}
|
|
|
|
func (m *mockQuery) query(_ context.Context, nodeID UniqueID, qn types.QueryNode, chs []string) error {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
if err, ok := m.failset[nodeID]; ok {
|
|
return err
|
|
}
|
|
m.queryset[nodeID] = append(m.queryset[nodeID], chs...)
|
|
return nil
|
|
}
|
|
|
|
func (m *mockQuery) init() {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
m.queryset = make(map[int64][]string)
|
|
m.failset = make(map[int64]error)
|
|
}
|
|
|
|
func (m *mockQuery) records() map[UniqueID][]string {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
for nodeID := range m.queryset {
|
|
sort.Slice(m.queryset[nodeID], func(i, j int) bool {
|
|
return m.queryset[nodeID][i] < m.queryset[nodeID][j]
|
|
})
|
|
}
|
|
return m.queryset
|
|
}
|