fix: memory leak while parsing query plan (#34931)

issue: #34930

Signed-off-by: jaime <yun.zhang@zilliz.com>
This commit is contained in:
jaime 2024-07-28 21:50:20 +08:00 committed by GitHub
parent 9463eeef2b
commit 08fa51d4f4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 62 additions and 15 deletions

View File

@ -7,11 +7,20 @@ import (
"github.com/antlr/antlr4/runtime/Go/antlr" "github.com/antlr/antlr4/runtime/Go/antlr"
) )
type errorListener struct { type errorListener interface {
antlr.ErrorListener
Error() error
}
type errorListenerImpl struct {
*antlr.DefaultErrorListener *antlr.DefaultErrorListener
err error err error
} }
func (l *errorListener) SyntaxError(recognizer antlr.Recognizer, offendingSymbol interface{}, line, column int, msg string, e antlr.RecognitionException) { func (l *errorListenerImpl) SyntaxError(recognizer antlr.Recognizer, offendingSymbol interface{}, line, column int, msg string, e antlr.RecognitionException) {
l.err = fmt.Errorf("line " + strconv.Itoa(line) + ":" + strconv.Itoa(column) + " " + msg) l.err = fmt.Errorf("line " + strconv.Itoa(line) + ":" + strconv.Itoa(column) + " " + msg)
} }
func (l *errorListenerImpl) Error() error {
return l.err
}

View File

@ -14,6 +14,10 @@ import (
) )
func handleExpr(schema *typeutil.SchemaHelper, exprStr string) interface{} { func handleExpr(schema *typeutil.SchemaHelper, exprStr string) interface{} {
return handleExprWithErrorListener(schema, exprStr, &errorListenerImpl{})
}
func handleExprWithErrorListener(schema *typeutil.SchemaHelper, exprStr string, errorListener errorListener) interface{} {
if isEmptyExpression(exprStr) { if isEmptyExpression(exprStr) {
return &ExprWithType{ return &ExprWithType{
dataType: schemapb.DataType_Bool, dataType: schemapb.DataType_Bool,
@ -22,21 +26,19 @@ func handleExpr(schema *typeutil.SchemaHelper, exprStr string) interface{} {
} }
inputStream := antlr.NewInputStream(exprStr) inputStream := antlr.NewInputStream(exprStr)
errorListener := &errorListener{}
lexer := getLexer(inputStream, errorListener) lexer := getLexer(inputStream, errorListener)
if errorListener.err != nil { if errorListener.Error() != nil {
return errorListener.err return errorListener.Error()
} }
parser := getParser(lexer, errorListener) parser := getParser(lexer, errorListener)
if errorListener.err != nil { if errorListener.Error() != nil {
return errorListener.err return errorListener.Error()
} }
ast := parser.Expr() ast := parser.Expr()
if errorListener.err != nil { if errorListener.Error() != nil {
return errorListener.err return errorListener.Error()
} }
if parser.GetCurrentToken().GetTokenType() != antlr.TokenEOF { if parser.GetCurrentToken().GetTokenType() != antlr.TokenEOF {

View File

@ -4,6 +4,7 @@ import (
"sync" "sync"
"testing" "testing"
"github.com/antlr/antlr4/runtime/Go/antlr"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -604,6 +605,39 @@ func TestCreateSearchPlan_Invalid(t *testing.T) {
}) })
} }
var listenerCnt int
type errorListenerTest struct {
antlr.DefaultErrorListener
}
func (l *errorListenerTest) SyntaxError(recognizer antlr.Recognizer, offendingSymbol interface{}, line, column int, msg string, e antlr.RecognitionException) {
listenerCnt += 1
}
func (l *errorListenerTest) ReportAmbiguity(recognizer antlr.Parser, dfa *antlr.DFA, startIndex, stopIndex int, exact bool, ambigAlts *antlr.BitSet, configs antlr.ATNConfigSet) {
listenerCnt += 1
}
func (l *errorListenerTest) Error() error {
return nil
}
func Test_FixErrorListenerNotRemoved(t *testing.T) {
schema := newTestSchema()
schemaHelper, err := typeutil.CreateSchemaHelper(schema)
assert.NoError(t, err)
normal := "1 < Int32Field < (Int16Field)"
for i := 0; i < 10; i++ {
err := handleExprWithErrorListener(schemaHelper, normal, &errorListenerTest{})
err1, ok := err.(error)
assert.True(t, ok)
assert.Error(t, err1)
}
assert.True(t, listenerCnt <= 10)
}
func Test_handleExpr(t *testing.T) { func Test_handleExpr(t *testing.T) {
schema := newTestSchema() schema := newTestSchema()
schemaHelper, err := typeutil.CreateSchemaHelper(schema) schemaHelper, err := typeutil.CreateSchemaHelper(schema)

View File

@ -72,11 +72,13 @@ func getParser(lexer *antlrparser.PlanLexer, listeners ...antlr.ErrorListener) *
func putLexer(lexer *antlrparser.PlanLexer) { func putLexer(lexer *antlrparser.PlanLexer) {
lexer.SetInputStream(nil) lexer.SetInputStream(nil)
lexer.RemoveErrorListeners()
lexerPool.ReturnObject(context.TODO(), lexer) lexerPool.ReturnObject(context.TODO(), lexer)
} }
func putParser(parser *antlrparser.PlanParser) { func putParser(parser *antlrparser.PlanParser) {
parser.SetInputStream(nil) parser.SetInputStream(nil)
parser.RemoveErrorListeners()
parserPool.ReturnObject(context.TODO(), parser) parserPool.ReturnObject(context.TODO(), parser)
} }

View File

@ -16,10 +16,10 @@ func genNaiveInputStream() *antlr.InputStream {
func Test_getLexer(t *testing.T) { func Test_getLexer(t *testing.T) {
var lexer *antlrparser.PlanLexer var lexer *antlrparser.PlanLexer
resetLexerPool() resetLexerPool()
lexer = getLexer(genNaiveInputStream(), &errorListener{}) lexer = getLexer(genNaiveInputStream(), &errorListenerImpl{})
assert.NotNil(t, lexer) assert.NotNil(t, lexer)
lexer = getLexer(genNaiveInputStream(), &errorListener{}) lexer = getLexer(genNaiveInputStream(), &errorListenerImpl{})
assert.NotNil(t, lexer) assert.NotNil(t, lexer)
pool := getLexerPool() pool := getLexerPool()
@ -36,13 +36,13 @@ func Test_getParser(t *testing.T) {
var parser *antlrparser.PlanParser var parser *antlrparser.PlanParser
resetParserPool() resetParserPool()
lexer = getLexer(genNaiveInputStream(), &errorListener{}) lexer = getLexer(genNaiveInputStream(), &errorListenerImpl{})
assert.NotNil(t, lexer) assert.NotNil(t, lexer)
parser = getParser(lexer, &errorListener{}) parser = getParser(lexer, &errorListenerImpl{})
assert.NotNil(t, parser) assert.NotNil(t, parser)
parser = getParser(lexer, &errorListener{}) parser = getParser(lexer, &errorListenerImpl{})
assert.NotNil(t, parser) assert.NotNil(t, parser)
pool := getParserPool() pool := getParserPool()