milvus/internal/proxy/accesslog/info/restful_info_test.go
2024-06-05 17:13:50 +08:00

206 lines
5.5 KiB
Go

// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package info
import (
"fmt"
"net/http"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/paramtable"
)
type RestfulAccessInfoSuite struct {
suite.Suite
username string
traceID string
info *RestfulInfo
}
func (s *RestfulAccessInfoSuite) SetupSuite() {
paramtable.Init()
}
func (s *RestfulAccessInfoSuite) SetupTest() {
s.username = "test-user"
s.traceID = "test-trace"
s.info = &RestfulInfo{}
s.info.SetParams(
&gin.LogFormatterParams{
Keys: make(map[string]any),
})
}
func (s *RestfulAccessInfoSuite) TestTimeCost() {
s.info.params.Latency = time.Second
result := Get(s.info, "$time_cost")
s.Equal(fmt.Sprint(time.Second), result[0])
}
func (s *RestfulAccessInfoSuite) TestTimeNow() {
result := Get(s.info, "$time_now")
s.NotEqual(Unknown, result[0])
}
func (s *RestfulAccessInfoSuite) TestTimeStart() {
result := Get(s.info, "$time_start")
s.Equal(Unknown, result[0])
s.info.start = time.Now()
result = Get(s.info, "$time_start")
s.Equal(s.info.start.Format(timeFormat), result[0])
}
func (s *RestfulAccessInfoSuite) TestTimeEnd() {
s.info.params.TimeStamp = time.Now()
result := Get(s.info, "$time_end")
s.Equal(s.info.params.TimeStamp.Format(timeFormat), result[0])
}
func (s *RestfulAccessInfoSuite) TestMethodName() {
s.info.params.Path = "/restful/test"
result := Get(s.info, "$method_name")
s.Equal(s.info.params.Path, result[0])
}
func (s *RestfulAccessInfoSuite) TestAddress() {
s.info.params.ClientIP = "127.0.0.1"
result := Get(s.info, "$user_addr")
s.Equal(s.info.params.ClientIP, result[0])
}
func (s *RestfulAccessInfoSuite) TestTraceID() {
result := Get(s.info, "$trace_id")
s.Equal(Unknown, result[0])
s.info.params.Keys["traceID"] = "testtrace"
result = Get(s.info, "$trace_id")
s.Equal(s.info.params.Keys["traceID"], result[0])
}
func (s *RestfulAccessInfoSuite) TestStatus() {
s.info.params.StatusCode = http.StatusBadRequest
result := Get(s.info, "$method_status")
s.Equal("HttpError400", result[0])
s.info.params.StatusCode = http.StatusOK
s.info.params.Keys[ContextReturnCode] = merr.Code(merr.ErrChannelLack)
result = Get(s.info, "$method_status")
s.Equal("Failed", result[0])
s.info.params.StatusCode = http.StatusOK
s.info.params.Keys[ContextReturnCode] = merr.Code(nil)
result = Get(s.info, "$method_status")
s.Equal("Successful", result[0])
}
func (s *RestfulAccessInfoSuite) TestErrorCode() {
result := Get(s.info, "$error_code")
s.Equal(Unknown, result[0])
s.info.params.Keys[ContextReturnCode] = 200
result = Get(s.info, "$error_code")
s.Equal(fmt.Sprint(200), result[0])
}
func (s *RestfulAccessInfoSuite) TestErrorMsg() {
s.info.params.Keys[ContextReturnMessage] = merr.ErrChannelLack.Error()
result := Get(s.info, "$error_msg")
s.Equal(merr.ErrChannelLack.Error(), result[0])
}
func (s *RestfulAccessInfoSuite) TestDbName() {
result := Get(s.info, "$database_name")
s.Equal(Unknown, result[0])
req := &milvuspb.QueryRequest{
DbName: "test",
}
s.info.req = req
result = Get(s.info, "$database_name")
s.Equal("test", result[0])
}
func (s *RestfulAccessInfoSuite) TestSdkInfo() {
result := Get(s.info, "$sdk_version")
s.Equal("Restful", result[0])
}
func (s *RestfulAccessInfoSuite) TestExpression() {
result := Get(s.info, "$method_expr")
s.Equal(Unknown, result[0])
testExpr := "test"
s.info.req = &milvuspb.QueryRequest{
Expr: testExpr,
}
result = Get(s.info, "$method_expr")
s.Equal(testExpr, result[0])
s.info.req = &milvuspb.SearchRequest{
Dsl: testExpr,
}
result = Get(s.info, "$method_expr")
s.Equal(testExpr, result[0])
}
func (s *RestfulAccessInfoSuite) TestOutputFields() {
result := Get(s.info, "$output_fields")
s.Equal(Unknown, result[0])
fields := []string{"pk"}
s.info.params.Keys[ContextRequest] = &milvuspb.QueryRequest{
OutputFields: fields,
}
s.info.InitReq()
result = Get(s.info, "$output_fields")
s.Equal(fmt.Sprint(fields), result[0])
}
func (s *RestfulAccessInfoSuite) TestConsistencyLevel() {
result := Get(s.info, "$consistency_level")
s.Equal(Unknown, result[0])
s.info.params.Keys[ContextRequest] = &milvuspb.QueryRequest{
ConsistencyLevel: commonpb.ConsistencyLevel_Bounded,
}
s.info.InitReq()
result = Get(s.info, "$consistency_level")
s.Equal(commonpb.ConsistencyLevel_Bounded.String(), result[0])
}
func (s *RestfulAccessInfoSuite) TestClusterPrefix() {
cluster := "instance-test"
paramtable.Init()
ClusterPrefix.Store(cluster)
result := Get(s.info, "$cluster_prefix")
s.Equal(cluster, result[0])
}
func TestRestfulAccessInfo(t *testing.T) {
suite.Run(t, new(RestfulAccessInfoSuite))
}