mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-12-03 12:29:36 +08:00
enhance: Convert unincode to ascii to improving expression parsing efficiency (#36675)
issue: #36672 --------- Signed-off-by: Cai Zhang <cai.zhang@zilliz.com>
This commit is contained in:
parent
2ec6e602d6
commit
fc8b5ab791
@ -131,7 +131,7 @@ func (v *ParserVisitor) VisitFloating(ctx *parser.FloatingContext) interface{} {
|
|||||||
|
|
||||||
// VisitString translates expr to GenericValue.
|
// VisitString translates expr to GenericValue.
|
||||||
func (v *ParserVisitor) VisitString(ctx *parser.StringContext) interface{} {
|
func (v *ParserVisitor) VisitString(ctx *parser.StringContext) interface{} {
|
||||||
pattern, err := convertEscapeSingle(ctx.StringLiteral().GetText())
|
pattern, err := convertEscapeSingle(ctx.GetText())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -2,7 +2,9 @@ package planparserv2
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
"unicode"
|
||||||
|
|
||||||
"github.com/antlr/antlr4/runtime/Go/antlr"
|
"github.com/antlr/antlr4/runtime/Go/antlr"
|
||||||
"github.com/hashicorp/golang-lru/v2/expirable"
|
"github.com/hashicorp/golang-lru/v2/expirable"
|
||||||
@ -126,7 +128,39 @@ func CreateRetrievePlan(schema *typeutil.SchemaHelper, exprStr string) (*planpb.
|
|||||||
return planNode, nil
|
return planNode, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func convertHanToASCII(s string) string {
|
||||||
|
var builder strings.Builder
|
||||||
|
builder.Grow(len(s) * 6)
|
||||||
|
skipCur := false
|
||||||
|
n := len(s)
|
||||||
|
for i, r := range s {
|
||||||
|
if skipCur {
|
||||||
|
builder.WriteRune(r)
|
||||||
|
skipCur = false
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if r == '\\' {
|
||||||
|
if i+1 < n && !isEscapeCh(s[i+1]) {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
skipCur = true
|
||||||
|
builder.WriteRune(r)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if unicode.Is(unicode.Han, r) {
|
||||||
|
builder.WriteString(formatUnicode(uint32(r)))
|
||||||
|
} else {
|
||||||
|
builder.WriteRune(r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return builder.String()
|
||||||
|
}
|
||||||
|
|
||||||
func CreateSearchPlan(schema *typeutil.SchemaHelper, exprStr string, vectorFieldName string, queryInfo *planpb.QueryInfo) (*planpb.PlanNode, error) {
|
func CreateSearchPlan(schema *typeutil.SchemaHelper, exprStr string, vectorFieldName string, queryInfo *planpb.QueryInfo) (*planpb.PlanNode, error) {
|
||||||
|
exprStr = convertHanToASCII(exprStr)
|
||||||
|
|
||||||
parse := func() (*planpb.Expr, error) {
|
parse := func() (*planpb.Expr, error) {
|
||||||
if len(exprStr) <= 0 {
|
if len(exprStr) <= 0 {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
|
@ -2,6 +2,7 @@ package planparserv2
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"math/rand"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
@ -1053,13 +1054,14 @@ c'`,
|
|||||||
`A == "\中国"`,
|
`A == "\中国"`,
|
||||||
}
|
}
|
||||||
for _, expr = range invalidExprs {
|
for _, expr = range invalidExprs {
|
||||||
_, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{
|
plan, err := CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{
|
||||||
Topk: 0,
|
Topk: 0,
|
||||||
MetricType: "",
|
MetricType: "",
|
||||||
SearchParams: "",
|
SearchParams: "",
|
||||||
RoundDecimal: 0,
|
RoundDecimal: 0,
|
||||||
})
|
})
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
|
fmt.Println(plan)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1338,3 +1340,54 @@ func BenchmarkNoPlanCache(b *testing.B) {
|
|||||||
assert.NoError(b, err)
|
assert.NoError(b, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func randomChineseString(length int) string {
|
||||||
|
min := 0x4e00
|
||||||
|
max := 0x9fa5
|
||||||
|
|
||||||
|
result := make([]rune, length)
|
||||||
|
for i := 0; i < length; i++ {
|
||||||
|
result[i] = rune(rand.Intn(max-min+1) + min)
|
||||||
|
}
|
||||||
|
|
||||||
|
return string(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkWithString(b *testing.B) {
|
||||||
|
schema := newTestSchema()
|
||||||
|
schemaHelper, err := typeutil.CreateSchemaHelper(schema)
|
||||||
|
require.NoError(b, err)
|
||||||
|
|
||||||
|
expr := ""
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
expr += fmt.Sprintf(`"%s",`, randomChineseString(rand.Intn(100)))
|
||||||
|
}
|
||||||
|
expr = "StringField in [" + expr + "]"
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
plan, err := CreateSearchPlan(schemaHelper, expr, "FloatVectorField", &planpb.QueryInfo{
|
||||||
|
Topk: 0,
|
||||||
|
MetricType: "",
|
||||||
|
SearchParams: "",
|
||||||
|
RoundDecimal: 0,
|
||||||
|
})
|
||||||
|
assert.NoError(b, err)
|
||||||
|
assert.NotNil(b, plan)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_convertHanToASCII(t *testing.T) {
|
||||||
|
type testcase struct {
|
||||||
|
source string
|
||||||
|
target string
|
||||||
|
}
|
||||||
|
testcases := []testcase{
|
||||||
|
{`A in ["中国"]`, `A in ["\u4e2d\u56fd"]`},
|
||||||
|
{`A in ["\中国"]`, `A in ["\中国"]`},
|
||||||
|
{`A in ["\\中国"]`, `A in ["\\\u4e2d\u56fd"]`},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, c := range testcases {
|
||||||
|
assert.Equal(t, c.target, convertHanToASCII(c.source))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -539,3 +539,25 @@ func isIntegerColumn(col *planpb.ColumnInfo) bool {
|
|||||||
(typeutil.IsArrayType(col.GetDataType()) && typeutil.IsIntegerType(col.GetElementType())) ||
|
(typeutil.IsArrayType(col.GetDataType()) && typeutil.IsIntegerType(col.GetElementType())) ||
|
||||||
typeutil.IsJSONType(col.GetDataType())
|
typeutil.IsJSONType(col.GetDataType())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isEscapeCh(ch uint8) bool {
|
||||||
|
return ch == '\\' || ch == 'n' || ch == 't' || ch == 'r' || ch == 'f' || ch == '"' || ch == '\''
|
||||||
|
}
|
||||||
|
|
||||||
|
func formatUnicode(r uint32) string {
|
||||||
|
return string([]byte{
|
||||||
|
'\\', 'u',
|
||||||
|
hexDigit(r >> 12),
|
||||||
|
hexDigit(r >> 8),
|
||||||
|
hexDigit(r >> 4),
|
||||||
|
hexDigit(r),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func hexDigit(n uint32) byte {
|
||||||
|
n &= 0xf
|
||||||
|
if n < 10 {
|
||||||
|
return byte(n) + '0'
|
||||||
|
}
|
||||||
|
return byte(n-10) + 'a'
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user