milvus/internal/datanode/channel_manager_test.go
congqixia d9efea2fea
fix: Cleanup write buffer when flowgraph released (#31376)
See also #30137

Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
2024-03-19 01:33:05 +08:00

213 lines
6.4 KiB
Go

// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package datanode
import (
"context"
"testing"
"github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/datanode/allocator"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/util/paramtable"
)
func TestChannelManagerSuite(t *testing.T) {
suite.Run(t, new(ChannelManagerSuite))
}
type ChannelManagerSuite struct {
suite.Suite
node *DataNode
manager *ChannelManagerImpl
}
func (s *ChannelManagerSuite) SetupTest() {
ctx := context.Background()
s.node = newIDLEDataNodeMock(ctx, schemapb.DataType_Int64)
s.node.allocator = allocator.NewMockAllocator(s.T())
s.manager = NewChannelManager(s.node)
}
func getWatchInfoByOpID(opID UniqueID, channel string, state datapb.ChannelWatchState) *datapb.ChannelWatchInfo {
return &datapb.ChannelWatchInfo{
OpID: opID,
State: state,
Vchan: &datapb.VchannelInfo{
CollectionID: 1,
ChannelName: channel,
},
Schema: &schemapb.CollectionSchema{
Name: "test_collection",
Fields: []*schemapb.FieldSchema{
{
FieldID: common.RowIDField, Name: common.RowIDFieldName, DataType: schemapb.DataType_Int64,
},
{
FieldID: common.TimeStampField, Name: common.TimeStampFieldName, DataType: schemapb.DataType_Int64,
},
{
FieldID: 100, Name: "pk", DataType: schemapb.DataType_Int64, IsPrimaryKey: true,
},
{
FieldID: 101, Name: "vector", DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: common.DimKey, Value: "128"},
},
},
},
},
}
}
func (s *ChannelManagerSuite) TearDownTest() {
s.manager.Close()
}
func (s *ChannelManagerSuite) TestWatchFail() {
channel := "by-dev-rootcoord-dml-2"
paramtable.Get().Save(Params.DataCoordCfg.WatchTimeoutInterval.Key, "0.000001")
defer paramtable.Get().Reset(Params.DataCoordCfg.WatchTimeoutInterval.Key)
info := getWatchInfoByOpID(100, channel, datapb.ChannelWatchState_ToWatch)
s.Require().Equal(0, s.manager.opRunners.Len())
err := s.manager.Submit(info)
s.Require().NoError(err)
opState := <-s.manager.communicateCh
s.Require().NotNil(opState)
s.Equal(info.GetOpID(), opState.opID)
s.Equal(datapb.ChannelWatchState_WatchFailure, opState.state)
s.manager.handleOpState(opState)
resp := s.manager.GetProgress(info)
s.Equal(datapb.ChannelWatchState_WatchFailure, resp.GetState())
}
func (s *ChannelManagerSuite) TestReleaseStuck() {
var (
channel = "by-dev-rootcoord-dml-2"
stuckSig = make(chan struct{})
)
s.manager.releaseFunc = func(channel string) {
stuckSig <- struct{}{}
}
info := getWatchInfoByOpID(100, channel, datapb.ChannelWatchState_ToWatch)
s.Require().Equal(0, s.manager.opRunners.Len())
err := s.manager.Submit(info)
s.Require().NoError(err)
opState := <-s.manager.communicateCh
s.Require().NotNil(opState)
s.manager.handleOpState(opState)
releaseInfo := getWatchInfoByOpID(101, channel, datapb.ChannelWatchState_ToRelease)
paramtable.Get().Save(Params.DataCoordCfg.WatchTimeoutInterval.Key, "0.1")
defer paramtable.Get().Reset(Params.DataCoordCfg.WatchTimeoutInterval.Key)
err = s.manager.Submit(releaseInfo)
s.NoError(err)
opState = <-s.manager.communicateCh
s.Require().NotNil(opState)
s.Equal(datapb.ChannelWatchState_ReleaseFailure, opState.state)
s.manager.handleOpState(opState)
s.Equal(1, s.manager.abnormals.Len())
abchannel, ok := s.manager.abnormals.Get(releaseInfo.GetOpID())
s.True(ok)
s.Equal(channel, abchannel)
<-stuckSig
resp := s.manager.GetProgress(releaseInfo)
s.Equal(datapb.ChannelWatchState_ReleaseFailure, resp.GetState())
}
func (s *ChannelManagerSuite) TestSubmitIdempotent() {
channel := "by-dev-rootcoord-dml-1"
info := getWatchInfoByOpID(100, channel, datapb.ChannelWatchState_ToWatch)
s.Require().Equal(0, s.manager.opRunners.Len())
for i := 0; i < 10; i++ {
err := s.manager.Submit(info)
s.NoError(err)
}
s.Equal(1, s.manager.opRunners.Len())
s.True(s.manager.opRunners.Contain(channel))
runner, ok := s.manager.opRunners.Get(channel)
s.True(ok)
s.Equal(1, runner.UnfinishedOpSize())
}
func (s *ChannelManagerSuite) TestSubmitWatchAndRelease() {
channel := "by-dev-rootcoord-dml-0"
info := getWatchInfoByOpID(100, channel, datapb.ChannelWatchState_ToWatch)
err := s.manager.Submit(info)
s.NoError(err)
opState := <-s.manager.communicateCh
s.NotNil(opState)
s.Equal(datapb.ChannelWatchState_WatchSuccess, opState.state)
s.NotNil(opState.fg)
s.Equal(info.GetOpID(), opState.fg.opID)
resp := s.manager.GetProgress(info)
s.Equal(info.GetOpID(), resp.GetOpID())
s.Equal(datapb.ChannelWatchState_ToWatch, resp.GetState())
s.manager.handleOpState(opState)
s.Equal(1, s.manager.fgManager.GetFlowgraphCount())
s.True(s.manager.opRunners.Contain(info.GetVchan().GetChannelName()))
s.Equal(1, s.manager.opRunners.Len())
resp = s.manager.GetProgress(info)
s.Equal(info.GetOpID(), resp.GetOpID())
s.Equal(datapb.ChannelWatchState_WatchSuccess, resp.GetState())
// release
info = getWatchInfoByOpID(101, channel, datapb.ChannelWatchState_ToRelease)
err = s.manager.Submit(info)
s.NoError(err)
opState = <-s.manager.communicateCh
s.NotNil(opState)
s.Equal(datapb.ChannelWatchState_ReleaseSuccess, opState.state)
s.manager.handleOpState(opState)
resp = s.manager.GetProgress(info)
s.Equal(info.GetOpID(), resp.GetOpID())
s.Equal(datapb.ChannelWatchState_ReleaseSuccess, resp.GetState())
s.Equal(0, s.manager.fgManager.GetFlowgraphCount())
s.False(s.manager.opRunners.Contain(info.GetVchan().GetChannelName()))
s.Equal(0, s.manager.opRunners.Len())
}