diff --git a/internal/parser/planparserv2/parser_visitor.go b/internal/parser/planparserv2/parser_visitor.go index fda5c981e7..5adccad6e8 100644 --- a/internal/parser/planparserv2/parser_visitor.go +++ b/internal/parser/planparserv2/parser_visitor.go @@ -131,7 +131,7 @@ func (v *ParserVisitor) VisitFloating(ctx *parser.FloatingContext) interface{} { // VisitString translates expr to GenericValue. func (v *ParserVisitor) VisitString(ctx *parser.StringContext) interface{} { - pattern, err := convertEscapeSingle(ctx.StringLiteral().GetText()) + pattern, err := convertEscapeSingle(ctx.GetText()) if err != nil { return err } diff --git a/internal/parser/planparserv2/plan_parser_v2.go b/internal/parser/planparserv2/plan_parser_v2.go index 6916a70220..af3a1c9cac 100644 --- a/internal/parser/planparserv2/plan_parser_v2.go +++ b/internal/parser/planparserv2/plan_parser_v2.go @@ -2,7 +2,9 @@ package planparserv2 import ( "fmt" + "strings" "time" + "unicode" "github.com/antlr/antlr4/runtime/Go/antlr" "github.com/hashicorp/golang-lru/v2/expirable" @@ -126,7 +128,39 @@ func CreateRetrievePlan(schema *typeutil.SchemaHelper, exprStr string) (*planpb. return planNode, nil } +func convertHanToASCII(s string) string { + var builder strings.Builder + builder.Grow(len(s) * 6) + skipCur := false + n := len(s) + for i, r := range s { + if skipCur { + builder.WriteRune(r) + skipCur = false + continue + } + if r == '\\' { + if i+1 < n && !isEscapeCh(s[i+1]) { + return s + } + skipCur = true + builder.WriteRune(r) + continue + } + + if unicode.Is(unicode.Han, r) { + builder.WriteString(formatUnicode(uint32(r))) + } else { + builder.WriteRune(r) + } + } + + return builder.String() +} + func CreateSearchPlan(schema *typeutil.SchemaHelper, exprStr string, vectorFieldName string, queryInfo *planpb.QueryInfo) (*planpb.PlanNode, error) { + exprStr = convertHanToASCII(exprStr) + parse := func() (*planpb.Expr, error) { if len(exprStr) <= 0 { return nil, nil diff --git a/internal/parser/planparserv2/plan_parser_v2_test.go b/internal/parser/planparserv2/plan_parser_v2_test.go index c9c059d78b..4642161d88 100644 --- a/internal/parser/planparserv2/plan_parser_v2_test.go +++ b/internal/parser/planparserv2/plan_parser_v2_test.go @@ -2,6 +2,7 @@ package planparserv2 import ( "fmt" + "math/rand" "sync" "testing" @@ -1053,13 +1054,14 @@ c'`, `A == "\中国"`, } for _, expr = range invalidExprs { - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ + plan, err := CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ Topk: 0, MetricType: "", SearchParams: "", RoundDecimal: 0, }) assert.Error(t, err) + fmt.Println(plan) } } @@ -1338,3 +1340,54 @@ func BenchmarkNoPlanCache(b *testing.B) { assert.NoError(b, err) } } + +func randomChineseString(length int) string { + min := 0x4e00 + max := 0x9fa5 + + result := make([]rune, length) + for i := 0; i < length; i++ { + result[i] = rune(rand.Intn(max-min+1) + min) + } + + return string(result) +} + +func BenchmarkWithString(b *testing.B) { + schema := newTestSchema() + schemaHelper, err := typeutil.CreateSchemaHelper(schema) + require.NoError(b, err) + + expr := "" + for i := 0; i < 100; i++ { + expr += fmt.Sprintf(`"%s",`, randomChineseString(rand.Intn(100))) + } + expr = "StringField in [" + expr + "]" + + for i := 0; i < b.N; i++ { + plan, err := CreateSearchPlan(schemaHelper, expr, "FloatVectorField", &planpb.QueryInfo{ + Topk: 0, + MetricType: "", + SearchParams: "", + RoundDecimal: 0, + }) + assert.NoError(b, err) + assert.NotNil(b, plan) + } +} + +func Test_convertHanToASCII(t *testing.T) { + type testcase struct { + source string + target string + } + testcases := []testcase{ + {`A in ["中国"]`, `A in ["\u4e2d\u56fd"]`}, + {`A in ["\中国"]`, `A in ["\中国"]`}, + {`A in ["\\中国"]`, `A in ["\\\u4e2d\u56fd"]`}, + } + + for _, c := range testcases { + assert.Equal(t, c.target, convertHanToASCII(c.source)) + } +} diff --git a/internal/parser/planparserv2/utils.go b/internal/parser/planparserv2/utils.go index 24c042d615..db87dcfa2b 100644 --- a/internal/parser/planparserv2/utils.go +++ b/internal/parser/planparserv2/utils.go @@ -539,3 +539,25 @@ func isIntegerColumn(col *planpb.ColumnInfo) bool { (typeutil.IsArrayType(col.GetDataType()) && typeutil.IsIntegerType(col.GetElementType())) || typeutil.IsJSONType(col.GetDataType()) } + +func isEscapeCh(ch uint8) bool { + return ch == '\\' || ch == 'n' || ch == 't' || ch == 'r' || ch == 'f' || ch == '"' || ch == '\'' +} + +func formatUnicode(r uint32) string { + return string([]byte{ + '\\', 'u', + hexDigit(r >> 12), + hexDigit(r >> 8), + hexDigit(r >> 4), + hexDigit(r), + }) +} + +func hexDigit(n uint32) byte { + n &= 0xf + if n < 10 { + return byte(n) + '0' + } + return byte(n-10) + 'a' +}