enhance: support pattern matching on json field (#30779)

issue: https://github.com/milvus-io/milvus/issues/30714

---------

Signed-off-by: longjiquan <jiquan.long@zilliz.com>
This commit is contained in:
Jiquan Long 2024-02-28 18:31:00 +08:00 committed by GitHub
parent ed1197ea50
commit e2f35954d4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 233 additions and 9 deletions

View File

@ -9,7 +9,12 @@
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#pragma once
#include <string>
#include <regex>
#include "common/EasyAssert.h"
namespace milvus {
std::string
@ -19,4 +24,41 @@ ReplaceUnescapedChars(const std::string& input,
std::string
TranslatePatternMatchToRegex(const std::string& pattern);
struct PatternMatchTranslator {
template <typename T>
inline std::string
operator()(const T& pattern) {
PanicInfo(OpTypeInvalid,
"pattern matching is only supported on string type");
}
};
template <>
inline std::string
PatternMatchTranslator::operator()<std::string>(const std::string& pattern) {
return TranslatePatternMatchToRegex(pattern);
}
struct RegexMatcher {
template <typename T>
inline bool
operator()(const std::regex& reg, const T& operand) {
return false;
}
};
template <>
inline bool
RegexMatcher::operator()<std::string>(const std::regex& reg,
const std::string& operand) {
return std::regex_match(operand, reg);
}
template <>
inline bool
RegexMatcher::operator()<std::string_view>(const std::regex& reg,
const std::string_view& operand) {
return std::regex_match(operand.begin(), operand.end(), reg);
}
} // namespace milvus

View File

@ -333,6 +333,21 @@ PhyUnaryRangeFilterExpr::ExecRangeVisitorImplJson() {
}
break;
}
case proto::plan::Match: {
PatternMatchTranslator translator;
RegexMatcher matcher;
auto regex_pattern = translator(val);
std::regex reg(regex_pattern);
for (size_t i = 0; i < size; ++i) {
if constexpr (std::is_same_v<GetType, proto::plan::Array>) {
res[i] = false;
} else {
UnaryRangeJSONCompare(
matcher(reg, ExprValueType(x.value())));
}
}
break;
}
default:
PanicInfo(
OpTypeInvalid,

View File

@ -38,6 +38,7 @@ GenTestSchema() {
auto schema = std::make_shared<Schema>();
schema->AddDebugField("str", DataType::VARCHAR);
schema->AddDebugField("another_str", DataType::VARCHAR);
schema->AddDebugField("json", DataType::JSON);
schema->AddDebugField(
"fvec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2);
auto pk = schema->AddDebugField("int64", DataType::INT64);
@ -59,6 +60,13 @@ class GrowingSegmentRegexQueryTest : public ::testing::Test {
"abbb",
"abcabcabc",
};
raw_json = {
R"({"int":1})",
R"({"float":1.0})",
R"({"str":"aaa"})",
R"({"str":"bbb"})",
R"({"str":"abcabcabc"})",
};
N = 5;
uint64_t seed = 19190504;
@ -71,6 +79,16 @@ class GrowingSegmentRegexQueryTest : public ::testing::Test {
for (int64_t i = 0; i < N; i++) {
str_col->at(i) = raw_str[i];
}
auto json_col = raw_data.raw_->mutable_fields_data()
->at(2)
.mutable_scalars()
->mutable_json_data()
->mutable_data();
for (int64_t i = 0; i < N; i++) {
json_col->at(i) = raw_json[i];
}
seg->PreInsert(N);
seg->Insert(0,
N,
@ -88,6 +106,7 @@ class GrowingSegmentRegexQueryTest : public ::testing::Test {
SegmentGrowingPtr seg;
int64_t N;
std::vector<std::string> raw_str;
std::vector<std::string> raw_json;
};
TEST_F(GrowingSegmentRegexQueryTest, RegexQueryOnNonStringField) {
@ -141,6 +160,33 @@ TEST_F(GrowingSegmentRegexQueryTest, RegexQueryOnStringField) {
ASSERT_TRUE(final[4]);
}
TEST_F(GrowingSegmentRegexQueryTest, RegexQueryOnJsonField) {
std::string operand = "a%";
const auto& str_meta = schema->operator[](FieldName("json"));
auto column_info = test::GenColumnInfo(
str_meta.get_id().get(), proto::schema::DataType::JSON, false, false);
column_info->add_nested_path("str");
auto unary_range_expr = test::GenUnaryRangeExpr(OpType::Match, operand);
unary_range_expr->set_allocated_column_info(column_info);
auto expr = test::GenExpr();
expr->set_allocated_unary_range_expr(unary_range_expr);
auto parser = ProtoParser(*schema);
auto typed_expr = parser.ParseExprs(*expr);
auto parsed =
std::make_shared<plan::FilterBitsNode>(DEFAULT_PLANNODE_ID, typed_expr);
auto segpromote = dynamic_cast<SegmentGrowingImpl*>(seg.get());
query::ExecPlanNodeVisitor visitor(*segpromote, MAX_TIMESTAMP);
BitsetType final;
visitor.ExecuteExprNode(parsed, segpromote, N, final);
ASSERT_FALSE(final[0]);
ASSERT_FALSE(final[1]);
ASSERT_TRUE(final[2]);
ASSERT_FALSE(final[3]);
ASSERT_TRUE(final[4]);
}
struct MockStringIndex : index::StringIndexSort {
const bool
HasRawData() const override {
@ -166,6 +212,13 @@ class SealedSegmentRegexQueryTest : public ::testing::Test {
"abbb",
"abcabcabc",
};
raw_json = {
R"({"int":1})",
R"({"float":1.0})",
R"({"str":"aaa"})",
R"({"str":"bbb"})",
R"({"str":"abcabcabc"})",
};
N = 5;
uint64_t seed = 19190504;
auto raw_data = DataGen(schema, N, seed);
@ -180,6 +233,16 @@ class SealedSegmentRegexQueryTest : public ::testing::Test {
for (int64_t i = 0; i < N; i++) {
str_col->at(i) = raw_str[i];
}
auto json_col = raw_data.raw_->mutable_fields_data()
->at(2)
.mutable_scalars()
->mutable_json_data()
->mutable_data();
for (int64_t i = 0; i < N; i++) {
json_col->at(i) = raw_json[i];
}
SealedLoadFieldData(raw_data, *seg);
}
@ -251,6 +314,7 @@ class SealedSegmentRegexQueryTest : public ::testing::Test {
int64_t N;
std::vector<std::string> raw_str;
std::vector<int64_t> raw_int;
std::vector<std::string> raw_json;
};
TEST_F(SealedSegmentRegexQueryTest, BFRegexQueryOnNonStringField) {
@ -271,9 +335,7 @@ TEST_F(SealedSegmentRegexQueryTest, BFRegexQueryOnNonStringField) {
auto segpromote = dynamic_cast<SegmentSealedImpl*>(seg.get());
query::ExecPlanNodeVisitor visitor(*segpromote, MAX_TIMESTAMP);
BitsetType final;
ASSERT_ANY_THROW(
visitor.ExecuteExprNode(parsed, segpromote, N, final));
ASSERT_ANY_THROW(visitor.ExecuteExprNode(parsed, segpromote, N, final));
}
TEST_F(SealedSegmentRegexQueryTest, BFRegexQueryOnStringField) {
@ -304,6 +366,33 @@ TEST_F(SealedSegmentRegexQueryTest, BFRegexQueryOnStringField) {
ASSERT_TRUE(final[4]);
}
TEST_F(SealedSegmentRegexQueryTest, BFRegexQueryOnJsonField) {
std::string operand = "a%";
const auto& str_meta = schema->operator[](FieldName("json"));
auto column_info = test::GenColumnInfo(
str_meta.get_id().get(), proto::schema::DataType::JSON, false, false);
column_info->add_nested_path("str");
auto unary_range_expr = test::GenUnaryRangeExpr(OpType::Match, operand);
unary_range_expr->set_allocated_column_info(column_info);
auto expr = test::GenExpr();
expr->set_allocated_unary_range_expr(unary_range_expr);
auto parser = ProtoParser(*schema);
auto typed_expr = parser.ParseExprs(*expr);
auto parsed =
std::make_shared<plan::FilterBitsNode>(DEFAULT_PLANNODE_ID, typed_expr);
auto segpromote = dynamic_cast<SegmentSealedImpl*>(seg.get());
query::ExecPlanNodeVisitor visitor(*segpromote, MAX_TIMESTAMP);
BitsetType final;
visitor.ExecuteExprNode(parsed, segpromote, N, final);
ASSERT_FALSE(final[0]);
ASSERT_FALSE(final[1]);
ASSERT_TRUE(final[2]);
ASSERT_FALSE(final[3]);
ASSERT_TRUE(final[4]);
}
TEST_F(SealedSegmentRegexQueryTest, RegexQueryOnIndexedNonStringField) {
int64_t operand = 120;
const auto& int_meta = schema->operator[](FieldName("another_int64"));

View File

@ -43,3 +43,67 @@ TEST(TranslatePatternMatchToRegexTest, PatternWithRegexChar) {
std::string result = milvus::TranslatePatternMatchToRegex(pattern);
EXPECT_EQ(result, "abc\\*def\\.ghi\\+");
}
TEST(PatternMatchTranslatorTest, InvalidTypeTest) {
using namespace milvus;
PatternMatchTranslator translator;
ASSERT_ANY_THROW(translator(123));
ASSERT_ANY_THROW(translator(3.14));
ASSERT_ANY_THROW(translator(true));
}
TEST(PatternMatchTranslatorTest, StringTypeTest) {
using namespace milvus;
PatternMatchTranslator translator;
std::string pattern1 = "abc";
std::string pattern2 = "xyz";
std::string pattern3 = "%a_b%";
EXPECT_EQ(translator(pattern1), "abc");
EXPECT_EQ(translator(pattern2), "xyz");
EXPECT_EQ(translator(pattern3), ".*a.b.*");
}
TEST(RegexMatcherTest, DefaultBehaviorTest) {
using namespace milvus;
RegexMatcher matcher;
std::regex pattern("Hello.*");
int operand1 = 123;
double operand2 = 3.14;
bool operand3 = true;
EXPECT_FALSE(matcher(pattern, operand1));
EXPECT_FALSE(matcher(pattern, operand2));
EXPECT_FALSE(matcher(pattern, operand3));
}
TEST(RegexMatcherTest, StringMatchTest) {
using namespace milvus;
RegexMatcher matcher;
std::regex pattern("Hello.*");
std::string str1 = "Hello, World!";
std::string str2 = "Hi there!";
std::string str3 = "Hello, OpenAI!";
EXPECT_TRUE(matcher(pattern, str1));
EXPECT_FALSE(matcher(pattern, str2));
EXPECT_TRUE(matcher(pattern, str3));
}
TEST(RegexMatcherTest, StringViewMatchTest) {
using namespace milvus;
RegexMatcher matcher;
std::regex pattern("Hello.*");
std::string_view str1 = "Hello, World!";
std::string_view str2 = "Hi there!";
std::string_view str3 = "Hello, OpenAI!";
EXPECT_TRUE(matcher(pattern, str1));
EXPECT_FALSE(matcher(pattern, str2));
EXPECT_TRUE(matcher(pattern, str3));
}

View File

@ -659,6 +659,26 @@ func (s *JSONExprSuite) checkSearch(collectionName, fieldName string, dim int) {
s.doSearch(collectionName, []string{fieldName}, expr, dim, checkFunc)
log.Info("like expression run successfully")
expr = `D like "%name-%"`
checkFunc = func(result *milvuspb.SearchResults) {
s.Equal(1, len(result.Results.FieldsData))
s.Equal(fieldName, result.Results.FieldsData[0].GetFieldName())
s.Equal(schemapb.DataType_JSON, result.Results.FieldsData[0].GetType())
s.Equal(10, len(result.Results.FieldsData[0].GetScalars().GetJsonData().GetData()))
}
s.doSearch(collectionName, []string{fieldName}, expr, dim, checkFunc)
log.Info("like expression run successfully")
expr = `D like "na%me"`
checkFunc = func(result *milvuspb.SearchResults) {
s.Equal(1, len(result.Results.FieldsData))
s.Equal(fieldName, result.Results.FieldsData[0].GetFieldName())
s.Equal(schemapb.DataType_JSON, result.Results.FieldsData[0].GetType())
s.Equal(0, len(result.Results.FieldsData[0].GetScalars().GetJsonData().GetData()))
}
s.doSearch(collectionName, []string{fieldName}, expr, dim, checkFunc)
log.Info("like expression run successfully")
expr = `A in []`
checkFunc = func(result *milvuspb.SearchResults) {
for _, topk := range result.GetResults().GetTopks() {
@ -700,12 +720,6 @@ func (s *JSONExprSuite) checkSearch(collectionName, fieldName string, dim int) {
expr = `A like abc`
s.doSearchWithInvalidExpr(collectionName, []string{fieldName}, expr, dim)
expr = `D like "%name-%"`
s.doSearchWithInvalidExpr(collectionName, []string{fieldName}, expr, dim)
expr = `D like "na%me"`
s.doSearchWithInvalidExpr(collectionName, []string{fieldName}, expr, dim)
expr = `1+5 <= A+1 < 5+10`
s.doSearchWithInvalidExpr(collectionName, []string{fieldName}, expr, dim)