diff --git a/internal/parser/planparserv2/plan_parser_v2.go b/internal/parser/planparserv2/plan_parser_v2.go index 1fbf14ffda..e298bcaa53 100644 --- a/internal/parser/planparserv2/plan_parser_v2.go +++ b/internal/parser/planparserv2/plan_parser_v2.go @@ -12,68 +12,82 @@ import ( "go.uber.org/zap" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + planparserv2 "github.com/milvus-io/milvus/internal/parser/planparserv2/generated" "github.com/milvus-io/milvus/internal/proto/planpb" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/typeutil" ) -// exprParseKey is used to cache the parse result. Currently only collectionName is used besides expr string, which implies -// that the same collectionName will have the same schema thus the same parse result. In the future, if there is case that the -// schema changes without changing the collectionName, we need to change the cache key. -type exprParseKey struct { - collectionName string - expr string -} - -var exprCache = expirable.NewLRU[exprParseKey, any](256, nil, time.Minute*10) - -func handleExpr(schema *typeutil.SchemaHelper, exprStr string) interface{} { - parseKey := exprParseKey{collectionName: schema.GetCollectionName(), expr: exprStr} - val, ok := exprCache.Get(parseKey) - if !ok { - exprStr = convertHanToASCII(exprStr) - val = handleExprWithErrorListener(schema, exprStr, &errorListenerImpl{}) - // Note that the errors will be cached, too. - exprCache.Add(parseKey, val) +var ( + exprCache = expirable.NewLRU[string, any](1024, nil, time.Minute*10) + trueLiteral = &ExprWithType{ + dataType: schemapb.DataType_Bool, + expr: alwaysTrueExpr(), } +) - return val -} - -func handleExprWithErrorListener(schema *typeutil.SchemaHelper, exprStr string, errorListener errorListener) interface{} { - if isEmptyExpression(exprStr) { - return &ExprWithType{ - dataType: schemapb.DataType_Bool, - expr: alwaysTrueExpr(), +func handleInternal(exprStr string) (ast planparserv2.IExprContext, err error) { + val, ok := exprCache.Get(exprStr) + if ok { + switch v := val.(type) { + case planparserv2.IExprContext: + return v, nil + case error: + return nil, v + default: + return nil, fmt.Errorf("unknown cache error: %v", v) } } - inputStream := antlr.NewInputStream(exprStr) - lexer := getLexer(inputStream, errorListener) - if errorListener.Error() != nil { - return errorListener.Error() + // Note that the errors will be cached, too. + defer func() { + if err != nil { + exprCache.Add(exprStr, err) + } + }() + exprNormal := convertHanToASCII(exprStr) + listener := &errorListenerImpl{} + + inputStream := antlr.NewInputStream(exprNormal) + lexer := getLexer(inputStream, listener) + if err = listener.Error(); err != nil { + return } - parser := getParser(lexer, errorListener) - if errorListener.Error() != nil { - return errorListener.Error() + parser := getParser(lexer, listener) + if err = listener.Error(); err != nil { + return } - ast := parser.Expr() - if errorListener.Error() != nil { - return errorListener.Error() + ast = parser.Expr() + if err = listener.Error(); err != nil { + return } if parser.GetCurrentToken().GetTokenType() != antlr.TokenEOF { log.Info("invalid expression", zap.String("expr", exprStr)) - return fmt.Errorf("invalid expression: %s", exprStr) + err = fmt.Errorf("invalid expression: %s", exprStr) + return } // lexer & parser won't be used by this thread, can be put into pool. putLexer(lexer) putParser(parser) + exprCache.Add(exprStr, ast) + return +} + +func handleExpr(schema *typeutil.SchemaHelper, exprStr string) interface{} { + if isEmptyExpression(exprStr) { + return trueLiteral + } + ast, err := handleInternal(exprStr) + if err != nil { + return err + } + visitor := NewParserVisitor(schema) return ast.Accept(visitor) } diff --git a/internal/parser/planparserv2/plan_parser_v2_test.go b/internal/parser/planparserv2/plan_parser_v2_test.go index febff4be83..69b17398d2 100644 --- a/internal/parser/planparserv2/plan_parser_v2_test.go +++ b/internal/parser/planparserv2/plan_parser_v2_test.go @@ -719,7 +719,7 @@ func Test_FixErrorListenerNotRemoved(t *testing.T) { normal := "1 < Int32Field < (Int16Field)" for i := 0; i < 10; i++ { - err := handleExprWithErrorListener(schemaHelper, normal, &errorListenerTest{}) + err := handleExpr(schemaHelper, normal) err1, ok := err.(error) assert.True(t, ok) assert.Error(t, err1) @@ -1379,25 +1379,21 @@ func BenchmarkPlanCache(b *testing.B) { b.ResetTimer() - for i := 0; i < b.N; i++ { - r := handleExpr(schemaHelper, "array_length(ArrayField) == 10") - err := getError(r) - assert.NoError(b, err) - } -} + b.Run("cached", func(b *testing.B) { + for i := 0; i < b.N; i++ { + r := handleExpr(schemaHelper, "array_length(ArrayField) == 10") + err := getError(r) + assert.NoError(b, err) + } + }) -func BenchmarkNoPlanCache(b *testing.B) { - schema := newTestSchema() - schemaHelper, err := typeutil.CreateSchemaHelper(schema) - require.NoError(b, err) - - b.ResetTimer() - - for i := 0; i < b.N; i++ { - r := handleExpr(schemaHelper, fmt.Sprintf("array_length(ArrayField) == %d", i)) - err := getError(r) - assert.NoError(b, err) - } + b.Run("uncached", func(b *testing.B) { + for i := 0; i < b.N; i++ { + r := handleExpr(schemaHelper, fmt.Sprintf("array_length(ArrayField) == %d", i)) + err := getError(r) + assert.NoError(b, err) + } + }) } func randomChineseString(length int) string {