From 1aa97a5c21cb2644493d86a34506aedff7d15c0e Mon Sep 17 00:00:00 2001 From: "cai.zhang" Date: Fri, 1 Mar 2024 16:57:00 +0800 Subject: [PATCH] enhance: Support more relational operators for binary expressions (#30902) issue: #30677 Signed-off-by: Cai Zhang --- .../expression/BinaryArithOpEvalRangeExpr.cpp | 801 ++++++++- .../expression/BinaryArithOpEvalRangeExpr.h | 180 ++ internal/core/unittest/test_array_expr.cpp | 388 +++++ internal/core/unittest/test_expr.cpp | 1512 +++++++++++++++-- .../planparserv2/plan_parser_v2_test.go | 21 +- internal/parser/planparserv2/utils.go | 8 - .../integration/expression/expression_test.go | 240 +++ tests/python_client/testcases/test_search.py | 6 +- 8 files changed, 3037 insertions(+), 119 deletions(-) create mode 100644 tests/integration/expression/expression_test.go diff --git a/internal/core/src/exec/expression/BinaryArithOpEvalRangeExpr.cpp b/internal/core/src/exec/expression/BinaryArithOpEvalRangeExpr.cpp index d46cdc7558..2ce0d8b2a9 100644 --- a/internal/core/src/exec/expression/BinaryArithOpEvalRangeExpr.cpp +++ b/internal/core/src/exec/expression/BinaryArithOpEvalRangeExpr.cpp @@ -119,7 +119,10 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForJson() { auto op_type = expr_->op_type_; auto arith_type = expr_->arith_op_type_; auto value = GetValueFromProto(expr_->value_); - auto right_operand = GetValueFromProto(expr_->right_operand_); + auto right_operand = + arith_type != proto::plan::ArithOpType::ArrayLength + ? GetValueFromProto(expr_->right_operand_) + : ValueType(); #define BinaryArithRangeJSONCompare(cmp) \ do { \ @@ -260,6 +263,202 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForJson() { } break; } + case proto::plan::OpType::GreaterThan: { + switch (arith_type) { + case proto::plan::ArithOpType::Add: { + BinaryArithRangeJSONCompare(x.value() + right_operand > + val); + break; + } + case proto::plan::ArithOpType::Sub: { + BinaryArithRangeJSONCompare(x.value() - right_operand > + val); + break; + } + case proto::plan::ArithOpType::Mul: { + BinaryArithRangeJSONCompare(x.value() * right_operand > + val); + break; + } + case proto::plan::ArithOpType::Div: { + BinaryArithRangeJSONCompare(x.value() / right_operand > + val); + break; + } + case proto::plan::ArithOpType::Mod: { + BinaryArithRangeJSONCompare( + static_cast( + fmod(x.value(), right_operand)) > val); + break; + } + case proto::plan::ArithOpType::ArrayLength: { + for (size_t i = 0; i < size; ++i) { + int array_length = 0; + auto doc = data[i].doc(); + auto array = doc.at_pointer(pointer).get_array(); + if (!array.error()) { + array_length = array.count_elements(); + } + res[i] = array_length > val; + } + break; + } + default: + PanicInfo( + OpTypeInvalid, + fmt::format("unsupported arith type for binary " + "arithmetic eval expr: {}", + arith_type)); + } + break; + } + case proto::plan::OpType::GreaterEqual: { + switch (arith_type) { + case proto::plan::ArithOpType::Add: { + BinaryArithRangeJSONCompare(x.value() + right_operand >= + val); + break; + } + case proto::plan::ArithOpType::Sub: { + BinaryArithRangeJSONCompare(x.value() - right_operand >= + val); + break; + } + case proto::plan::ArithOpType::Mul: { + BinaryArithRangeJSONCompare(x.value() * right_operand >= + val); + break; + } + case proto::plan::ArithOpType::Div: { + BinaryArithRangeJSONCompare(x.value() / right_operand >= + val); + break; + } + case proto::plan::ArithOpType::Mod: { + BinaryArithRangeJSONCompare( + static_cast( + fmod(x.value(), right_operand)) >= val); + break; + } + case proto::plan::ArithOpType::ArrayLength: { + for (size_t i = 0; i < size; ++i) { + int array_length = 0; + auto doc = data[i].doc(); + auto array = doc.at_pointer(pointer).get_array(); + if (!array.error()) { + array_length = array.count_elements(); + } + res[i] = array_length >= val; + } + break; + } + default: + PanicInfo( + OpTypeInvalid, + fmt::format("unsupported arith type for binary " + "arithmetic eval expr: {}", + arith_type)); + } + break; + } + case proto::plan::OpType::LessThan: { + switch (arith_type) { + case proto::plan::ArithOpType::Add: { + BinaryArithRangeJSONCompare(x.value() + right_operand < + val); + break; + } + case proto::plan::ArithOpType::Sub: { + BinaryArithRangeJSONCompare(x.value() - right_operand < + val); + break; + } + case proto::plan::ArithOpType::Mul: { + BinaryArithRangeJSONCompare(x.value() * right_operand < + val); + break; + } + case proto::plan::ArithOpType::Div: { + BinaryArithRangeJSONCompare(x.value() / right_operand < + val); + break; + } + case proto::plan::ArithOpType::Mod: { + BinaryArithRangeJSONCompare( + static_cast( + fmod(x.value(), right_operand)) < val); + break; + } + case proto::plan::ArithOpType::ArrayLength: { + for (size_t i = 0; i < size; ++i) { + int array_length = 0; + auto doc = data[i].doc(); + auto array = doc.at_pointer(pointer).get_array(); + if (!array.error()) { + array_length = array.count_elements(); + } + res[i] = array_length < val; + } + break; + } + default: + PanicInfo( + OpTypeInvalid, + fmt::format("unsupported arith type for binary " + "arithmetic eval expr: {}", + arith_type)); + } + break; + } + case proto::plan::OpType::LessEqual: { + switch (arith_type) { + case proto::plan::ArithOpType::Add: { + BinaryArithRangeJSONCompare(x.value() + right_operand <= + val); + break; + } + case proto::plan::ArithOpType::Sub: { + BinaryArithRangeJSONCompare(x.value() - right_operand <= + val); + break; + } + case proto::plan::ArithOpType::Mul: { + BinaryArithRangeJSONCompare(x.value() * right_operand <= + val); + break; + } + case proto::plan::ArithOpType::Div: { + BinaryArithRangeJSONCompare(x.value() / right_operand <= + val); + break; + } + case proto::plan::ArithOpType::Mod: { + BinaryArithRangeJSONCompare( + static_cast( + fmod(x.value(), right_operand)) <= val); + break; + } + case proto::plan::ArithOpType::ArrayLength: { + for (size_t i = 0; i < size; ++i) { + int array_length = 0; + auto doc = data[i].doc(); + auto array = doc.at_pointer(pointer).get_array(); + if (!array.error()) { + array_length = array.count_elements(); + } + res[i] = array_length <= val; + } + break; + } + default: + PanicInfo( + OpTypeInvalid, + fmt::format("unsupported arith type for binary " + "arithmetic eval expr: {}", + arith_type)); + } + break; + } default: PanicInfo(OpTypeInvalid, "unsupported operator type for binary " @@ -413,6 +612,178 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForArray() { } break; } + case proto::plan::OpType::GreaterThan: { + switch (arith_type) { + case proto::plan::ArithOpType::Add: { + BinaryArithRangeArrayCompare(value + right_operand > + val); + break; + } + case proto::plan::ArithOpType::Sub: { + BinaryArithRangeArrayCompare(value - right_operand > + val); + break; + } + case proto::plan::ArithOpType::Mul: { + BinaryArithRangeArrayCompare(value * right_operand > + val); + break; + } + case proto::plan::ArithOpType::Div: { + BinaryArithRangeArrayCompare(value / right_operand > + val); + break; + } + case proto::plan::ArithOpType::Mod: { + BinaryArithRangeArrayCompare( + static_cast(fmod(value, right_operand)) > + val); + break; + } + case proto::plan::ArithOpType::ArrayLength: { + for (size_t i = 0; i < size; ++i) { + res[i] = data[i].length() > val; + } + break; + } + default: + PanicInfo( + OpTypeInvalid, + fmt::format("unsupported arith type for binary " + "arithmetic eval expr: {}", + arith_type)); + } + break; + } + case proto::plan::OpType::GreaterEqual: { + switch (arith_type) { + case proto::plan::ArithOpType::Add: { + BinaryArithRangeArrayCompare(value + right_operand >= + val); + break; + } + case proto::plan::ArithOpType::Sub: { + BinaryArithRangeArrayCompare(value - right_operand >= + val); + break; + } + case proto::plan::ArithOpType::Mul: { + BinaryArithRangeArrayCompare(value * right_operand >= + val); + break; + } + case proto::plan::ArithOpType::Div: { + BinaryArithRangeArrayCompare(value / right_operand >= + val); + break; + } + case proto::plan::ArithOpType::Mod: { + BinaryArithRangeArrayCompare( + static_cast( + fmod(value, right_operand)) >= val); + break; + } + case proto::plan::ArithOpType::ArrayLength: { + for (size_t i = 0; i < size; ++i) { + res[i] = data[i].length() >= val; + } + break; + } + default: + PanicInfo( + OpTypeInvalid, + fmt::format("unsupported arith type for binary " + "arithmetic eval expr: {}", + arith_type)); + } + break; + } + case proto::plan::OpType::LessThan: { + switch (arith_type) { + case proto::plan::ArithOpType::Add: { + BinaryArithRangeArrayCompare(value + right_operand < + val); + break; + } + case proto::plan::ArithOpType::Sub: { + BinaryArithRangeArrayCompare(value - right_operand < + val); + break; + } + case proto::plan::ArithOpType::Mul: { + BinaryArithRangeArrayCompare(value * right_operand < + val); + break; + } + case proto::plan::ArithOpType::Div: { + BinaryArithRangeArrayCompare(value / right_operand < + val); + break; + } + case proto::plan::ArithOpType::Mod: { + BinaryArithRangeArrayCompare( + static_cast(fmod(value, right_operand)) < + val); + break; + } + case proto::plan::ArithOpType::ArrayLength: { + for (size_t i = 0; i < size; ++i) { + res[i] = data[i].length() < val; + } + break; + } + default: + PanicInfo( + OpTypeInvalid, + fmt::format("unsupported arith type for binary " + "arithmetic eval expr: {}", + arith_type)); + } + break; + } + case proto::plan::OpType::LessEqual: { + switch (arith_type) { + case proto::plan::ArithOpType::Add: { + BinaryArithRangeArrayCompare(value + right_operand <= + val); + break; + } + case proto::plan::ArithOpType::Sub: { + BinaryArithRangeArrayCompare(value - right_operand <= + val); + break; + } + case proto::plan::ArithOpType::Mul: { + BinaryArithRangeArrayCompare(value * right_operand <= + val); + break; + } + case proto::plan::ArithOpType::Div: { + BinaryArithRangeArrayCompare(value / right_operand <= + val); + break; + } + case proto::plan::ArithOpType::Mod: { + BinaryArithRangeArrayCompare( + static_cast( + fmod(value, right_operand)) <= val); + break; + } + case proto::plan::ArithOpType::ArrayLength: { + for (size_t i = 0; i < size; ++i) { + res[i] = data[i].length() <= val; + } + break; + } + default: + PanicInfo( + OpTypeInvalid, + fmt::format("unsupported arith type for binary " + "arithmetic eval expr: {}", + arith_type)); + } + break; + } default: PanicInfo(OpTypeInvalid, "unsupported operator type for binary " @@ -579,6 +950,230 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForIndex() { } break; } + case proto::plan::OpType::GreaterThan: { + switch (arith_type) { + case proto::plan::ArithOpType::Add: { + ArithOpIndexFunc + func; + res = std::move(func( + index_ptr, sub_batch_size, value, right_operand)); + break; + } + case proto::plan::ArithOpType::Sub: { + ArithOpIndexFunc + func; + res = std::move(func( + index_ptr, sub_batch_size, value, right_operand)); + break; + } + case proto::plan::ArithOpType::Mul: { + ArithOpIndexFunc + func; + res = std::move(func( + index_ptr, sub_batch_size, value, right_operand)); + break; + } + case proto::plan::ArithOpType::Div: { + ArithOpIndexFunc + func; + res = std::move(func( + index_ptr, sub_batch_size, value, right_operand)); + break; + } + case proto::plan::ArithOpType::Mod: { + ArithOpIndexFunc + func; + res = std::move(func( + index_ptr, sub_batch_size, value, right_operand)); + break; + } + default: + PanicInfo( + OpTypeInvalid, + fmt::format("unsupported arith type for binary " + "arithmetic eval expr: {}", + arith_type)); + } + break; + } + case proto::plan::OpType::GreaterEqual: { + switch (arith_type) { + case proto::plan::ArithOpType::Add: { + ArithOpIndexFunc + func; + res = std::move(func( + index_ptr, sub_batch_size, value, right_operand)); + break; + } + case proto::plan::ArithOpType::Sub: { + ArithOpIndexFunc + func; + res = std::move(func( + index_ptr, sub_batch_size, value, right_operand)); + break; + } + case proto::plan::ArithOpType::Mul: { + ArithOpIndexFunc + func; + res = std::move(func( + index_ptr, sub_batch_size, value, right_operand)); + break; + } + case proto::plan::ArithOpType::Div: { + ArithOpIndexFunc + func; + res = std::move(func( + index_ptr, sub_batch_size, value, right_operand)); + break; + } + case proto::plan::ArithOpType::Mod: { + ArithOpIndexFunc + func; + res = std::move(func( + index_ptr, sub_batch_size, value, right_operand)); + break; + } + default: + PanicInfo( + OpTypeInvalid, + fmt::format("unsupported arith type for binary " + "arithmetic eval expr: {}", + arith_type)); + } + break; + } + case proto::plan::OpType::LessThan: { + switch (arith_type) { + case proto::plan::ArithOpType::Add: { + ArithOpIndexFunc + func; + res = std::move(func( + index_ptr, sub_batch_size, value, right_operand)); + break; + } + case proto::plan::ArithOpType::Sub: { + ArithOpIndexFunc + func; + res = std::move(func( + index_ptr, sub_batch_size, value, right_operand)); + break; + } + case proto::plan::ArithOpType::Mul: { + ArithOpIndexFunc + func; + res = std::move(func( + index_ptr, sub_batch_size, value, right_operand)); + break; + } + case proto::plan::ArithOpType::Div: { + ArithOpIndexFunc + func; + res = std::move(func( + index_ptr, sub_batch_size, value, right_operand)); + break; + } + case proto::plan::ArithOpType::Mod: { + ArithOpIndexFunc + func; + res = std::move(func( + index_ptr, sub_batch_size, value, right_operand)); + break; + } + default: + PanicInfo( + OpTypeInvalid, + fmt::format("unsupported arith type for binary " + "arithmetic eval expr: {}", + arith_type)); + } + break; + } + case proto::plan::OpType::LessEqual: { + switch (arith_type) { + case proto::plan::ArithOpType::Add: { + ArithOpIndexFunc + func; + res = std::move(func( + index_ptr, sub_batch_size, value, right_operand)); + break; + } + case proto::plan::ArithOpType::Sub: { + ArithOpIndexFunc + func; + res = std::move(func( + index_ptr, sub_batch_size, value, right_operand)); + break; + } + case proto::plan::ArithOpType::Mul: { + ArithOpIndexFunc + func; + res = std::move(func( + index_ptr, sub_batch_size, value, right_operand)); + break; + } + case proto::plan::ArithOpType::Div: { + ArithOpIndexFunc + func; + res = std::move(func( + index_ptr, sub_batch_size, value, right_operand)); + break; + } + case proto::plan::ArithOpType::Mod: { + ArithOpIndexFunc + func; + res = std::move(func( + index_ptr, sub_batch_size, value, right_operand)); + break; + } + default: + PanicInfo( + OpTypeInvalid, + fmt::format("unsupported arith type for binary " + "arithmetic eval expr: {}", + arith_type)); + } + break; + } default: PanicInfo(OpTypeInvalid, "unsupported operator type for binary " @@ -727,6 +1322,210 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { } break; } + case proto::plan::OpType::GreaterThan: { + switch (arith_type) { + case proto::plan::ArithOpType::Add: { + ArithOpElementFunc + func; + func(data, size, value, right_operand, res); + break; + } + case proto::plan::ArithOpType::Sub: { + ArithOpElementFunc + func; + func(data, size, value, right_operand, res); + break; + } + case proto::plan::ArithOpType::Mul: { + ArithOpElementFunc + func; + func(data, size, value, right_operand, res); + break; + } + case proto::plan::ArithOpType::Div: { + ArithOpElementFunc + func; + func(data, size, value, right_operand, res); + break; + } + case proto::plan::ArithOpType::Mod: { + ArithOpElementFunc + func; + func(data, size, value, right_operand, res); + break; + } + default: + PanicInfo( + OpTypeInvalid, + fmt::format("unsupported arith type for binary " + "arithmetic eval expr: {}", + arith_type)); + } + break; + } + case proto::plan::OpType::GreaterEqual: { + switch (arith_type) { + case proto::plan::ArithOpType::Add: { + ArithOpElementFunc + func; + func(data, size, value, right_operand, res); + break; + } + case proto::plan::ArithOpType::Sub: { + ArithOpElementFunc + func; + func(data, size, value, right_operand, res); + break; + } + case proto::plan::ArithOpType::Mul: { + ArithOpElementFunc + func; + func(data, size, value, right_operand, res); + break; + } + case proto::plan::ArithOpType::Div: { + ArithOpElementFunc + func; + func(data, size, value, right_operand, res); + break; + } + case proto::plan::ArithOpType::Mod: { + ArithOpElementFunc + func; + func(data, size, value, right_operand, res); + break; + } + default: + PanicInfo( + OpTypeInvalid, + fmt::format("unsupported arith type for binary " + "arithmetic eval expr: {}", + arith_type)); + } + break; + } + case proto::plan::OpType::LessThan: { + switch (arith_type) { + case proto::plan::ArithOpType::Add: { + ArithOpElementFunc + func; + func(data, size, value, right_operand, res); + break; + } + case proto::plan::ArithOpType::Sub: { + ArithOpElementFunc + func; + func(data, size, value, right_operand, res); + break; + } + case proto::plan::ArithOpType::Mul: { + ArithOpElementFunc + func; + func(data, size, value, right_operand, res); + break; + } + case proto::plan::ArithOpType::Div: { + ArithOpElementFunc + func; + func(data, size, value, right_operand, res); + break; + } + case proto::plan::ArithOpType::Mod: { + ArithOpElementFunc + func; + func(data, size, value, right_operand, res); + break; + } + default: + PanicInfo( + OpTypeInvalid, + fmt::format("unsupported arith type for binary " + "arithmetic eval expr: {}", + arith_type)); + } + break; + } + case proto::plan::OpType::LessEqual: { + switch (arith_type) { + case proto::plan::ArithOpType::Add: { + ArithOpElementFunc + func; + func(data, size, value, right_operand, res); + break; + } + case proto::plan::ArithOpType::Sub: { + ArithOpElementFunc + func; + func(data, size, value, right_operand, res); + break; + } + case proto::plan::ArithOpType::Mul: { + ArithOpElementFunc + func; + func(data, size, value, right_operand, res); + break; + } + case proto::plan::ArithOpType::Div: { + ArithOpElementFunc + func; + func(data, size, value, right_operand, res); + break; + } + case proto::plan::ArithOpType::Mod: { + ArithOpElementFunc + func; + func(data, size, value, right_operand, res); + break; + } + default: + PanicInfo( + OpTypeInvalid, + fmt::format("unsupported arith type for binary " + "arithmetic eval expr: {}", + arith_type)); + } + break; + } default: PanicInfo(OpTypeInvalid, "unsupported operator type for binary " diff --git a/internal/core/src/exec/expression/BinaryArithOpEvalRangeExpr.h b/internal/core/src/exec/expression/BinaryArithOpEvalRangeExpr.h index c16c0d983a..805d77c62d 100644 --- a/internal/core/src/exec/expression/BinaryArithOpEvalRangeExpr.h +++ b/internal/core/src/exec/expression/BinaryArithOpEvalRangeExpr.h @@ -88,6 +88,94 @@ struct ArithOpElementFunc { "unsupported arith type:{} for ArithOpElementFunc", arith_op)); } + } else if constexpr (cmp_op == proto::plan::OpType::GreaterThan) { + if constexpr (arith_op == proto::plan::ArithOpType::Add) { + res[i] = (src[i] + right_operand) > val; + } else if constexpr (arith_op == + proto::plan::ArithOpType::Sub) { + res[i] = (src[i] - right_operand) > val; + } else if constexpr (arith_op == + proto::plan::ArithOpType::Mul) { + res[i] = (src[i] * right_operand) > val; + } else if constexpr (arith_op == + proto::plan::ArithOpType::Div) { + res[i] = (src[i] / right_operand) > val; + } else if constexpr (arith_op == + proto::plan::ArithOpType::Mod) { + res[i] = (fmod(src[i], right_operand)) > val; + } else { + PanicInfo( + OpTypeInvalid, + fmt::format( + "unsupported arith type:{} for ArithOpElementFunc", + arith_op)); + } + } else if constexpr (cmp_op == proto::plan::OpType::GreaterEqual) { + if constexpr (arith_op == proto::plan::ArithOpType::Add) { + res[i] = (src[i] + right_operand) >= val; + } else if constexpr (arith_op == + proto::plan::ArithOpType::Sub) { + res[i] = (src[i] - right_operand) >= val; + } else if constexpr (arith_op == + proto::plan::ArithOpType::Mul) { + res[i] = (src[i] * right_operand) >= val; + } else if constexpr (arith_op == + proto::plan::ArithOpType::Div) { + res[i] = (src[i] / right_operand) >= val; + } else if constexpr (arith_op == + proto::plan::ArithOpType::Mod) { + res[i] = (fmod(src[i], right_operand)) >= val; + } else { + PanicInfo( + OpTypeInvalid, + fmt::format( + "unsupported arith type:{} for ArithOpElementFunc", + arith_op)); + } + } else if constexpr (cmp_op == proto::plan::OpType::LessThan) { + if constexpr (arith_op == proto::plan::ArithOpType::Add) { + res[i] = (src[i] + right_operand) < val; + } else if constexpr (arith_op == + proto::plan::ArithOpType::Sub) { + res[i] = (src[i] - right_operand) < val; + } else if constexpr (arith_op == + proto::plan::ArithOpType::Mul) { + res[i] = (src[i] * right_operand) < val; + } else if constexpr (arith_op == + proto::plan::ArithOpType::Div) { + res[i] = (src[i] / right_operand) < val; + } else if constexpr (arith_op == + proto::plan::ArithOpType::Mod) { + res[i] = (fmod(src[i], right_operand)) < val; + } else { + PanicInfo( + OpTypeInvalid, + fmt::format( + "unsupported arith type:{} for ArithOpElementFunc", + arith_op)); + } + } else if constexpr (cmp_op == proto::plan::OpType::LessEqual) { + if constexpr (arith_op == proto::plan::ArithOpType::Add) { + res[i] = (src[i] + right_operand) <= val; + } else if constexpr (arith_op == + proto::plan::ArithOpType::Sub) { + res[i] = (src[i] - right_operand) <= val; + } else if constexpr (arith_op == + proto::plan::ArithOpType::Mul) { + res[i] = (src[i] * right_operand) <= val; + } else if constexpr (arith_op == + proto::plan::ArithOpType::Div) { + res[i] = (src[i] / right_operand) <= val; + } else if constexpr (arith_op == + proto::plan::ArithOpType::Mod) { + res[i] = (fmod(src[i], right_operand)) <= val; + } else { + PanicInfo( + OpTypeInvalid, + fmt::format( + "unsupported arith type:{} for ArithOpElementFunc", + arith_op)); + } } } } @@ -157,6 +245,98 @@ struct ArithOpIndexFunc { "unsupported arith type:{} for ArithOpElementFunc", arith_op)); } + } else if constexpr (cmp_op == proto::plan::OpType::GreaterThan) { + if constexpr (arith_op == proto::plan::ArithOpType::Add) { + res[i] = (index->Reverse_Lookup(i) + right_operand) > val; + } else if constexpr (arith_op == + proto::plan::ArithOpType::Sub) { + res[i] = (index->Reverse_Lookup(i) - right_operand) > val; + } else if constexpr (arith_op == + proto::plan::ArithOpType::Mul) { + res[i] = (index->Reverse_Lookup(i) * right_operand) > val; + } else if constexpr (arith_op == + proto::plan::ArithOpType::Div) { + res[i] = (index->Reverse_Lookup(i) / right_operand) > val; + } else if constexpr (arith_op == + proto::plan::ArithOpType::Mod) { + res[i] = + (fmod(index->Reverse_Lookup(i), right_operand)) > val; + } else { + PanicInfo( + OpTypeInvalid, + fmt::format( + "unsupported arith type:{} for ArithOpElementFunc", + arith_op)); + } + } else if constexpr (cmp_op == proto::plan::OpType::GreaterEqual) { + if constexpr (arith_op == proto::plan::ArithOpType::Add) { + res[i] = (index->Reverse_Lookup(i) + right_operand) >= val; + } else if constexpr (arith_op == + proto::plan::ArithOpType::Sub) { + res[i] = (index->Reverse_Lookup(i) - right_operand) >= val; + } else if constexpr (arith_op == + proto::plan::ArithOpType::Mul) { + res[i] = (index->Reverse_Lookup(i) * right_operand) >= val; + } else if constexpr (arith_op == + proto::plan::ArithOpType::Div) { + res[i] = (index->Reverse_Lookup(i) / right_operand) >= val; + } else if constexpr (arith_op == + proto::plan::ArithOpType::Mod) { + res[i] = + (fmod(index->Reverse_Lookup(i), right_operand)) >= val; + } else { + PanicInfo( + OpTypeInvalid, + fmt::format( + "unsupported arith type:{} for ArithOpElementFunc", + arith_op)); + } + } else if constexpr (cmp_op == proto::plan::OpType::LessThan) { + if constexpr (arith_op == proto::plan::ArithOpType::Add) { + res[i] = (index->Reverse_Lookup(i) + right_operand) < val; + } else if constexpr (arith_op == + proto::plan::ArithOpType::Sub) { + res[i] = (index->Reverse_Lookup(i) - right_operand) < val; + } else if constexpr (arith_op == + proto::plan::ArithOpType::Mul) { + res[i] = (index->Reverse_Lookup(i) * right_operand) < val; + } else if constexpr (arith_op == + proto::plan::ArithOpType::Div) { + res[i] = (index->Reverse_Lookup(i) / right_operand) < val; + } else if constexpr (arith_op == + proto::plan::ArithOpType::Mod) { + res[i] = + (fmod(index->Reverse_Lookup(i), right_operand)) < val; + } else { + PanicInfo( + OpTypeInvalid, + fmt::format( + "unsupported arith type:{} for ArithOpElementFunc", + arith_op)); + } + } else if constexpr (cmp_op == proto::plan::OpType::LessEqual) { + if constexpr (arith_op == proto::plan::ArithOpType::Add) { + res[i] = (index->Reverse_Lookup(i) + right_operand) <= val; + } else if constexpr (arith_op == + proto::plan::ArithOpType::Sub) { + res[i] = (index->Reverse_Lookup(i) - right_operand) <= val; + } else if constexpr (arith_op == + proto::plan::ArithOpType::Mul) { + res[i] = (index->Reverse_Lookup(i) * right_operand) <= val; + } else if constexpr (arith_op == + proto::plan::ArithOpType::Div) { + res[i] = (index->Reverse_Lookup(i) / right_operand) <= val; + } else if constexpr (arith_op == + proto::plan::ArithOpType::Mod) { + res[i] = + (fmod(index->Reverse_Lookup(i), right_operand)) <= val; + } else { + PanicInfo( + OpTypeInvalid, + fmt::format( + "unsupported arith type:{} for ArithOpElementFunc", + arith_op)); + } } } return res_vec; diff --git a/internal/core/unittest/test_array_expr.cpp b/internal/core/unittest/test_array_expr.cpp index 1a1daa9cbb..06266f6e4a 100644 --- a/internal/core/unittest/test_array_expr.cpp +++ b/internal/core/unittest/test_array_expr.cpp @@ -1274,6 +1274,74 @@ TEST(Expr, TestArrayBinaryArith) { auto val = array.get_data(0); return val + 2 != 5; }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 101 + data_type: Array + nested_path:"0" + element_type:Int8 + > + arith_op:Add + right_operand: + op:GreaterThan + value: + >)", + "int", + [](milvus::Array& array) { + auto val = array.get_data(0); + return val + 2 > 5; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 101 + data_type: Array + nested_path:"0" + element_type:Int8 + > + arith_op:Add + right_operand: + op:GreaterEqual + value: + >)", + "int", + [](milvus::Array& array) { + auto val = array.get_data(0); + return val + 2 >= 5; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 101 + data_type: Array + nested_path:"0" + element_type:Int8 + > + arith_op:Add + right_operand: + op:LessThan + value: + >)", + "int", + [](milvus::Array& array) { + auto val = array.get_data(0); + return val + 2 < 5; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 101 + data_type: Array + nested_path:"0" + element_type:Int8 + > + arith_op:Add + right_operand: + op:LessEqual + value: + >)", + "int", + [](milvus::Array& array) { + auto val = array.get_data(0); + return val + 2 <= 5; + }}, {R"(binary_arith_op_eval_range_expr: < column_info: < field_id: 102 @@ -1308,6 +1376,74 @@ TEST(Expr, TestArrayBinaryArith) { auto val = array.get_data(0); return val - 1 != 144; }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 102 + data_type: Array + nested_path:"0" + element_type:Int64 + > + arith_op:Sub + right_operand: + op:GreaterThan + value: + >)", + "long", + [](milvus::Array& array) { + auto val = array.get_data(0); + return val - 1 > 144; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 102 + data_type: Array + nested_path:"0" + element_type:Int64 + > + arith_op:Sub + right_operand: + op:GreaterEqual + value: + >)", + "long", + [](milvus::Array& array) { + auto val = array.get_data(0); + return val - 1 >= 144; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 102 + data_type: Array + nested_path:"0" + element_type:Int64 + > + arith_op:Sub + right_operand: + op:LessThan + value: + >)", + "long", + [](milvus::Array& array) { + auto val = array.get_data(0); + return val - 1 < 144; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 102 + data_type: Array + nested_path:"0" + element_type:Int64 + > + arith_op:Sub + right_operand: + op:LessEqual + value: + >)", + "long", + [](milvus::Array& array) { + auto val = array.get_data(0); + return val - 1 <= 144; + }}, {R"(binary_arith_op_eval_range_expr: < column_info: < field_id: 103 @@ -1410,6 +1546,74 @@ TEST(Expr, TestArrayBinaryArith) { auto val = array.get_data(0); return val * 2 != 20; }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 102 + data_type: Array + nested_path:"0" + element_type:Int64 + > + arith_op:Mul + right_operand: + op:GreaterThan + value: + >)", + "long", + [](milvus::Array& array) { + auto val = array.get_data(0); + return val * 2 > 20; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 102 + data_type: Array + nested_path:"0" + element_type:Int64 + > + arith_op:Mul + right_operand: + op:GreaterEqual + value: + >)", + "long", + [](milvus::Array& array) { + auto val = array.get_data(0); + return val * 2 >= 20; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 102 + data_type: Array + nested_path:"0" + element_type:Int64 + > + arith_op:Mul + right_operand: + op:LessThan + value: + >)", + "long", + [](milvus::Array& array) { + auto val = array.get_data(0); + return val * 2 < 20; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 102 + data_type: Array + nested_path:"0" + element_type:Int64 + > + arith_op:Mul + right_operand: + op:LessEqual + value: + >)", + "long", + [](milvus::Array& array) { + auto val = array.get_data(0); + return val * 2 <= 20; + }}, {R"(binary_arith_op_eval_range_expr: < column_info: < field_id: 102 @@ -1444,6 +1648,74 @@ TEST(Expr, TestArrayBinaryArith) { auto val = array.get_data(0); return val / 2 != 20; }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 102 + data_type: Array + nested_path:"0" + element_type:Int64 + > + arith_op:Div + right_operand: + op:GreaterThan + value: + >)", + "long", + [](milvus::Array& array) { + auto val = array.get_data(0); + return val / 2 > 20; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 102 + data_type: Array + nested_path:"0" + element_type:Int64 + > + arith_op:Div + right_operand: + op:GreaterEqual + value: + >)", + "long", + [](milvus::Array& array) { + auto val = array.get_data(0); + return val / 2 >= 20; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 102 + data_type: Array + nested_path:"0" + element_type:Int64 + > + arith_op:Div + right_operand: + op:LessThan + value: + >)", + "long", + [](milvus::Array& array) { + auto val = array.get_data(0); + return val / 2 < 20; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 102 + data_type: Array + nested_path:"0" + element_type:Int64 + > + arith_op:Div + right_operand: + op:LessEqual + value: + >)", + "long", + [](milvus::Array& array) { + auto val = array.get_data(0); + return val / 2 <= 20; + }}, {R"(binary_arith_op_eval_range_expr: < column_info: < field_id: 102 @@ -1478,6 +1750,74 @@ TEST(Expr, TestArrayBinaryArith) { auto val = array.get_data(0); return val % 3 != 2; }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 102 + data_type: Array + nested_path:"0" + element_type:Int64 + > + arith_op:Mod + right_operand: + op:GreaterThan + value: + >)", + "long", + [](milvus::Array& array) { + auto val = array.get_data(0); + return val % 3 > 2; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 102 + data_type: Array + nested_path:"0" + element_type:Int64 + > + arith_op:Mod + right_operand: + op:GreaterEqual + value: + >)", + "long", + [](milvus::Array& array) { + auto val = array.get_data(0); + return val % 3 >= 2; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 102 + data_type: Array + nested_path:"0" + element_type:Int64 + > + arith_op:Mod + right_operand: + op:LessThan + value: + >)", + "long", + [](milvus::Array& array) { + auto val = array.get_data(0); + return val % 3 < 2; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 102 + data_type: Array + nested_path:"0" + element_type:Int64 + > + arith_op:Mod + right_operand: + op:LessEqual + value: + >)", + "long", + [](milvus::Array& array) { + auto val = array.get_data(0); + return val % 3 <= 2; + }}, {R"(binary_arith_op_eval_range_expr: < column_info: < field_id: 103 @@ -1704,6 +2044,54 @@ TEST(Expr, TestArrayBinaryArith) { >)", "int", [](milvus::Array& array) { return array.length() != 8; }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 101 + data_type: Array + element_type:Int8 + > + arith_op:ArrayLength + op:GreaterThan + value: + >)", + "int", + [](milvus::Array& array) { return array.length() > 8; }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 101 + data_type: Array + element_type:Int8 + > + arith_op:ArrayLength + op:GreaterEqual + value: + >)", + "int", + [](milvus::Array& array) { return array.length() >= 8; }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 101 + data_type: Array + element_type:Int8 + > + arith_op:ArrayLength + op:LessThan + value: + >)", + "int", + [](milvus::Array& array) { return array.length() < 8; }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 101 + data_type: Array + element_type:Int8 + > + arith_op:ArrayLength + op:LessEqual + value: + >)", + "int", + [](milvus::Array& array) { return array.length() <= 8; }}, }; std::string raw_plan_tmp = R"(vector_anns: < diff --git a/internal/core/unittest/test_expr.cpp b/internal/core/unittest/test_expr.cpp index df109b7d49..dea2e889d6 100644 --- a/internal/core/unittest/test_expr.cpp +++ b/internal/core/unittest/test_expr.cpp @@ -2679,10 +2679,9 @@ TEST(Expr, TestCompareWithScalarIndexMaris) { } TEST(Expr, TestBinaryArithOpEvalRange) { - std::vector, DataType>> - testcases = { - // Add test cases for BinaryArithOpEvalRangeExpr EQ of various data types - {R"(binary_arith_op_eval_range_expr: < + std::vector, DataType>> testcases = { + // Add test cases for BinaryArithOpEvalRangeExpr EQ of various data types + {R"(binary_arith_op_eval_range_expr: < column_info: < field_id: 101 data_type: Int8 @@ -2696,9 +2695,9 @@ TEST(Expr, TestBinaryArithOpEvalRange) { int64_val: 8 > >)", - [](int8_t v) { return (v + 4) == 8; }, - DataType::INT8}, - {R"(binary_arith_op_eval_range_expr: < + [](int8_t v) { return (v + 4) == 8; }, + DataType::INT8}, + {R"(binary_arith_op_eval_range_expr: < column_info: < field_id: 102 data_type: Int16 @@ -2712,9 +2711,9 @@ TEST(Expr, TestBinaryArithOpEvalRange) { int64_val: 1500 > >)", - [](int16_t v) { return (v - 500) == 1500; }, - DataType::INT16}, - {R"(binary_arith_op_eval_range_expr: < + [](int16_t v) { return (v - 500) == 1500; }, + DataType::INT16}, + {R"(binary_arith_op_eval_range_expr: < column_info: < field_id: 103 data_type: Int32 @@ -2728,9 +2727,9 @@ TEST(Expr, TestBinaryArithOpEvalRange) { int64_val: 4000 > >)", - [](int32_t v) { return (v * 2) == 4000; }, - DataType::INT32}, - {R"(binary_arith_op_eval_range_expr: < + [](int32_t v) { return (v * 2) == 4000; }, + DataType::INT32}, + {R"(binary_arith_op_eval_range_expr: < column_info: < field_id: 104 data_type: Int64 @@ -2744,9 +2743,9 @@ TEST(Expr, TestBinaryArithOpEvalRange) { int64_val: 1000 > >)", - [](int64_t v) { return (v / 2) == 1000; }, - DataType::INT64}, - {R"(binary_arith_op_eval_range_expr: < + [](int64_t v) { return (v / 2) == 1000; }, + DataType::INT64}, + {R"(binary_arith_op_eval_range_expr: < column_info: < field_id: 103 data_type: Int32 @@ -2760,9 +2759,9 @@ TEST(Expr, TestBinaryArithOpEvalRange) { int64_val: 0 > >)", - [](int32_t v) { return (v % 100) == 0; }, - DataType::INT32}, - {R"(binary_arith_op_eval_range_expr: < + [](int32_t v) { return (v % 100) == 0; }, + DataType::INT32}, + {R"(binary_arith_op_eval_range_expr: < column_info: < field_id: 105 data_type: Float @@ -2776,9 +2775,9 @@ TEST(Expr, TestBinaryArithOpEvalRange) { float_val: 2500 > >)", - [](float v) { return (v + 500) == 2500; }, - DataType::FLOAT}, - {R"(binary_arith_op_eval_range_expr: < + [](float v) { return (v + 500) == 2500; }, + DataType::FLOAT}, + {R"(binary_arith_op_eval_range_expr: < column_info: < field_id: 106 data_type: Double @@ -2792,10 +2791,10 @@ TEST(Expr, TestBinaryArithOpEvalRange) { float_val: 2500 > >)", - [](double v) { return (v + 500) == 2500; }, - DataType::DOUBLE}, - // Add test cases for BinaryArithOpEvalRangeExpr NE of various data types - {R"(binary_arith_op_eval_range_expr: < + [](double v) { return (v + 500) == 2500; }, + DataType::DOUBLE}, + // Add test cases for BinaryArithOpEvalRangeExpr NE of various data types + {R"(binary_arith_op_eval_range_expr: < column_info: < field_id: 105 data_type: Float @@ -2809,9 +2808,9 @@ TEST(Expr, TestBinaryArithOpEvalRange) { float_val: 2500 > >)", - [](float v) { return (v + 500) != 2500; }, - DataType::FLOAT}, - {R"(binary_arith_op_eval_range_expr: < + [](float v) { return (v + 500) != 2500; }, + DataType::FLOAT}, + {R"(binary_arith_op_eval_range_expr: < column_info: < field_id: 106 data_type: Double @@ -2825,9 +2824,9 @@ TEST(Expr, TestBinaryArithOpEvalRange) { float_val: 2500 > >)", - [](double v) { return (v - 500) != 2500; }, - DataType::DOUBLE}, - {R"(binary_arith_op_eval_range_expr: < + [](double v) { return (v - 500) != 2500; }, + DataType::DOUBLE}, + {R"(binary_arith_op_eval_range_expr: < column_info: < field_id: 101 data_type: Int8 @@ -2841,9 +2840,9 @@ TEST(Expr, TestBinaryArithOpEvalRange) { int64_val: 2 > >)", - [](int8_t v) { return (v * 2) != 2; }, - DataType::INT8}, - {R"(binary_arith_op_eval_range_expr: < + [](int8_t v) { return (v * 2) != 2; }, + DataType::INT8}, + {R"(binary_arith_op_eval_range_expr: < column_info: < field_id: 102 data_type: Int16 @@ -2857,9 +2856,9 @@ TEST(Expr, TestBinaryArithOpEvalRange) { int64_val: 1000 > >)", - [](int16_t v) { return (v / 2) != 1000; }, - DataType::INT16}, - {R"(binary_arith_op_eval_range_expr: < + [](int16_t v) { return (v / 2) != 1000; }, + DataType::INT16}, + {R"(binary_arith_op_eval_range_expr: < column_info: < field_id: 103 data_type: Int32 @@ -2873,9 +2872,9 @@ TEST(Expr, TestBinaryArithOpEvalRange) { int64_val: 0 > >)", - [](int32_t v) { return (v % 100) != 0; }, - DataType::INT32}, - {R"(binary_arith_op_eval_range_expr: < + [](int32_t v) { return (v % 100) != 0; }, + DataType::INT32}, + {R"(binary_arith_op_eval_range_expr: < column_info: < field_id: 104 data_type: Int64 @@ -2889,9 +2888,397 @@ TEST(Expr, TestBinaryArithOpEvalRange) { int64_val: 2500 > >)", - [](int64_t v) { return (v + 500) != 2500; }, - DataType::INT64}, - }; + [](int64_t v) { return (v + 500) != 2500; }, + DataType::INT64}, + // Add test cases for BinaryArithOpEvalRangeExpr GT of various data types + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 105 + data_type: Float + > + arith_op: Add + right_operand: < + float_val: 500 + > + op: GreaterThan + value: < + float_val: 2500 + > + >)", + [](float v) { return (v + 500) > 2500; }, + DataType::FLOAT}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 106 + data_type: Double + > + arith_op: Sub + right_operand: < + float_val: 500 + > + op: GreaterThan + value: < + float_val: 2500 + > + >)", + [](double v) { return (v - 500) > 2500; }, + DataType::DOUBLE}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 101 + data_type: Int8 + > + arith_op: Mul + right_operand: < + int64_val: 2 + > + op: GreaterThan + value: < + int64_val: 2 + > + >)", + [](int8_t v) { return (v * 2) > 2; }, + DataType::INT8}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 102 + data_type: Int16 + > + arith_op: Div + right_operand: < + int64_val: 2 + > + op: GreaterThan + value: < + int64_val: 1000 + > + >)", + [](int16_t v) { return (v / 2) > 1000; }, + DataType::INT16}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 103 + data_type: Int32 + > + arith_op: Mod + right_operand: < + int64_val: 100 + > + op: GreaterThan + value: < + int64_val: 0 + > + >)", + [](int32_t v) { return (v % 100) > 0; }, + DataType::INT32}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 104 + data_type: Int64 + > + arith_op: Mod + right_operand: < + int64_val: 500 + > + op: GreaterThan + value: < + int64_val: 2500 + > + >)", + [](int64_t v) { return (v + 500) > 2500; }, + DataType::INT64}, + // Add test cases for BinaryArithOpEvalRangeExpr GE of various data types + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 105 + data_type: Float + > + arith_op: Add + right_operand: < + float_val: 500 + > + op: GreaterEqual + value: < + float_val: 2500 + > + >)", + [](float v) { return (v + 500) >= 2500; }, + DataType::FLOAT}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 106 + data_type: Double + > + arith_op: Sub + right_operand: < + float_val: 500 + > + op: GreaterEqual + value: < + float_val: 2500 + > + >)", + [](double v) { return (v - 500) >= 2500; }, + DataType::DOUBLE}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 101 + data_type: Int8 + > + arith_op: Mul + right_operand: < + int64_val: 2 + > + op: GreaterEqual + value: < + int64_val: 2 + > + >)", + [](int8_t v) { return (v * 2) >= 2; }, + DataType::INT8}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 102 + data_type: Int16 + > + arith_op: Div + right_operand: < + int64_val: 2 + > + op: GreaterEqual + value: < + int64_val: 1000 + > + >)", + [](int16_t v) { return (v / 2) >= 1000; }, + DataType::INT16}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 103 + data_type: Int32 + > + arith_op: Mod + right_operand: < + int64_val: 100 + > + op: GreaterEqual + value: < + int64_val: 0 + > + >)", + [](int32_t v) { return (v % 100) >= 0; }, + DataType::INT32}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 104 + data_type: Int64 + > + arith_op: Mod + right_operand: < + int64_val: 500 + > + op: GreaterEqual + value: < + int64_val: 2500 + > + >)", + [](int64_t v) { return (v + 500) >= 2500; }, + DataType::INT64}, + // Add test cases for BinaryArithOpEvalRangeExpr LT of various data types + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 105 + data_type: Float + > + arith_op: Add + right_operand: < + float_val: 500 + > + op: LessThan + value: < + float_val: 2500 + > + >)", + [](float v) { return (v + 500) < 2500; }, + DataType::FLOAT}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 106 + data_type: Double + > + arith_op: Sub + right_operand: < + float_val: 500 + > + op: LessThan + value: < + float_val: 2500 + > + >)", + [](double v) { return (v - 500) < 2500; }, + DataType::DOUBLE}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 101 + data_type: Int8 + > + arith_op: Mul + right_operand: < + int64_val: 2 + > + op: LessThan + value: < + int64_val: 2 + > + >)", + [](int8_t v) { return (v * 2) < 2; }, + DataType::INT8}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 102 + data_type: Int16 + > + arith_op: Div + right_operand: < + int64_val: 2 + > + op: LessThan + value: < + int64_val: 1000 + > + >)", + [](int16_t v) { return (v / 2) < 1000; }, + DataType::INT16}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 103 + data_type: Int32 + > + arith_op: Mod + right_operand: < + int64_val: 100 + > + op: LessThan + value: < + int64_val: 0 + > + >)", + [](int32_t v) { return (v % 100) < 0; }, + DataType::INT32}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 104 + data_type: Int64 + > + arith_op: Mod + right_operand: < + int64_val: 500 + > + op: LessThan + value: < + int64_val: 2500 + > + >)", + [](int64_t v) { return (v + 500) < 2500; }, + DataType::INT64}, + // Add test cases for BinaryArithOpEvalRangeExpr LE of various data types + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 105 + data_type: Float + > + arith_op: Add + right_operand: < + float_val: 500 + > + op: LessEqual + value: < + float_val: 2500 + > + >)", + [](float v) { return (v + 500) <= 2500; }, + DataType::FLOAT}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 106 + data_type: Double + > + arith_op: Sub + right_operand: < + float_val: 500 + > + op: LessEqual + value: < + float_val: 2500 + > + >)", + [](double v) { return (v - 500) <= 2500; }, + DataType::DOUBLE}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 101 + data_type: Int8 + > + arith_op: Mul + right_operand: < + int64_val: 2 + > + op: LessEqual + value: < + int64_val: 2 + > + >)", + [](int8_t v) { return (v * 2) <= 2; }, + DataType::INT8}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 102 + data_type: Int16 + > + arith_op: Div + right_operand: < + int64_val: 2 + > + op: LessEqual + value: < + int64_val: 1000 + > + >)", + [](int16_t v) { return (v / 2) <= 1000; }, + DataType::INT16}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 103 + data_type: Int32 + > + arith_op: Mod + right_operand: < + int64_val: 100 + > + op: LessEqual + value: < + int64_val: 0 + > + >)", + [](int32_t v) { return (v % 100) <= 0; }, + DataType::INT32}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 104 + data_type: Int64 + > + arith_op: Mod + right_operand: < + int64_val: 500 + > + op: LessEqual + value: < + int64_val: 2500 + > + >)", + [](int64_t v) { return (v + 500) <= 2500; }, + DataType::INT64}, + }; // std::string dsl_string_tmp = R"({ // "bool": { @@ -3052,24 +3439,769 @@ TEST(Expr, TestBinaryArithOpEvalRange) { } TEST(Expr, TestBinaryArithOpEvalRangeJSON) { - struct Testcase { - int64_t right_operand; - int64_t value; - OpType op; - std::vector nested_path; - }; - std::vector testcases{ - {10, 20, OpType::Equal, {"int"}}, - {20, 30, OpType::Equal, {"int"}}, - {30, 40, OpType::NotEqual, {"int"}}, - {40, 50, OpType::NotEqual, {"int"}}, - {10, 20, OpType::Equal, {"double"}}, - {20, 30, OpType::Equal, {"double"}}, - {30, 40, OpType::NotEqual, {"double"}}, - {40, 50, OpType::NotEqual, {"double"}}, - }; + using namespace milvus; + using namespace milvus::query; + using namespace milvus::segcore; + std::vector< + std::tuple>> + testcases = { + // Add test cases for BinaryArithOpEvalRangeExpr EQ of various data types + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Add + right_operand: < + int64_val: 1 + > + op: Equal + value: < + int64_val: 2 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val + 1) == 2; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Sub + right_operand: < + int64_val: 1 + > + op: Equal + value: < + int64_val: 2 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val - 1) == 2; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Mul + right_operand: < + int64_val: 2 + > + op: Equal + value: < + int64_val: 4 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val * 2) == 4; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Div + right_operand: < + int64_val: 2 + > + op: Equal + value: < + int64_val: 4 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val / 2) == 4; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Mod + right_operand: < + int64_val: 2 + > + op: Equal + value: + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val % 2) == 4; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"array" + > + arith_op: ArrayLength + op: Equal + value: + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"array"}); + int array_length = 0; + auto doc = json.doc(); + auto array = doc.at_pointer(pointer).get_array(); + if (!array.error()) { + array_length = array.count_elements(); + } + return array_length == 4; + }}, + // Add test cases for BinaryArithOpEvalRangeExpr NQ of various data types + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Add + right_operand: < + int64_val: 1 + > + op: NotEqual + value: < + int64_val: 2 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val + 1) != 2; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Sub + right_operand: < + int64_val: 1 + > + op: NotEqual + value: < + int64_val: 2 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val - 1) != 2; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Mul + right_operand: < + int64_val: 2 + > + op: NotEqual + value: < + int64_val: 4 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val * 2) != 4; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Div + right_operand: < + int64_val: 2 + > + op: NotEqual + value: < + int64_val: 4 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val / 2) != 4; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Mod + right_operand: < + int64_val: 2 + > + op: NotEqual + value: < + int64_val: 4 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val % 2) != 4; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"array" + > + arith_op: ArrayLength + op: NotEqual + value: < + int64_val: 4 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"array"}); + int array_length = 0; + auto doc = json.doc(); + auto array = doc.at_pointer(pointer).get_array(); + if (!array.error()) { + array_length = array.count_elements(); + } + return array_length != 4; + }}, + + // Add test cases for BinaryArithOpEvalRangeExpr GT of various data types + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Add + right_operand: < + int64_val: 1 + > + op: GreaterThan + value: < + int64_val: 2 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val + 1) > 2; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Sub + right_operand: < + int64_val: 1 + > + op: GreaterThan + value: < + int64_val: 2 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val - 1) > 2; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Mul + right_operand: < + int64_val: 2 + > + op: GreaterThan + value: < + int64_val: 4 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val * 2) > 4; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Div + right_operand: < + int64_val: 2 + > + op: GreaterThan + value: < + int64_val: 4 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val / 2) > 4; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Mod + right_operand: < + int64_val: 2 + > + op: GreaterThan + value: < + int64_val: 4 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val % 2) > 4; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"array" + > + arith_op: ArrayLength + op: GreaterThan + value: < + int64_val: 4 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"array"}); + int array_length = 0; + auto doc = json.doc(); + auto array = doc.at_pointer(pointer).get_array(); + if (!array.error()) { + array_length = array.count_elements(); + } + return array_length > 4; + }}, + + // Add test cases for BinaryArithOpEvalRangeExpr GE of various data types + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Add + right_operand: < + int64_val: 1 + > + op: GreaterEqual + value: < + int64_val: 2 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val + 1) >= 2; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Sub + right_operand: < + int64_val: 1 + > + op: GreaterEqual + value: < + int64_val: 2 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val - 1) >= 2; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Mul + right_operand: < + int64_val: 2 + > + op: GreaterEqual + value: < + int64_val: 4 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val * 2) >= 4; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Div + right_operand: < + int64_val: 2 + > + op: GreaterEqual + value: < + int64_val: 4 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val / 2) >= 4; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Mod + right_operand: < + int64_val: 2 + > + op: GreaterEqual + value: < + int64_val: 4 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val % 2) >= 4; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"array" + > + arith_op: ArrayLength + op: GreaterEqual + value: < + int64_val: 4 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"array"}); + int array_length = 0; + auto doc = json.doc(); + auto array = doc.at_pointer(pointer).get_array(); + if (!array.error()) { + array_length = array.count_elements(); + } + return array_length >= 4; + }}, + + // Add test cases for BinaryArithOpEvalRangeExpr LT of various data types + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Add + right_operand: < + int64_val: 1 + > + op: LessThan + value: < + int64_val: 2 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val + 1) < 2; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Sub + right_operand: < + int64_val: 1 + > + op: LessThan + value: < + int64_val: 2 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val - 1) < 2; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Mul + right_operand: < + int64_val: 2 + > + op: LessThan + value: < + int64_val: 4 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val * 2) < 4; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Div + right_operand: < + int64_val: 2 + > + op: LessThan + value: < + int64_val: 4 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val / 2) < 4; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Mod + right_operand: < + int64_val: 2 + > + op: LessThan + value: < + int64_val: 4 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val % 2) < 4; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"array" + > + arith_op: ArrayLength + op: LessThan + value: < + int64_val: 4 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"array"}); + int array_length = 0; + auto doc = json.doc(); + auto array = doc.at_pointer(pointer).get_array(); + if (!array.error()) { + array_length = array.count_elements(); + } + return array_length < 4; + }}, + + // Add test cases for BinaryArithOpEvalRangeExpr LE of various data types + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Add + right_operand: < + int64_val: 1 + > + op: LessEqual + value: < + int64_val: 2 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val + 1) <= 2; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Sub + right_operand: < + int64_val: 1 + > + op: LessEqual + value: < + int64_val: 2 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val - 1) <= 2; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Mul + right_operand: < + int64_val: 2 + > + op: LessEqual + value: < + int64_val: 4 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val * 2) <= 4; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Div + right_operand: < + int64_val: 2 + > + op: LessEqual + value: < + int64_val: 4 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val / 2) <= 4; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Mod + right_operand: < + int64_val: 2 + > + op: LessEqual + value: < + int64_val: 4 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val % 2) <= 4; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"array" + > + arith_op: ArrayLength + op: LessEqual + value: < + int64_val: 4 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"array"}); + int array_length = 0; + auto doc = json.doc(); + auto array = doc.at_pointer(pointer).get_array(); + if (!array.error()) { + array_length = array.count_elements(); + } + return array_length <= 4; + }}, + }; + + std::string raw_plan_tmp = R"(vector_anns: < + field_id: 100 + predicates: < + @@@@@ + > + query_info: < + topk: 10 + round_decimal: 3 + metric_type: "L2" + search_params: "{\"nprobe\": 10}" + > + placeholder_tag: "$0" + >)"; auto schema = std::make_shared(); + auto vec_fid = schema->AddDebugField( + "fakevec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2); auto i64_fid = schema->AddDebugField("id", DataType::INT64); auto json_fid = schema->AddDebugField("json", DataType::JSON); schema->set_primary_field_id(i64_fid); @@ -3094,48 +4226,26 @@ TEST(Expr, TestBinaryArithOpEvalRangeJSON) { auto seg_promote = dynamic_cast(seg.get()); query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); - for (auto testcase : testcases) { - auto check = [&](int64_t value) { - if (testcase.op == OpType::Equal) { - return value + testcase.right_operand == testcase.value; - } - return value + testcase.right_operand != testcase.value; - }; - auto pointer = milvus::Json::pointer(testcase.nested_path); - proto::plan::GenericValue value; - value.set_int64_val(testcase.value); - proto::plan::GenericValue right_operand; - right_operand.set_int64_val(testcase.right_operand); - auto expr = std::make_shared( - milvus::expr::ColumnInfo( - json_fid, DataType::JSON, testcase.nested_path), - testcase.op, - ArithOpType::Add, - value, - right_operand); - BitsetType final; + + for (auto [clause, ref_func] : testcases) { + auto loc = raw_plan_tmp.find("@@@@@"); + auto raw_plan = raw_plan_tmp; + raw_plan.replace(loc, 5, clause); + auto plan_str = translate_text_plan_to_binary_plan(raw_plan.c_str()); auto plan = - std::make_shared(DEFAULT_PLANNODE_ID, expr); - visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); + BitsetType final; + visitor.ExecuteExprNode(plan->plan_node_->filter_plannode_.value(), + seg_promote, + N * num_iters, + final); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { auto ans = final[i]; - - if (testcase.nested_path[0] == "int") { - auto val = milvus::Json(simdjson::padded_string(json_col[i])) - .template at(pointer) - .value(); - auto ref = check(val); - ASSERT_EQ(ans, ref) << testcase.value << " " << val; - } else { - auto val = milvus::Json(simdjson::padded_string(json_col[i])) - .template at(pointer) - .value(); - auto ref = check(val); - ASSERT_EQ(ans, ref) << testcase.value << " " << val << " " - << testcase.op << " " << i; - } + auto ref = + ref_func(milvus::Json(simdjson::padded_string(json_col[i]))); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << json_col[i]; } } } @@ -3400,6 +4510,214 @@ TEST(Expr, TestBinaryArithOpEvalRangeWithScalarSortIndex) { >)", [](int64_t v) { return (v + 500) != 2000; }, DataType::INT64}, + + // Add test cases for BinaryArithOpEvalRangeExpr GT of various data types + {R"(arith_op: Add + right_operand: < + int64_val: 4 + > + op: GreaterThan + value: < + int64_val: 8 + >)", + [](int8_t v) { return (v + 4) > 8; }, + DataType::INT8}, + {R"(arith_op: Sub + right_operand: < + int64_val: 500 + > + op: GreaterThan + value: < + int64_val: 1500 + >)", + [](int16_t v) { return (v - 500) > 1500; }, + DataType::INT16}, + {R"(arith_op: Mul + right_operand: < + int64_val: 2 + > + op: GreaterThan + value: < + int64_val: 4000 + >)", + [](int32_t v) { return (v * 2) > 4000; }, + DataType::INT32}, + {R"(arith_op: Div + right_operand: < + int64_val: 2 + > + op: GreaterThan + value: < + int64_val: 1000 + >)", + [](int64_t v) { return (v / 2) > 1000; }, + DataType::INT64}, + {R"(arith_op: Mod + right_operand: < + int64_val: 100 + > + op: GreaterThan + value: < + int64_val: 0 + >)", + [](int32_t v) { return (v % 100) > 0; }, + DataType::INT32}, + + // Add test cases for BinaryArithOpEvalRangeExpr GE of various data types + {R"(arith_op: Add + right_operand: < + int64_val: 4 + > + op: GreaterEqual + value: < + int64_val: 8 + >)", + [](int8_t v) { return (v + 4) >= 8; }, + DataType::INT8}, + {R"(arith_op: Sub + right_operand: < + int64_val: 500 + > + op: GreaterEqual + value: < + int64_val: 1500 + >)", + [](int16_t v) { return (v - 500) >= 1500; }, + DataType::INT16}, + {R"(arith_op: Mul + right_operand: < + int64_val: 2 + > + op: GreaterEqual + value: < + int64_val: 4000 + >)", + [](int32_t v) { return (v * 2) >= 4000; }, + DataType::INT32}, + {R"(arith_op: Div + right_operand: < + int64_val: 2 + > + op: GreaterEqual + value: < + int64_val: 1000 + >)", + [](int64_t v) { return (v / 2) >= 1000; }, + DataType::INT64}, + {R"(arith_op: Mod + right_operand: < + int64_val: 100 + > + op: GreaterEqual + value: < + int64_val: 0 + >)", + [](int32_t v) { return (v % 100) >= 0; }, + DataType::INT32}, + + // Add test cases for BinaryArithOpEvalRangeExpr LT of various data types + {R"(arith_op: Add + right_operand: < + int64_val: 4 + > + op: LessThan + value: < + int64_val: 8 + >)", + [](int8_t v) { return (v + 4) < 8; }, + DataType::INT8}, + {R"(arith_op: Sub + right_operand: < + int64_val: 500 + > + op: LessThan + value: < + int64_val: 1500 + >)", + [](int16_t v) { return (v - 500) < 1500; }, + DataType::INT16}, + {R"(arith_op: Mul + right_operand: < + int64_val: 2 + > + op: LessThan + value: < + int64_val: 4000 + >)", + [](int32_t v) { return (v * 2) < 4000; }, + DataType::INT32}, + {R"(arith_op: Div + right_operand: < + int64_val: 2 + > + op: LessThan + value: < + int64_val: 1000 + >)", + [](int64_t v) { return (v / 2) < 1000; }, + DataType::INT64}, + {R"(arith_op: Mod + right_operand: < + int64_val: 100 + > + op: LessThan + value: < + int64_val: 0 + >)", + [](int32_t v) { return (v % 100) < 0; }, + DataType::INT32}, + + // Add test cases for BinaryArithOpEvalRangeExpr LE of various data types + {R"(arith_op: Add + right_operand: < + int64_val: 4 + > + op: LessEqual + value: < + int64_val: 8 + >)", + [](int8_t v) { return (v + 4) <= 8; }, + DataType::INT8}, + {R"(arith_op: Sub + right_operand: < + int64_val: 500 + > + op: LessEqual + value: < + int64_val: 1500 + >)", + [](int16_t v) { return (v - 500) <= 1500; }, + DataType::INT16}, + {R"(arith_op: Mul + right_operand: < + int64_val: 2 + > + op: LessEqual + value: < + int64_val: 4000 + >)", + [](int32_t v) { return (v * 2) <= 4000; }, + DataType::INT32}, + {R"(arith_op: Div + right_operand: < + int64_val: 2 + > + op: LessEqual + value: < + int64_val: 1000 + >)", + [](int64_t v) { return (v / 2) <= 1000; }, + DataType::INT64}, + {R"(arith_op: Mod + right_operand: < + int64_val: 100 + > + op: LessEqual + value: < + int64_val: 0 + >)", + [](int32_t v) { return (v % 100) <= 0; }, + DataType::INT32}, }; std::string serialized_expr_plan = R"(vector_anns: < diff --git a/internal/parser/planparserv2/plan_parser_v2_test.go b/internal/parser/planparserv2/plan_parser_v2_test.go index 8cc4f8924c..ebb76142b7 100644 --- a/internal/parser/planparserv2/plan_parser_v2_test.go +++ b/internal/parser/planparserv2/plan_parser_v2_test.go @@ -218,6 +218,15 @@ func TestExpr_BinaryArith(t *testing.T) { `Int64Field % 10 != 9`, `Int64Field + 1.1 == 2.1`, `A % 10 != 2`, + `Int8Field + 1 < 2`, + `Int16Field - 3 <= 4`, + `Int32Field * 5 > 6`, + `Int64Field / 7 >= 8`, + `FloatField + 11 < 12`, + `DoubleField - 13 <= 14`, + `A * 15 > 16`, + `JSONField['A'] / 17 >= 18`, + `ArrayField[0] % 19 >= 20`, } for _, exprStr := range exprStrs { assertValidExpr(t, helper, exprStr) @@ -225,13 +234,6 @@ func TestExpr_BinaryArith(t *testing.T) { // TODO: enable these after execution backend is ready. unsupported := []string{ - `Int8Field + 1 < 2`, - `Int16Field - 3 <= 4`, - `Int32Field * 5 > 6`, - `Int64Field / 7 >= 8`, - `FloatField + 11 < 12`, - `DoubleField - 13 < 14`, - `A - 15 < 16`, `JSONField + 15 == 16`, `15 + JSONField == 16`, `ArrayField + 15 == 16`, @@ -1172,6 +1174,10 @@ func Test_ArrayLength(t *testing.T) { `array_length(B) != 1`, `not (array_length(C[0]) == 1)`, `not (array_length(C["D"]) != 1)`, + `array_length(StringArrayField) < 1`, + `array_length(StringArrayField) <= 1`, + `array_length(StringArrayField) > 5`, + `array_length(StringArrayField) >= 5`, } for _, expr = range exprs { _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ @@ -1193,7 +1199,6 @@ func Test_ArrayLength(t *testing.T) { `0 < array_length(a-b) < 2`, `0 < array_length(StringArrayField) < 1`, `100 > array_length(ArrayField) > 10`, - `array_length(StringArrayField) < 1`, `array_length(A) % 10 == 2`, `array_length(A) / 10 == 2`, `array_length(A) + 1 == 2`, diff --git a/internal/parser/planparserv2/utils.go b/internal/parser/planparserv2/utils.go index e0bfc1714f..24c042d615 100644 --- a/internal/parser/planparserv2/utils.go +++ b/internal/parser/planparserv2/utils.go @@ -259,14 +259,6 @@ func combineArrayLengthExpr(op planpb.OpType, arithOp planpb.ArithOpType, column } func handleBinaryArithExpr(op planpb.OpType, arithExpr *planpb.BinaryArithExpr, valueExpr *planpb.ValueExpr) (*planpb.Expr, error) { - switch op { - case planpb.OpType_Equal, planpb.OpType_NotEqual: - break - default: - // TODO: enable this after execution is ready. - return nil, fmt.Errorf("%s is not supported in execution backend", op) - } - leftExpr, leftValue := arithExpr.Left.GetColumnExpr(), arithExpr.Left.GetValueExpr() rightExpr, rightValue := arithExpr.Right.GetColumnExpr(), arithExpr.Right.GetValueExpr() arithOp := arithExpr.GetOp() diff --git a/tests/integration/expression/expression_test.go b/tests/integration/expression/expression_test.go new file mode 100644 index 0000000000..859a7a4e87 --- /dev/null +++ b/tests/integration/expression/expression_test.go @@ -0,0 +1,240 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expression + +import ( + "context" + "encoding/json" + "fmt" + "testing" + + "github.com/golang/protobuf/proto" + "github.com/stretchr/testify/suite" + "go.uber.org/zap" + + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/funcutil" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/metric" + "github.com/milvus-io/milvus/tests/integration" +) + +type ExpressionSuite struct { + integration.MiniClusterSuite + dbName string + collectionName string + dim int + rowNum int +} + +func (s *ExpressionSuite) setParams() { + prefix := "TestExpression" + s.dbName = "" + s.collectionName = prefix + funcutil.GenRandomStr() + s.dim = 128 + s.rowNum = 100 +} + +func newJSONData(fieldName string, rowNum int) *schemapb.FieldData { + jsonData := make([][]byte, 0, rowNum) + for i := 0; i < rowNum; i++ { + data := map[string]interface{}{ + "A": i, + "B": rowNum - i, + "C": []int{i, rowNum - i}, + "D": fmt.Sprintf("name-%d", i), + "E": map[string]interface{}{ + "F": i, + "G": i + 10, + }, + "str1": `abc\"def-` + string(rune(i)), + "str2": fmt.Sprintf("abc\"def-%d", i), + "str3": fmt.Sprintf("abc\ndef-%d", i), + "str4": fmt.Sprintf("abc\367-%d", i), + } + if i%2 == 0 { + data = map[string]interface{}{ + "B": rowNum - i, + "C": []int{i, rowNum - i}, + "D": fmt.Sprintf("name-%d", i), + "E": map[string]interface{}{ + "F": i, + "G": i + 10, + }, + } + } + if i == 100 { + data = nil + } + jsonBytes, err := json.MarshalIndent(data, "", " ") + if err != nil { + return nil + } + jsonData = append(jsonData, jsonBytes) + } + return &schemapb.FieldData{ + Type: schemapb.DataType_JSON, + FieldName: fieldName, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_JsonData{ + JsonData: &schemapb.JSONArray{ + Data: jsonData, + }, + }, + }, + }, + } +} + +func (s *ExpressionSuite) insertFlushIndexLoad(ctx context.Context, fieldData []*schemapb.FieldData) { + hashKeys := integration.GenerateHashKeys(s.rowNum) + insertResult, err := s.Cluster.Proxy.Insert(ctx, &milvuspb.InsertRequest{ + DbName: s.dbName, + CollectionName: s.collectionName, + FieldsData: fieldData, + HashKeys: hashKeys, + NumRows: uint32(s.rowNum), + }) + s.NoError(err) + s.NoError(merr.Error(insertResult.GetStatus())) + + // flush + flushResp, err := s.Cluster.Proxy.Flush(ctx, &milvuspb.FlushRequest{ + DbName: s.dbName, + CollectionNames: []string{s.collectionName}, + }) + s.NoError(err) + segmentIDs, has := flushResp.GetCollSegIDs()[s.collectionName] + ids := segmentIDs.GetData() + s.Require().NotEmpty(segmentIDs) + s.Require().True(has) + flushTs, has := flushResp.GetCollFlushTs()[s.collectionName] + s.True(has) + + segments, err := s.Cluster.MetaWatcher.ShowSegments() + s.NoError(err) + s.NotEmpty(segments) + for _, segment := range segments { + log.Info("ShowSegments result", zap.String("segment", segment.String())) + } + s.WaitForFlush(ctx, ids, flushTs, s.dbName, s.collectionName) + + // create index + createIndexStatus, err := s.Cluster.Proxy.CreateIndex(context.TODO(), &milvuspb.CreateIndexRequest{ + CollectionName: s.collectionName, + FieldName: integration.FloatVecField, + IndexName: "_default", + ExtraParams: integration.ConstructIndexParam(s.dim, integration.IndexFaissIvfFlat, metric.IP), + }) + s.NoError(err) + err = merr.Error(createIndexStatus) + s.NoError(err) + s.WaitForIndexBuilt(context.TODO(), s.collectionName, integration.FloatVecField) + log.Info("=========================Index created=========================") + + // load + loadStatus, err := s.Cluster.Proxy.LoadCollection(ctx, &milvuspb.LoadCollectionRequest{ + DbName: s.dbName, + CollectionName: s.collectionName, + }) + s.NoError(err) + err = merr.Error(loadStatus) + s.NoError(err) + s.WaitForLoad(context.TODO(), s.collectionName) + log.Info("=========================Collection loaded=========================") +} + +func (s *ExpressionSuite) setupData() { + c := s.Cluster + ctx, cancel := context.WithCancel(c.GetContext()) + defer cancel() + + schema := integration.ConstructSchema(s.collectionName, s.dim, true) + schema.EnableDynamicField = true + marshaledSchema, err := proto.Marshal(schema) + s.NoError(err) + + createCollectionStatus, err := c.Proxy.CreateCollection(ctx, &milvuspb.CreateCollectionRequest{ + DbName: s.dbName, + CollectionName: s.collectionName, + Schema: marshaledSchema, + ShardsNum: 2, + }) + s.NoError(err) + err = merr.Error(createCollectionStatus) + s.NoError(err) + + showCollectionsResp, err := c.Proxy.ShowCollections(ctx, &milvuspb.ShowCollectionsRequest{}) + s.NoError(err) + err = merr.Error(showCollectionsResp.GetStatus()) + s.NoError(err) + + describeCollectionResp, err := c.Proxy.DescribeCollection(ctx, &milvuspb.DescribeCollectionRequest{CollectionName: s.collectionName}) + s.NoError(err) + err = merr.Error(describeCollectionResp.GetStatus()) + s.NoError(err) + s.True(describeCollectionResp.Schema.EnableDynamicField) + s.Equal(2, len(describeCollectionResp.GetSchema().GetFields())) + + fVecColumn := integration.NewFloatVectorFieldData(integration.FloatVecField, s.rowNum, s.dim) + jsonData := newJSONData(common.MetaFieldName, s.rowNum) + jsonData.IsDynamic = true + s.insertFlushIndexLoad(ctx, []*schemapb.FieldData{fVecColumn, jsonData}) +} + +type testCase struct { + expr string + topK int + resNum int +} + +func (s *ExpressionSuite) searchWithExpression() { + testcases := []testCase{ + {"A + 5 > 0", 10, 10}, + {"B - 5 >= 0", 10, 10}, + {"C[0] * 5 < 500", 10, 10}, + {"E['F'] / 5 <= 100", 10, 10}, + {"E['G'] % 5 == 4", 10, 10}, + {"A / 5 != 4", 10, 10}, + } + for _, c := range testcases { + params := integration.GetSearchParams(integration.IndexFaissIDMap, metric.IP) + searchReq := integration.ConstructSearchRequest(s.dbName, s.collectionName, c.expr, + integration.FloatVecField, schemapb.DataType_FloatVector, nil, metric.IP, params, 1, s.dim, c.topK, -1) + + searchResult, err := s.Cluster.Proxy.Search(context.Background(), searchReq) + s.NoError(err) + err = merr.Error(searchResult.GetStatus()) + s.NoError(err) + s.Equal(c.resNum, len(searchResult.GetResults().GetScores())) + log.Info(fmt.Sprintf("=========================Search done with expr:%s =========================", c.expr)) + } +} + +func (s *ExpressionSuite) TestExpression() { + s.setParams() + s.setupData() + s.searchWithExpression() +} + +func TestExpression(t *testing.T) { + suite.Run(t, new(ExpressionSuite)) +} diff --git a/tests/python_client/testcases/test_search.py b/tests/python_client/testcases/test_search.py index 4c0675c14f..bc6a5623c5 100644 --- a/tests/python_client/testcases/test_search.py +++ b/tests/python_client/testcases/test_search.py @@ -627,12 +627,8 @@ class TestCollectionSearchInvalid(TestcaseBase): # 2. search expression = "int32_array[0] - 1 < 1" - error = {ct.err_code: 65535, - ct.err_msg: f"failed to create query plan: cannot parse expression: {expression}, " - f"error: LessThan is not supported in execution backend"} collection_w.search(vectors[:default_nq], default_search_field, - default_search_params, nb, expression, - check_task=CheckTasks.err_res, check_items=error) + default_search_params, nb, expression) @pytest.mark.tags(CaseLabel.L2) def test_search_partition_invalid_type(self, get_invalid_partition):