From 3628593d20d8e0582d8423c35530cb0a95041123 Mon Sep 17 00:00:00 2001 From: Yinzuo Jiang Date: Fri, 25 Oct 2024 15:25:30 +0800 Subject: [PATCH] feat: Implement custom function module in milvus expr (#36560) OSPP 2024 project: https://summer-ospp.ac.cn/org/prodetail/247410235?list=org&navpage=org Solutions: - parser (planparserv2) - add CallExpr in planparserv2/Plan.g4 - update parser_visitor and show_visitor - grpc protobuf - add CallExpr in plan.proto - execution (`core/src/exec`) - add `CallExpr` `ValueExpr` and `ColumnExpr` (both logical and physical) for function call and function parameters - function factory (`core/src/exec/expression/function`) - create a global hashmap when starting milvus (see server.go) - the global hashmap stores function signatures and their function pointers, the CallExpr in execution engine can get the function pointer by function signature. - custom functions - empty(string) - starts_with(string, string) - add cpp/go unittests and E2E tests closes: #36559 Signed-off-by: Yinzuo Jiang --- docs/design_docs/segcore/visitor.md | 13 +- internal/core/CMakeLists.txt | 6 + internal/core/src/common/Types.h | 3 +- internal/core/src/common/Vector.h | 117 ++++- internal/core/src/common/init_c.cpp | 2 +- internal/core/src/exec/Task.cpp | 2 +- .../core/src/exec/expression/CallExpr.cpp | 46 ++ internal/core/src/exec/expression/CallExpr.h | 83 +++ .../core/src/exec/expression/ColumnExpr.cpp | 147 ++++++ .../core/src/exec/expression/ColumnExpr.h | 125 +++++ .../core/src/exec/expression/CompareExpr.cpp | 361 ++----------- .../core/src/exec/expression/CompareExpr.h | 261 ++++------ internal/core/src/exec/expression/Expr.cpp | 42 +- internal/core/src/exec/expression/Expr.h | 7 +- .../core/src/exec/expression/ValueExpr.cpp | 101 ++++ internal/core/src/exec/expression/ValueExpr.h | 67 +++ .../expression/function/FunctionFactory.cpp | 83 +++ .../expression/function/FunctionFactory.h | 111 ++++ .../expression/function/FunctionImplUtils.cpp | 30 ++ .../expression/function/FunctionImplUtils.h | 25 + .../exec/expression/function/impl/Empty.cpp | 59 +++ .../expression/function/impl/StartsWith.cpp | 64 +++ .../function/impl/StringFunctions.h | 35 ++ .../src/exec/expression/function/init_c.cpp | 23 + .../src/exec/expression/function/init_c.h | 28 + .../core/src/exec/operator/FilterBitsNode.cpp | 27 +- internal/core/src/expr/ITypeExpr.h | 144 ++++-- internal/core/src/query/PlanProto.cpp | 87 +++- internal/core/src/query/PlanProto.h | 48 +- .../core/src/segcore/SegmentChunkReader.cpp | 327 ++++++++++++ .../core/src/segcore/SegmentChunkReader.h | 141 +++++ internal/core/src/segcore/SegmentInterface.h | 4 +- internal/core/unittest/CMakeLists.txt | 1 + internal/core/unittest/test_exec.cpp | 63 ++- internal/core/unittest/test_expr.cpp | 171 ++++++ internal/core/unittest/test_function.cpp | 239 +++++++++ internal/parser/planparserv2/Plan.g4 | 1 + .../planparserv2/check_identical_test.go | 28 +- .../parser/planparserv2/generated/Plan.interp | 2 +- .../generated/plan_base_visitor.go | 4 + .../planparserv2/generated/plan_parser.go | 488 ++++++++++++------ .../planparserv2/generated/plan_visitor.go | 3 + .../parser/planparserv2/parser_visitor.go | 22 + .../planparserv2/plan_parser_v2_test.go | 59 ++- internal/parser/planparserv2/show_visitor.go | 16 +- internal/parser/planparserv2/utils.go | 10 +- internal/proto/plan.proto | 6 + internal/querynodev2/server.go | 2 + tests/go_client/common/utils.go | 4 +- tests/python_client/testcases/test_query.py | 56 ++ 50 files changed, 3003 insertions(+), 791 deletions(-) create mode 100644 internal/core/src/exec/expression/CallExpr.cpp create mode 100644 internal/core/src/exec/expression/CallExpr.h create mode 100644 internal/core/src/exec/expression/ColumnExpr.cpp create mode 100644 internal/core/src/exec/expression/ColumnExpr.h create mode 100644 internal/core/src/exec/expression/ValueExpr.cpp create mode 100644 internal/core/src/exec/expression/ValueExpr.h create mode 100644 internal/core/src/exec/expression/function/FunctionFactory.cpp create mode 100644 internal/core/src/exec/expression/function/FunctionFactory.h create mode 100644 internal/core/src/exec/expression/function/FunctionImplUtils.cpp create mode 100644 internal/core/src/exec/expression/function/FunctionImplUtils.h create mode 100644 internal/core/src/exec/expression/function/impl/Empty.cpp create mode 100644 internal/core/src/exec/expression/function/impl/StartsWith.cpp create mode 100644 internal/core/src/exec/expression/function/impl/StringFunctions.h create mode 100644 internal/core/src/exec/expression/function/init_c.cpp create mode 100644 internal/core/src/exec/expression/function/init_c.h create mode 100644 internal/core/src/segcore/SegmentChunkReader.cpp create mode 100644 internal/core/src/segcore/SegmentChunkReader.h create mode 100644 internal/core/unittest/test_function.cpp diff --git a/docs/design_docs/segcore/visitor.md b/docs/design_docs/segcore/visitor.md index 6cf70d7a8f..2c8fd5568f 100644 --- a/docs/design_docs/segcore/visitor.md +++ b/docs/design_docs/segcore/visitor.md @@ -1,20 +1,15 @@ # Visitor Pattern Visitor Pattern is used in segcore for parse and execute Execution Plan. -1. Inside `${core}/src/query/PlanNode.h`, contains physical plan for vector search: +1. Inside `${internal/core}/src/query/PlanNode.h`, contains physical plan for vector search: 1. `FloatVectorANNS` FloatVector search execution node 2. `BinaryVectorANNS` BinaryVector search execution node -2. `${core}/src/query/Expr.h` contains physical plan for scalar expression: +2. `${internal/core}/src/query/Expr.h` contains physical plan for scalar expression: 1. `TermExpr` support operation like `col in [1, 2, 3]` 2. `RangeExpr` support constant compare with data column like `a >= 5` `1 < b < 2` 3. `CompareExpr` support compare with different columns, like `a < b` 4. `LogicalBinaryExpr` support and/or 5. `LogicalUnaryExpr` support not -Currently, under `${core/query/visitors}` directory, there are the following visitors: -1. `ShowPlanNodeVisitor` prints PlanNode in json -2. `ShowExprVisitor` Expr -> json -3. `Verify...Visitor` validates ... -4. `ExtractInfo...Visitor` extracts info from..., including involved_fields and else -5. `ExecExprVisitor` generates bitmask according to expression -6. `ExecPlanNodeVistor` physical plan executor only supports ANNS node for now +Currently, under `${internal/core/src/query}` directory, there are the following visitors: +1. `ExecPlanNodeVistor` physical plan executor only supports ANNS node for now diff --git a/internal/core/CMakeLists.txt b/internal/core/CMakeLists.txt index f5530f8e4c..2b33959183 100644 --- a/internal/core/CMakeLists.txt +++ b/internal/core/CMakeLists.txt @@ -292,6 +292,12 @@ install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/src/segcore/ FILES_MATCHING PATTERN "*_c.h" ) +# Install exec/expression/function +install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/src/exec/expression/function/ + DESTINATION include/exec/expression/function + FILES_MATCHING PATTERN "*_c.h" +) + # Install indexbuilder install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/src/indexbuilder/ DESTINATION include/indexbuilder diff --git a/internal/core/src/common/Types.h b/internal/core/src/common/Types.h index 2473b21a88..d26d2ee2ed 100644 --- a/internal/core/src/common/Types.h +++ b/internal/core/src/common/Types.h @@ -59,6 +59,7 @@ using float16 = knowhere::fp16; using bfloat16 = knowhere::bf16; using bin1 = knowhere::bin1; +// See also: https://github.com/milvus-io/milvus-proto/blob/master/proto/schema.proto enum class DataType { NONE = 0, BOOL = 1, @@ -682,4 +683,4 @@ struct fmt::formatter : formatter { } return formatter::format(name, ctx); } -}; \ No newline at end of file +}; diff --git a/internal/core/src/common/Vector.h b/internal/core/src/common/Vector.h index afc1d4766e..6fa073e1d7 100644 --- a/internal/core/src/common/Vector.h +++ b/internal/core/src/common/Vector.h @@ -17,11 +17,13 @@ #pragma once #include -#include #include "EasyAssert.h" #include "Types.h" +#include "bitset/bitset.h" #include "common/FieldData.h" +#include "common/FieldDataInterface.h" +#include "common/Types.h" namespace milvus { @@ -29,7 +31,6 @@ namespace milvus { * @brief base class for different type vector * @todo implement full null value support */ - class BaseVector { public: BaseVector(DataType data_type, @@ -58,18 +59,39 @@ class BaseVector { using VectorPtr = std::shared_ptr; +/** + * SimpleVector abstracts over various Columnar Storage Formats, + * it is used in custom functions. + */ +class SimpleVector : public BaseVector { + public: + SimpleVector(DataType data_type, + size_t length, + std::optional null_count = std::nullopt) + : BaseVector(data_type, length, null_count) { + } + + virtual void* + RawValueAt(size_t index, size_t size_of_element) = 0; + + virtual bool + ValidAt(size_t index) = 0; +}; + /** * @brief Single vector for scalar types * @todo using memory pool && buffer replace FieldData */ -class ColumnVector final : public BaseVector { +class ColumnVector final : public SimpleVector { public: ColumnVector(DataType data_type, size_t length, std::optional null_count = std::nullopt) - : BaseVector(data_type, length, null_count) { + : SimpleVector(data_type, length, null_count), + is_bitmap_(false), + valid_values_(length, + !null_count.has_value() || null_count.value() == 0) { values_ = InitScalarFieldData(data_type, false, length); - valid_values_ = InitScalarFieldData(data_type, false, length); } // ColumnVector(FixedVector&& data) @@ -78,20 +100,14 @@ class ColumnVector final : public BaseVector { // std::make_shared>(DataType::BOOL, std::move(data)); // } - // // the size is the number of bits - // ColumnVector(TargetBitmap&& bitmap) - // : BaseVector(DataType::INT8, bitmap.size()) { - // values_ = std::make_shared>( - // bitmap.size(), DataType::INT8, false, std::move(bitmap).into()); - // } - // the size is the number of bits + // TODO: separate the usage of bitmap from scalar field data ColumnVector(TargetBitmap&& bitmap, TargetBitmap&& valid_bitmap) - : BaseVector(DataType::INT8, bitmap.size()) { + : SimpleVector(DataType::INT8, bitmap.size()), + is_bitmap_(true), + valid_values_(std::move(valid_bitmap)) { values_ = std::make_shared>(DataType::INT8, std::move(bitmap)); - valid_values_ = std::make_shared>( - DataType::INT8, std::move(valid_bitmap)); } virtual ~ColumnVector() override { @@ -100,28 +116,81 @@ class ColumnVector final : public BaseVector { } void* - GetRawData() { + RawValueAt(size_t index, size_t size_of_element) override { + return reinterpret_cast(GetRawData()) + index * size_of_element; + } + + bool + ValidAt(size_t index) override { + return valid_values_[index]; + } + + void* + GetRawData() const { return values_->Data(); } void* GetValidRawData() { - return valid_values_->Data(); + return valid_values_.data(); } template - const As* + As* RawAsValues() const { - return reinterpret_cast(values_->Data()); + return reinterpret_cast(values_->Data()); + } + + bool + IsBitmap() const { + return is_bitmap_; } private: + bool is_bitmap_; // TODO: remove the field after implementing BitmapVector FieldDataPtr values_; - FieldDataPtr valid_values_; + TargetBitmap valid_values_; // false means the value is null }; using ColumnVectorPtr = std::shared_ptr; +template +class ConstantVector : public SimpleVector { + public: + ConstantVector(DataType data_type, + size_t length, + const T& val, + std::optional null_count = std::nullopt) + : SimpleVector(data_type, length), + val_(val), + is_null_(null_count.has_value() && null_count.value() > 0) { + } + + void* + RawValueAt(size_t _index, size_t _size_of_element) override { + return &val_; + } + + bool + ValidAt(size_t _index) override { + return !is_null_; + } + + const T& + GetValue() const { + return val_; + } + + bool + IsNull() const { + return is_null_; + } + + private: + T val_; + bool is_null_; +}; + /** * @brief Multi vectors for scalar types * mainly using it to pass internal result in segcore scalar engine system @@ -149,8 +218,7 @@ class RowVector : public BaseVector { } RowVector(std::vector&& children) - : BaseVector(DataType::ROW, 0) { - children_values_ = std::move(children); + : BaseVector(DataType::ROW, 0), children_values_(std::move(children)) { for (auto& child : children_values_) { if (child->size() > length_) { length_ = child->size(); @@ -159,12 +227,12 @@ class RowVector : public BaseVector { } const std::vector& - childrens() { + childrens() const { return children_values_; } VectorPtr - child(int index) { + child(int index) const { assert(index < children_values_.size()); return children_values_[index]; } @@ -174,5 +242,4 @@ class RowVector : public BaseVector { }; using RowVectorPtr = std::shared_ptr; - } // namespace milvus diff --git a/internal/core/src/common/init_c.cpp b/internal/core/src/common/init_c.cpp index ce961b7d8b..77764ffa55 100644 --- a/internal/core/src/common/init_c.cpp +++ b/internal/core/src/common/init_c.cpp @@ -105,4 +105,4 @@ SetTrace(CTraceConfig* config) { config->oltpSecure, config->nodeID}; milvus::tracer::initTelemetry(traceConfig); -} \ No newline at end of file +} diff --git a/internal/core/src/exec/Task.cpp b/internal/core/src/exec/Task.cpp index d03ca3f97f..14731417f0 100644 --- a/internal/core/src/exec/Task.cpp +++ b/internal/core/src/exec/Task.cpp @@ -235,4 +235,4 @@ Task::Next(ContinueFuture* future) { } } // namespace exec -} // namespace milvus \ No newline at end of file +} // namespace milvus diff --git a/internal/core/src/exec/expression/CallExpr.cpp b/internal/core/src/exec/expression/CallExpr.cpp new file mode 100644 index 0000000000..0e6fb0fc5c --- /dev/null +++ b/internal/core/src/exec/expression/CallExpr.cpp @@ -0,0 +1,46 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/FieldDataInterface.h" +#include "common/Vector.h" +#include "exec/expression/CallExpr.h" +#include "exec/expression/EvalCtx.h" +#include "exec/expression/function/FunctionFactory.h" + +#include +#include + +namespace milvus { +namespace exec { + +void +PhyCallExpr::Eval(EvalCtx& context, VectorPtr& result) { + AssertInfo(inputs_.size() == expr_->inputs().size(), + "logical call expr needs {} inputs, but {} inputs are provided", + expr_->inputs().size(), + inputs_.size()); + std::vector args; + for (auto& input : this->inputs_) { + VectorPtr arg_result; + input->Eval(context, arg_result); + args.push_back(std::move(arg_result)); + } + RowVector row_vector(std::move(args)); + this->expr_->function_ptr()(row_vector, result); +} + +} // namespace exec +} // namespace milvus diff --git a/internal/core/src/exec/expression/CallExpr.h b/internal/core/src/exec/expression/CallExpr.h new file mode 100644 index 0000000000..f074c7b423 --- /dev/null +++ b/internal/core/src/exec/expression/CallExpr.h @@ -0,0 +1,83 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include "common/EasyAssert.h" +#include "common/FieldDataInterface.h" +#include "common/Utils.h" +#include "common/Vector.h" +#include "exec/expression/EvalCtx.h" +#include "exec/expression/Expr.h" +#include "exec/expression/function/FunctionFactory.h" +#include "expr/ITypeExpr.h" +#include "fmt/core.h" +#include "segcore/SegmentInterface.h" + +namespace milvus { +namespace exec { + +class PhyCallExpr : public Expr { + public: + PhyCallExpr(const std::vector>& input, + const std::shared_ptr& expr, + const std::string& name, + const segcore::SegmentInternalInterface* segment, + int64_t active_count, + int64_t batch_size) + : Expr(DataType::BOOL, std::move(input), name), + expr_(expr), + active_count_(active_count), + segment_(segment), + batch_size_(batch_size) { + size_per_chunk_ = segment_->size_per_chunk(); + num_chunk_ = upper_div(active_count_, size_per_chunk_); + AssertInfo( + batch_size_ > 0, + fmt::format("expr batch size should greater than zero, but now: {}", + batch_size_)); + } + + void + Eval(EvalCtx& context, VectorPtr& result) override; + + void + MoveCursor() override { + for (auto input : inputs_) { + input->MoveCursor(); + } + } + + private: + std::shared_ptr expr_; + + int64_t active_count_{0}; + int64_t num_chunk_{0}; + int64_t current_chunk_id_{0}; + int64_t current_chunk_pos_{0}; + int64_t size_per_chunk_{0}; + + const segcore::SegmentInternalInterface* segment_; + int64_t batch_size_; +}; + +} // namespace exec +} // namespace milvus diff --git a/internal/core/src/exec/expression/ColumnExpr.cpp b/internal/core/src/exec/expression/ColumnExpr.cpp new file mode 100644 index 0000000000..933155bb3a --- /dev/null +++ b/internal/core/src/exec/expression/ColumnExpr.cpp @@ -0,0 +1,147 @@ +// Licensed to the LF AI & Data foundation under on0 +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ColumnExpr.h" + +namespace milvus { +namespace exec { + +int64_t +PhyColumnExpr::GetNextBatchSize() { + auto current_rows = GetCurrentRows(); + + return current_rows + batch_size_ >= segment_chunk_reader_.active_count_ + ? segment_chunk_reader_.active_count_ - current_rows + : batch_size_; +} + +void +PhyColumnExpr::Eval(EvalCtx& context, VectorPtr& result) { + switch (this->expr_->type()) { + case DataType::BOOL: + result = DoEval(); + break; + case DataType::INT8: + result = DoEval(); + break; + case DataType::INT16: + result = DoEval(); + break; + case DataType::INT32: + result = DoEval(); + break; + case DataType::INT64: + result = DoEval(); + break; + case DataType::FLOAT: + result = DoEval(); + break; + case DataType::DOUBLE: + result = DoEval(); + break; + case DataType::VARCHAR: { + result = DoEval(); + break; + } + default: + PanicInfo(DataTypeInvalid, + "unsupported data type: {}", + this->expr_->type()); + } +} + +template +VectorPtr +PhyColumnExpr::DoEval() { + // similar to PhyCompareFilterExpr::ExecCompareExprDispatcher(OpType op) + if (segment_chunk_reader_.segment_->is_chunked()) { + auto real_batch_size = GetNextBatchSize(); + if (real_batch_size == 0) { + return nullptr; + } + + auto res_vec = std::make_shared( + expr_->GetColumn().data_type_, real_batch_size); + T* res_value = res_vec->RawAsValues(); + TargetBitmapView valid_res(res_vec->GetValidRawData(), real_batch_size); + valid_res.set(); + auto chunk_data = segment_chunk_reader_.GetChunkDataAccessor( + expr_->GetColumn().data_type_, + expr_->GetColumn().field_id_, + is_indexed_, + current_chunk_id_, + current_chunk_pos_); + for (int i = 0; i < real_batch_size; ++i) { + if (!chunk_data().has_value()) { + valid_res[i] = false; + continue; + } + res_value[i] = boost::get(chunk_data().value()); + } + return res_vec; + } else { + auto real_batch_size = GetNextBatchSize(); + if (real_batch_size == 0) { + return nullptr; + } + + auto res_vec = std::make_shared( + expr_->GetColumn().data_type_, real_batch_size); + T* res_value = res_vec->RawAsValues(); + TargetBitmapView valid_res(res_vec->GetValidRawData(), real_batch_size); + valid_res.set(); + + auto data_barrier = segment_chunk_reader_.segment_->num_chunk_data( + expr_->GetColumn().field_id_); + + int64_t processed_rows = 0; + for (int64_t chunk_id = current_chunk_id_; chunk_id < num_chunk_; + ++chunk_id) { + auto chunk_size = + chunk_id == num_chunk_ - 1 + ? segment_chunk_reader_.active_count_ - + chunk_id * segment_chunk_reader_.SizePerChunk() + : segment_chunk_reader_.SizePerChunk(); + auto chunk_data = segment_chunk_reader_.GetChunkDataAccessor( + expr_->GetColumn().data_type_, + expr_->GetColumn().field_id_, + chunk_id, + data_barrier); + + for (int i = chunk_id == current_chunk_id_ ? current_chunk_pos_ : 0; + i < chunk_size; + ++i) { + if (!chunk_data(i).has_value()) { + valid_res[processed_rows] = false; + } else { + res_value[processed_rows] = + boost::get(chunk_data(i).value()); + } + processed_rows++; + + if (processed_rows >= batch_size_) { + current_chunk_id_ = chunk_id; + current_chunk_pos_ = i + 1; + return res_vec; + } + } + } + return res_vec; + } +} + +} //namespace exec +} // namespace milvus diff --git a/internal/core/src/exec/expression/ColumnExpr.h b/internal/core/src/exec/expression/ColumnExpr.h new file mode 100644 index 0000000000..4b8bdfd936 --- /dev/null +++ b/internal/core/src/exec/expression/ColumnExpr.h @@ -0,0 +1,125 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "common/EasyAssert.h" +#include "common/Types.h" +#include "common/Vector.h" +#include "exec/expression/Expr.h" +#include "segcore/SegmentInterface.h" +#include "segcore/SegmentChunkReader.h" + +namespace milvus { +namespace exec { + +class PhyColumnExpr : public Expr { + public: + PhyColumnExpr(const std::vector>& input, + const std::shared_ptr& expr, + const std::string& name, + const segcore::SegmentInternalInterface* segment, + int64_t active_count, + int64_t batch_size) + : Expr(expr->type(), std::move(input), name), + segment_chunk_reader_(segment, active_count), + batch_size_(batch_size), + expr_(expr) { + is_indexed_ = segment->HasIndex(expr_->GetColumn().field_id_); + if (segment->is_chunked()) { + num_chunk_ = + is_indexed_ + ? segment->num_chunk_index(expr_->GetColumn().field_id_) + : segment->type() == SegmentType::Growing + ? upper_div(segment_chunk_reader_.active_count_, + segment_chunk_reader_.SizePerChunk()) + : segment->num_chunk_data(expr_->GetColumn().field_id_); + } else { + num_chunk_ = + is_indexed_ + ? segment->num_chunk_index(expr_->GetColumn().field_id_) + : upper_div(segment_chunk_reader_.active_count_, + segment_chunk_reader_.SizePerChunk()); + } + AssertInfo( + batch_size_ > 0, + fmt::format("expr batch size should greater than zero, but now: {}", + batch_size_)); + } + + void + Eval(EvalCtx& context, VectorPtr& result) override; + + void + MoveCursor() override { + if (segment_chunk_reader_.segment_->is_chunked()) { + segment_chunk_reader_.MoveCursorForMultipleChunk( + current_chunk_id_, + current_chunk_pos_, + expr_->GetColumn().field_id_, + num_chunk_, + batch_size_); + } else { + segment_chunk_reader_.MoveCursorForSingleChunk( + current_chunk_id_, current_chunk_pos_, num_chunk_, batch_size_); + } + } + + private: + int64_t + GetCurrentRows() const { + if (segment_chunk_reader_.segment_->is_chunked()) { + auto current_rows = + is_indexed_ && segment_chunk_reader_.segment_->type() == + SegmentType::Sealed + ? current_chunk_pos_ + : segment_chunk_reader_.segment_->num_rows_until_chunk( + expr_->GetColumn().field_id_, current_chunk_id_) + + current_chunk_pos_; + return current_rows; + } else { + return segment_chunk_reader_.segment_->type() == + SegmentType::Growing + ? current_chunk_id_ * + segment_chunk_reader_.SizePerChunk() + + current_chunk_pos_ + : current_chunk_pos_; + } + } + + int64_t + GetNextBatchSize(); + + template + VectorPtr + DoEval(); + + private: + bool is_indexed_; + + int64_t num_chunk_{0}; + int64_t current_chunk_id_{0}; + int64_t current_chunk_pos_{0}; + + const segcore::SegmentChunkReader segment_chunk_reader_; + int64_t batch_size_; + std::shared_ptr expr_; +}; + +} //namespace exec +} // namespace milvus diff --git a/internal/core/src/exec/expression/CompareExpr.cpp b/internal/core/src/exec/expression/CompareExpr.cpp index 5bc2e8dab1..8916d366c2 100644 --- a/internal/core/src/exec/expression/CompareExpr.cpp +++ b/internal/core/src/exec/expression/CompareExpr.cpp @@ -15,7 +15,6 @@ // limitations under the License. #include "CompareExpr.h" -#include "common/type_c.h" #include #include "query/Relational.h" @@ -32,212 +31,15 @@ int64_t PhyCompareFilterExpr::GetNextBatchSize() { auto current_rows = GetCurrentRows(); - return current_rows + batch_size_ >= active_count_ - ? active_count_ - current_rows + return current_rows + batch_size_ >= segment_chunk_reader_.active_count_ + ? segment_chunk_reader_.active_count_ - current_rows : batch_size_; } -template -MultipleChunkDataAccessor -PhyCompareFilterExpr::GetChunkData(FieldId field_id, - bool index, - int64_t& current_chunk_id, - int64_t& current_chunk_pos) { - if (index) { - auto& indexing = const_cast&>( - segment_->chunk_scalar_index(field_id, current_chunk_id)); - auto current_chunk_size = segment_->type() == SegmentType::Growing - ? size_per_chunk_ - : active_count_; - - if (indexing.HasRawData()) { - return [&, current_chunk_size]() -> const number { - if (current_chunk_pos >= current_chunk_size) { - current_chunk_id++; - current_chunk_pos = 0; - indexing = const_cast&>( - segment_->chunk_scalar_index(field_id, - current_chunk_id)); - } - auto raw = indexing.Reverse_Lookup(current_chunk_pos); - current_chunk_pos++; - if (!raw.has_value()) { - return std::nullopt; - } - return raw.value(); - }; - } - } - auto chunk_data = - segment_->chunk_data(field_id, current_chunk_id).data(); - auto chunk_valid_data = - segment_->chunk_data(field_id, current_chunk_id).valid_data(); - auto current_chunk_size = segment_->chunk_size(field_id, current_chunk_id); - return - [=, ¤t_chunk_id, ¤t_chunk_pos]() mutable -> const number { - if (current_chunk_pos >= current_chunk_size) { - current_chunk_id++; - current_chunk_pos = 0; - chunk_data = - segment_->chunk_data(field_id, current_chunk_id).data(); - chunk_valid_data = - segment_->chunk_data(field_id, current_chunk_id) - .valid_data(); - current_chunk_size = - segment_->chunk_size(field_id, current_chunk_id); - } - if (chunk_valid_data && !chunk_valid_data[current_chunk_pos]) { - current_chunk_pos++; - return std::nullopt; - } - return chunk_data[current_chunk_pos++]; - }; -} - -template <> -MultipleChunkDataAccessor -PhyCompareFilterExpr::GetChunkData(FieldId field_id, - bool index, - int64_t& current_chunk_id, - int64_t& current_chunk_pos) { - if (index) { - auto& indexing = const_cast&>( - segment_->chunk_scalar_index(field_id, - current_chunk_id)); - auto current_chunk_size = segment_->type() == SegmentType::Growing - ? size_per_chunk_ - : active_count_; - - if (indexing.HasRawData()) { - return [&, current_chunk_size]() mutable -> const number { - if (current_chunk_pos >= current_chunk_size) { - current_chunk_id++; - current_chunk_pos = 0; - indexing = const_cast&>( - segment_->chunk_scalar_index( - field_id, current_chunk_id)); - } - auto raw = indexing.Reverse_Lookup(current_chunk_pos); - current_chunk_pos++; - if (!raw.has_value()) { - return std::nullopt; - } - return raw.value(); - }; - } - } - if (segment_->type() == SegmentType::Growing && - !storage::MmapManager::GetInstance() - .GetMmapConfig() - .growing_enable_mmap) { - auto chunk_data = - segment_->chunk_data(field_id, current_chunk_id) - .data(); - auto chunk_valid_data = - segment_->chunk_data(field_id, current_chunk_id) - .valid_data(); - auto current_chunk_size = - segment_->chunk_size(field_id, current_chunk_id); - return [=, - ¤t_chunk_id, - ¤t_chunk_pos]() mutable -> const number { - if (current_chunk_pos >= current_chunk_size) { - current_chunk_id++; - current_chunk_pos = 0; - chunk_data = - segment_ - ->chunk_data(field_id, current_chunk_id) - .data(); - chunk_valid_data = - segment_ - ->chunk_data(field_id, current_chunk_id) - .valid_data(); - current_chunk_size = - segment_->chunk_size(field_id, current_chunk_id); - } - if (chunk_valid_data && !chunk_valid_data[current_chunk_pos]) { - current_chunk_pos++; - return std::nullopt; - } - return chunk_data[current_chunk_pos++]; - }; - } else { - auto chunk_data = - segment_->chunk_view(field_id, current_chunk_id) - .first.data(); - auto chunk_valid_data = - segment_->chunk_data(field_id, current_chunk_id) - .valid_data(); - auto current_chunk_size = - segment_->chunk_size(field_id, current_chunk_id); - return [=, - ¤t_chunk_id, - ¤t_chunk_pos]() mutable -> const number { - if (current_chunk_pos >= current_chunk_size) { - current_chunk_id++; - current_chunk_pos = 0; - chunk_data = segment_ - ->chunk_view( - field_id, current_chunk_id) - .first.data(); - chunk_valid_data = segment_ - ->chunk_data( - field_id, current_chunk_id) - .valid_data(); - current_chunk_size = - segment_->chunk_size(field_id, current_chunk_id); - } - if (chunk_valid_data && !chunk_valid_data[current_chunk_pos]) { - current_chunk_pos++; - return std::nullopt; - } - - return std::string(chunk_data[current_chunk_pos++]); - }; - } -} - -MultipleChunkDataAccessor -PhyCompareFilterExpr::GetChunkData(DataType data_type, - FieldId field_id, - bool index, - int64_t& current_chunk_id, - int64_t& current_chunk_pos) { - switch (data_type) { - case DataType::BOOL: - return GetChunkData( - field_id, index, current_chunk_id, current_chunk_pos); - case DataType::INT8: - return GetChunkData( - field_id, index, current_chunk_id, current_chunk_pos); - case DataType::INT16: - return GetChunkData( - field_id, index, current_chunk_id, current_chunk_pos); - case DataType::INT32: - return GetChunkData( - field_id, index, current_chunk_id, current_chunk_pos); - case DataType::INT64: - return GetChunkData( - field_id, index, current_chunk_id, current_chunk_pos); - case DataType::FLOAT: - return GetChunkData( - field_id, index, current_chunk_id, current_chunk_pos); - case DataType::DOUBLE: - return GetChunkData( - field_id, index, current_chunk_id, current_chunk_pos); - case DataType::VARCHAR: { - return GetChunkData( - field_id, index, current_chunk_id, current_chunk_pos); - } - default: - PanicInfo(DataTypeInvalid, "unsupported data type: {}", data_type); - } -} - template VectorPtr PhyCompareFilterExpr::ExecCompareExprDispatcher(OpType op) { - if (segment_->is_chunked()) { + if (segment_chunk_reader_.segment_->is_chunked()) { auto real_batch_size = GetNextBatchSize(); if (real_batch_size == 0) { return nullptr; @@ -249,16 +51,18 @@ PhyCompareFilterExpr::ExecCompareExprDispatcher(OpType op) { TargetBitmapView valid_res(res_vec->GetValidRawData(), real_batch_size); valid_res.set(); - auto left = GetChunkData(expr_->left_data_type_, - expr_->left_field_id_, - is_left_indexed_, - left_current_chunk_id_, - left_current_chunk_pos_); - auto right = GetChunkData(expr_->right_data_type_, - expr_->right_field_id_, - is_right_indexed_, - right_current_chunk_id_, - right_current_chunk_pos_); + auto left = + segment_chunk_reader_.GetChunkDataAccessor(expr_->left_data_type_, + expr_->left_field_id_, + is_left_indexed_, + left_current_chunk_id_, + left_current_chunk_pos_); + auto right = segment_chunk_reader_.GetChunkDataAccessor( + expr_->right_data_type_, + expr_->right_field_id_, + is_right_indexed_, + right_current_chunk_id_, + right_current_chunk_pos_); for (int i = 0; i < real_batch_size; ++i) { if (!left().has_value() || !right().has_value()) { res[i] = false; @@ -283,25 +87,30 @@ PhyCompareFilterExpr::ExecCompareExprDispatcher(OpType op) { TargetBitmapView valid_res(res_vec->GetValidRawData(), real_batch_size); valid_res.set(); - auto left_data_barrier = - segment_->num_chunk_data(expr_->left_field_id_); + auto left_data_barrier = segment_chunk_reader_.segment_->num_chunk_data( + expr_->left_field_id_); auto right_data_barrier = - segment_->num_chunk_data(expr_->right_field_id_); + segment_chunk_reader_.segment_->num_chunk_data( + expr_->right_field_id_); int64_t processed_rows = 0; for (int64_t chunk_id = current_chunk_id_; chunk_id < num_chunk_; ++chunk_id) { - auto chunk_size = chunk_id == num_chunk_ - 1 - ? active_count_ - chunk_id * size_per_chunk_ - : size_per_chunk_; - auto left = GetChunkData(expr_->left_data_type_, - expr_->left_field_id_, - chunk_id, - left_data_barrier); - auto right = GetChunkData(expr_->right_data_type_, - expr_->right_field_id_, - chunk_id, - right_data_barrier); + auto chunk_size = + chunk_id == num_chunk_ - 1 + ? segment_chunk_reader_.active_count_ - + chunk_id * segment_chunk_reader_.SizePerChunk() + : segment_chunk_reader_.SizePerChunk(); + auto left = segment_chunk_reader_.GetChunkDataAccessor( + expr_->left_data_type_, + expr_->left_field_id_, + chunk_id, + left_data_barrier); + auto right = segment_chunk_reader_.GetChunkDataAccessor( + expr_->right_data_type_, + expr_->right_field_id_, + chunk_id, + right_data_barrier); for (int i = chunk_id == current_chunk_id_ ? current_chunk_pos_ : 0; i < chunk_size; @@ -328,108 +137,6 @@ PhyCompareFilterExpr::ExecCompareExprDispatcher(OpType op) { } } -template -ChunkDataAccessor -PhyCompareFilterExpr::GetChunkData(FieldId field_id, - int chunk_id, - int data_barrier) { - if (chunk_id >= data_barrier) { - auto& indexing = segment_->chunk_scalar_index(field_id, chunk_id); - if (indexing.HasRawData()) { - return [&indexing](int i) -> const number { - auto raw = indexing.Reverse_Lookup(i); - if (!raw.has_value()) { - return std::nullopt; - } - return raw.value(); - }; - } - } - auto chunk_data = segment_->chunk_data(field_id, chunk_id).data(); - auto chunk_valid_data = - segment_->chunk_data(field_id, chunk_id).valid_data(); - return [chunk_data, chunk_valid_data](int i) -> const number { - if (chunk_valid_data && !chunk_valid_data[i]) { - return std::nullopt; - } - return chunk_data[i]; - }; -} - -template <> -ChunkDataAccessor -PhyCompareFilterExpr::GetChunkData(FieldId field_id, - int chunk_id, - int data_barrier) { - if (chunk_id >= data_barrier) { - auto& indexing = - segment_->chunk_scalar_index(field_id, chunk_id); - if (indexing.HasRawData()) { - return [&indexing](int i) -> const number { - auto raw = indexing.Reverse_Lookup(i); - if (!raw.has_value()) { - return std::nullopt; - } - return raw.value(); - }; - } - } - if (segment_->type() == SegmentType::Growing && - !storage::MmapManager::GetInstance() - .GetMmapConfig() - .growing_enable_mmap) { - auto chunk_data = - segment_->chunk_data(field_id, chunk_id).data(); - auto chunk_valid_data = - segment_->chunk_data(field_id, chunk_id).valid_data(); - return [chunk_data, chunk_valid_data](int i) -> const number { - if (chunk_valid_data && !chunk_valid_data[i]) { - return std::nullopt; - } - return chunk_data[i]; - }; - } else { - auto chunk_info = - segment_->chunk_view(field_id, chunk_id); - auto chunk_data = chunk_info.first.data(); - auto chunk_valid_data = chunk_info.second.data(); - return [chunk_data, chunk_valid_data](int i) -> const number { - if (chunk_valid_data && !chunk_valid_data[i]) { - return std::nullopt; - } - return std::string(chunk_data[i]); - }; - } -} - -ChunkDataAccessor -PhyCompareFilterExpr::GetChunkData(DataType data_type, - FieldId field_id, - int chunk_id, - int data_barrier) { - switch (data_type) { - case DataType::BOOL: - return GetChunkData(field_id, chunk_id, data_barrier); - case DataType::INT8: - return GetChunkData(field_id, chunk_id, data_barrier); - case DataType::INT16: - return GetChunkData(field_id, chunk_id, data_barrier); - case DataType::INT32: - return GetChunkData(field_id, chunk_id, data_barrier); - case DataType::INT64: - return GetChunkData(field_id, chunk_id, data_barrier); - case DataType::FLOAT: - return GetChunkData(field_id, chunk_id, data_barrier); - case DataType::DOUBLE: - return GetChunkData(field_id, chunk_id, data_barrier); - case DataType::VARCHAR: { - return GetChunkData(field_id, chunk_id, data_barrier); - } - default: - PanicInfo(DataTypeInvalid, "unsupported data type: {}", data_type); - } -} - void PhyCompareFilterExpr::Eval(EvalCtx& context, VectorPtr& result) { // For segment both fields has no index, can use SIMD to speed up. diff --git a/internal/core/src/exec/expression/CompareExpr.h b/internal/core/src/exec/expression/CompareExpr.h index 8f4aaaed53..b881535744 100644 --- a/internal/core/src/exec/expression/CompareExpr.h +++ b/internal/core/src/exec/expression/CompareExpr.h @@ -18,7 +18,6 @@ #include #include -#include #include "common/EasyAssert.h" #include "common/Types.h" @@ -26,24 +25,11 @@ #include "common/type_c.h" #include "exec/expression/Expr.h" #include "segcore/SegmentInterface.h" +#include "segcore/SegmentChunkReader.h" namespace milvus { namespace exec { -using number_type = boost::variant; - -using number = std::optional; - -using ChunkDataAccessor = std::function; -using MultipleChunkDataAccessor = std::function; - template struct CompareElementFunc { void @@ -113,31 +99,32 @@ class PhyCompareFilterExpr : public Expr { : Expr(DataType::BOOL, std::move(input), name), left_field_(expr->left_field_id_), right_field_(expr->right_field_id_), - segment_(segment), - active_count_(active_count), + segment_chunk_reader_(segment, active_count), batch_size_(batch_size), expr_(expr) { - is_left_indexed_ = segment_->HasIndex(left_field_); - is_right_indexed_ = segment_->HasIndex(right_field_); - size_per_chunk_ = segment_->size_per_chunk(); - if (segment_->is_chunked()) { + is_left_indexed_ = segment->HasIndex(left_field_); + is_right_indexed_ = segment->HasIndex(right_field_); + if (segment->is_chunked()) { left_num_chunk_ = is_left_indexed_ - ? segment_->num_chunk_index(expr_->left_field_id_) - : segment_->type() == SegmentType::Growing - ? upper_div(active_count_, size_per_chunk_) - : segment_->num_chunk_data(left_field_); + ? segment->num_chunk_index(expr_->left_field_id_) + : segment->type() == SegmentType::Growing + ? upper_div(segment_chunk_reader_.active_count_, + segment_chunk_reader_.SizePerChunk()) + : segment->num_chunk_data(left_field_); right_num_chunk_ = is_right_indexed_ - ? segment_->num_chunk_index(expr_->right_field_id_) - : segment_->type() == SegmentType::Growing - ? upper_div(active_count_, size_per_chunk_) - : segment_->num_chunk_data(right_field_); + ? segment->num_chunk_index(expr_->right_field_id_) + : segment->type() == SegmentType::Growing + ? upper_div(segment_chunk_reader_.active_count_, + segment_chunk_reader_.SizePerChunk()) + : segment->num_chunk_data(right_field_); num_chunk_ = left_num_chunk_; } else { num_chunk_ = is_left_indexed_ - ? segment_->num_chunk_index(expr_->left_field_id_) - : upper_div(active_count_, size_per_chunk_); + ? segment->num_chunk_index(expr_->left_field_id_) + : upper_div(segment_chunk_reader_.active_count_, + segment_chunk_reader_.SizePerChunk()); } AssertInfo( @@ -151,128 +138,60 @@ class PhyCompareFilterExpr : public Expr { void MoveCursor() override { - if (segment_->is_chunked()) { - MoveCursorForMultipleChunk(); + if (segment_chunk_reader_.segment_->is_chunked()) { + segment_chunk_reader_.MoveCursorForMultipleChunk( + left_current_chunk_id_, + left_current_chunk_pos_, + left_field_, + left_num_chunk_, + batch_size_); + segment_chunk_reader_.MoveCursorForMultipleChunk( + right_current_chunk_id_, + right_current_chunk_pos_, + right_field_, + right_num_chunk_, + batch_size_); } else { - MoveCursorForSingleChunk(); - } - } - - void - MoveCursorForMultipleChunk() { - int64_t processed_rows = 0; - for (int64_t chunk_id = left_current_chunk_id_; - chunk_id < left_num_chunk_; - ++chunk_id) { - auto chunk_size = 0; - if (segment_->type() == SegmentType::Growing) { - chunk_size = chunk_id == left_num_chunk_ - 1 - ? active_count_ - chunk_id * size_per_chunk_ - : size_per_chunk_; - } else { - chunk_size = segment_->chunk_size(left_field_, chunk_id); - } - - for (int i = chunk_id == left_current_chunk_id_ - ? left_current_chunk_pos_ - : 0; - i < chunk_size; - ++i) { - if (++processed_rows >= batch_size_) { - left_current_chunk_id_ = chunk_id; - left_current_chunk_pos_ = i + 1; - } - } - } - processed_rows = 0; - for (int64_t chunk_id = right_current_chunk_id_; - chunk_id < right_num_chunk_; - ++chunk_id) { - auto chunk_size = 0; - if (segment_->type() == SegmentType::Growing) { - chunk_size = chunk_id == right_num_chunk_ - 1 - ? active_count_ - chunk_id * size_per_chunk_ - : size_per_chunk_; - } else { - chunk_size = segment_->chunk_size(right_field_, chunk_id); - } - - for (int i = chunk_id == right_current_chunk_id_ - ? right_current_chunk_pos_ - : 0; - i < chunk_size; - ++i) { - if (++processed_rows >= batch_size_) { - right_current_chunk_id_ = chunk_id; - right_current_chunk_pos_ = i + 1; - } - } - } - } - - void - MoveCursorForSingleChunk() { - int64_t processed_rows = 0; - for (int64_t chunk_id = current_chunk_id_; chunk_id < num_chunk_; - ++chunk_id) { - auto chunk_size = chunk_id == num_chunk_ - 1 - ? active_count_ - chunk_id * size_per_chunk_ - : size_per_chunk_; - - for (int i = chunk_id == current_chunk_id_ ? current_chunk_pos_ : 0; - i < chunk_size; - ++i) { - if (++processed_rows >= batch_size_) { - current_chunk_id_ = chunk_id; - current_chunk_pos_ = i + 1; - } - } - } - } - - int64_t - GetCurrentRows() { - if (segment_->is_chunked()) { - auto current_rows = - is_left_indexed_ && segment_->type() == SegmentType::Sealed - ? left_current_chunk_pos_ - : segment_->num_rows_until_chunk(left_field_, - left_current_chunk_id_) + - left_current_chunk_pos_; - return current_rows; - } else { - return segment_->type() == SegmentType::Growing - ? current_chunk_id_ * size_per_chunk_ + - current_chunk_pos_ - : current_chunk_pos_; + segment_chunk_reader_.MoveCursorForSingleChunk( + current_chunk_id_, current_chunk_pos_, num_chunk_, batch_size_); } } private: + int64_t + GetCurrentRows() { + if (segment_chunk_reader_.segment_->is_chunked()) { + auto current_rows = + is_left_indexed_ && segment_chunk_reader_.segment_->type() == + SegmentType::Sealed + ? left_current_chunk_pos_ + : segment_chunk_reader_.segment_->num_rows_until_chunk( + left_field_, left_current_chunk_id_) + + left_current_chunk_pos_; + return current_rows; + } else { + return segment_chunk_reader_.segment_->type() == + SegmentType::Growing + ? current_chunk_id_ * + segment_chunk_reader_.SizePerChunk() + + current_chunk_pos_ + : current_chunk_pos_; + } + } + int64_t GetNextBatchSize(); bool IsStringExpr(); - template - MultipleChunkDataAccessor - GetChunkData(FieldId field_id, - bool index, - int64_t& current_chunk_id, - int64_t& current_chunk_pos); - - template - ChunkDataAccessor - GetChunkData(FieldId field_id, int chunk_id, int data_barrier); - template int64_t ProcessBothDataChunks(FUNC func, TargetBitmapView res, TargetBitmapView valid_res, ValTypes... values) { - if (segment_->is_chunked()) { + if (segment_chunk_reader_.segment_->is_chunked()) { return ProcessBothDataChunksForMultipleChunkchunk_data(left_field_, i); - auto right_chunk = segment_->chunk_data(right_field_, i); + auto left_chunk = + segment_chunk_reader_.segment_->chunk_data(left_field_, i); + auto right_chunk = + segment_chunk_reader_.segment_->chunk_data(right_field_, i); auto data_pos = (i == current_chunk_id_) ? current_chunk_pos_ : 0; auto size = (i == (num_chunk_ - 1)) - ? (segment_->type() == SegmentType::Growing - ? (active_count_ % size_per_chunk_ == 0 - ? size_per_chunk_ - data_pos - : active_count_ % size_per_chunk_ - data_pos) - : active_count_ - data_pos) - : size_per_chunk_ - data_pos; + ? (segment_chunk_reader_.segment_->type() == + SegmentType::Growing + ? (active_count % segment_chunk_reader_ + .SizePerChunk() == + 0 + ? segment_chunk_reader_.SizePerChunk() - + data_pos + : active_count % segment_chunk_reader_ + .SizePerChunk() - + data_pos) + : active_count - data_pos) + : segment_chunk_reader_.SizePerChunk() - data_pos; if (processed_size + size >= batch_size_) { size = batch_size_ - processed_size; @@ -348,19 +276,29 @@ class PhyCompareFilterExpr : public Expr { // only call this function when left and right are not indexed, so they have the same number of chunks for (size_t i = left_current_chunk_id_; i < left_num_chunk_; i++) { - auto left_chunk = segment_->chunk_data(left_field_, i); - auto right_chunk = segment_->chunk_data(right_field_, i); + auto left_chunk = + segment_chunk_reader_.segment_->chunk_data(left_field_, i); + auto right_chunk = + segment_chunk_reader_.segment_->chunk_data(right_field_, i); auto data_pos = (i == left_current_chunk_id_) ? left_current_chunk_pos_ : 0; auto size = 0; - if (segment_->type() == SegmentType::Growing) { - size = (i == (left_num_chunk_ - 1)) - ? (active_count_ % size_per_chunk_ == 0 - ? size_per_chunk_ - data_pos - : active_count_ % size_per_chunk_ - data_pos) - : size_per_chunk_ - data_pos; + if (segment_chunk_reader_.segment_->type() == + SegmentType::Growing) { + size = + (i == (left_num_chunk_ - 1)) + ? (segment_chunk_reader_.active_count_ % + segment_chunk_reader_.SizePerChunk() == + 0 + ? segment_chunk_reader_.SizePerChunk() - data_pos + : segment_chunk_reader_.active_count_ % + segment_chunk_reader_.SizePerChunk() - + data_pos) + : segment_chunk_reader_.SizePerChunk() - data_pos; } else { - size = segment_->chunk_size(left_field_, i) - data_pos; + size = + segment_chunk_reader_.segment_->chunk_size(left_field_, i) - + data_pos; } if (processed_size + size >= batch_size_) { @@ -396,19 +334,6 @@ class PhyCompareFilterExpr : public Expr { return processed_size; } - MultipleChunkDataAccessor - GetChunkData(DataType data_type, - FieldId field_id, - bool index, - int64_t& current_chunk_id, - int64_t& current_chunk_pos); - - ChunkDataAccessor - GetChunkData(DataType data_type, - FieldId field_id, - int chunk_id, - int data_barrier); - template VectorPtr ExecCompareExprDispatcher(OpType op); @@ -432,7 +357,6 @@ class PhyCompareFilterExpr : public Expr { const FieldId right_field_; bool is_left_indexed_; bool is_right_indexed_; - int64_t active_count_{0}; int64_t num_chunk_{0}; int64_t left_num_chunk_{0}; int64_t right_num_chunk_{0}; @@ -442,9 +366,8 @@ class PhyCompareFilterExpr : public Expr { int64_t right_current_chunk_pos_{0}; int64_t current_chunk_id_{0}; int64_t current_chunk_pos_{0}; - int64_t size_per_chunk_{0}; - const segcore::SegmentInternalInterface* segment_; + const segcore::SegmentChunkReader segment_chunk_reader_; int64_t batch_size_; std::shared_ptr expr_; }; diff --git a/internal/core/src/exec/expression/Expr.cpp b/internal/core/src/exec/expression/Expr.cpp index 1332217f47..690c0e490d 100644 --- a/internal/core/src/exec/expression/Expr.cpp +++ b/internal/core/src/exec/expression/Expr.cpp @@ -16,9 +16,12 @@ #include "Expr.h" +#include "common/EasyAssert.h" #include "exec/expression/AlwaysTrueExpr.h" #include "exec/expression/BinaryArithOpEvalRangeExpr.h" #include "exec/expression/BinaryRangeExpr.h" +#include "exec/expression/CallExpr.h" +#include "exec/expression/ColumnExpr.h" #include "exec/expression/CompareExpr.h" #include "exec/expression/ConjunctExpr.h" #include "exec/expression/ExistsExpr.h" @@ -27,6 +30,10 @@ #include "exec/expression/LogicalUnaryExpr.h" #include "exec/expression/TermExpr.h" #include "exec/expression/UnaryExpr.h" +#include "exec/expression/ValueExpr.h" + +#include + namespace milvus { namespace exec { @@ -156,8 +163,14 @@ CompileExpression(const expr::TypedExprPtr& expr, }; auto input_types = GetTypes(compiled_inputs); - if (auto call = dynamic_cast(expr.get())) { - // TODO: support function register and search mode + if (auto call = std::dynamic_pointer_cast(expr)) { + result = std::make_shared( + compiled_inputs, + call, + "PhyCallExpr", + context->get_segment(), + context->get_active_count(), + context->query_config()->get_expr_batch_size()); } else if (auto casted_expr = std::dynamic_pointer_cast< const milvus::expr::UnaryRangeFilterExpr>(expr)) { result = std::make_shared( @@ -251,6 +264,29 @@ CompileExpression(const expr::TypedExprPtr& expr, context->get_segment(), context->get_active_count(), context->query_config()->get_expr_batch_size()); + } else if (auto value_expr = + std::dynamic_pointer_cast( + expr)) { + // used for function call arguments, may emit any type + result = std::make_shared( + compiled_inputs, + value_expr, + "PhyValueExpr", + context->get_segment(), + context->get_active_count(), + context->query_config()->get_expr_batch_size()); + } else if (auto column_expr = + std::dynamic_pointer_cast( + expr)) { + result = std::make_shared( + compiled_inputs, + column_expr, + "PhyColumnExpr", + context->get_segment(), + context->get_active_count(), + context->query_config()->get_expr_batch_size()); + } else { + PanicInfo(ExprInvalid, "unsupport expr: ", expr->ToString()); } return result; } @@ -261,4 +297,4 @@ OptimizeCompiledExprs(ExecContext* context, const std::vector& exprs) { } } // namespace exec -} // namespace milvus \ No newline at end of file +} // namespace milvus diff --git a/internal/core/src/exec/expression/Expr.h b/internal/core/src/exec/expression/Expr.h index 12a48af424..73a772f3e4 100644 --- a/internal/core/src/exec/expression/Expr.h +++ b/internal/core/src/exec/expression/Expr.h @@ -77,6 +77,7 @@ class Expr { DataType type_; const std::vector> inputs_; std::string name_; + // NOTE: unused std::shared_ptr vector_func_; }; @@ -84,6 +85,9 @@ using ExprPtr = std::shared_ptr; using SkipFunc = bool (*)(const milvus::SkipIndex&, FieldId, int); +/* + * The expr has only one column. + */ class SegmentExpr : public Expr { public: SegmentExpr(const std::vector&& input, @@ -762,7 +766,8 @@ CompileExpression(const expr::TypedExprPtr& expr, class ExprSet { public: explicit ExprSet(const std::vector& logical_exprs, - ExecContext* exec_ctx) { + ExecContext* exec_ctx) + : exec_ctx_(exec_ctx) { exprs_ = CompileExpressions(logical_exprs, exec_ctx); } diff --git a/internal/core/src/exec/expression/ValueExpr.cpp b/internal/core/src/exec/expression/ValueExpr.cpp new file mode 100644 index 0000000000..80330f7f15 --- /dev/null +++ b/internal/core/src/exec/expression/ValueExpr.cpp @@ -0,0 +1,101 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ValueExpr.h" +#include "common/Vector.h" + +namespace milvus { +namespace exec { + +void +PhyValueExpr::Eval(EvalCtx& context, VectorPtr& result) { + int64_t real_batch_size = current_pos_ + batch_size_ >= active_count_ + ? active_count_ - current_pos_ + : batch_size_; + + if (real_batch_size == 0) { + result = nullptr; + return; + } + + switch (expr_->type()) { + case DataType::NONE: + // null + result = std::make_shared>( + expr_->type(), real_batch_size, false, 1); + break; + case DataType::BOOL: + result = std::make_shared>( + expr_->type(), + real_batch_size, + expr_->GetGenericValue().bool_val()); + break; + case DataType::INT8: + result = std::make_shared>( + expr_->type(), + real_batch_size, + expr_->GetGenericValue().int64_val()); + break; + case DataType::INT16: + result = std::make_shared>( + expr_->type(), + real_batch_size, + expr_->GetGenericValue().int64_val()); + break; + case DataType::INT32: + result = std::make_shared>( + expr_->type(), + real_batch_size, + expr_->GetGenericValue().int64_val()); + break; + case DataType::INT64: + result = std::make_shared>( + expr_->type(), + real_batch_size, + expr_->GetGenericValue().int64_val()); + break; + case DataType::FLOAT: + result = std::make_shared>( + expr_->type(), + real_batch_size, + expr_->GetGenericValue().float_val()); + break; + case DataType::DOUBLE: + result = std::make_shared>( + expr_->type(), + real_batch_size, + expr_->GetGenericValue().float_val()); + break; + case DataType::STRING: + case DataType::VARCHAR: + result = std::make_shared>( + expr_->type(), + real_batch_size, + expr_->GetGenericValue().string_val()); + break; + // TODO: json and array type + case DataType::ARRAY: + case DataType::JSON: + default: + PanicInfo(DataTypeInvalid, + "PhyValueExpr not support data type " + + GetDataTypeName(expr_->type())); + } + current_pos_ += real_batch_size; +} + +} //namespace exec +} // namespace milvus diff --git a/internal/core/src/exec/expression/ValueExpr.h b/internal/core/src/exec/expression/ValueExpr.h new file mode 100644 index 0000000000..044f46ac39 --- /dev/null +++ b/internal/core/src/exec/expression/ValueExpr.h @@ -0,0 +1,67 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "common/EasyAssert.h" +#include "common/Types.h" +#include "common/Vector.h" +#include "exec/expression/Expr.h" +#include "segcore/SegmentInterface.h" + +namespace milvus { +namespace exec { + +class PhyValueExpr : public Expr { + public: + PhyValueExpr(const std::vector>& input, + const std::shared_ptr expr, + const std::string& name, + const segcore::SegmentInternalInterface* segment, + int64_t active_count, + int64_t batch_size) + : Expr(expr->type(), std::move(input), name), + expr_(expr), + active_count_(active_count), + batch_size_(batch_size) { + AssertInfo(input.empty(), + "PhyValueExpr should not have input, but got " + + std::to_string(input.size())); + } + + void + Eval(EvalCtx& context, VectorPtr& result) override; + + void + MoveCursor() override { + int64_t real_batch_size = current_pos_ + batch_size_ >= active_count_ + ? active_count_ - current_pos_ + : batch_size_; + + current_pos_ += real_batch_size; + } + + private: + std::shared_ptr expr_; + const int64_t active_count_; + int64_t current_pos_{0}; + const int64_t batch_size_; +}; + +} //namespace exec +} // namespace milvus diff --git a/internal/core/src/exec/expression/function/FunctionFactory.cpp b/internal/core/src/exec/expression/function/FunctionFactory.cpp new file mode 100644 index 0000000000..4f621f6506 --- /dev/null +++ b/internal/core/src/exec/expression/function/FunctionFactory.cpp @@ -0,0 +1,83 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "exec/expression/function/FunctionFactory.h" +#include +#include "exec/expression/function/impl/StringFunctions.h" +#include "log/Log.h" + +namespace milvus { +namespace exec { +namespace expression { + +std::string +FilterFunctionRegisterKey::ToString() const { + std::ostringstream oss; + oss << func_name << "("; + for (size_t i = 0; i < func_param_type_list.size(); ++i) { + oss << GetDataTypeName(func_param_type_list[i]); + if (i < func_param_type_list.size() - 1) { + oss << ", "; + } + } + + oss << ")"; + return oss.str(); +} + +FunctionFactory& +FunctionFactory::Instance() { + static FunctionFactory factory; + return factory; +} + +void +FunctionFactory::Initialize() { + std::call_once(init_flag_, &FunctionFactory::RegisterAllFunctions, this); +} + +void +FunctionFactory::RegisterAllFunctions() { + RegisterFilterFunction( + "empty", {DataType::VARCHAR}, function::EmptyVarchar); + RegisterFilterFunction("starts_with", + {DataType::VARCHAR, DataType::VARCHAR}, + function::StartsWithVarchar); + LOG_INFO("{} functions registered", GetFilterFunctionNum()); +} + +void +FunctionFactory::RegisterFilterFunction( + std::string func_name, + std::vector func_param_type_list, + FilterFunctionPtr func) { + filter_function_map_[FilterFunctionRegisterKey{ + func_name, func_param_type_list}] = func; +} + +const FilterFunctionPtr +FunctionFactory::GetFilterFunction( + const FilterFunctionRegisterKey& func_sig) const { + auto iter = filter_function_map_.find(func_sig); + if (iter != filter_function_map_.end()) { + return iter->second; + } + return nullptr; +} + +} // namespace expression +} // namespace exec +} // namespace milvus diff --git a/internal/core/src/exec/expression/function/FunctionFactory.h b/internal/core/src/exec/expression/function/FunctionFactory.h new file mode 100644 index 0000000000..0563408c9d --- /dev/null +++ b/internal/core/src/exec/expression/function/FunctionFactory.h @@ -0,0 +1,111 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include +#include "common/Vector.h" + +namespace milvus { +namespace exec { + +class EvalCtx; +class Expr; +class PhyCallExpr; + +namespace expression { + +struct FilterFunctionRegisterKey { + std::string func_name; + std::vector func_param_type_list; + + std::string + ToString() const; + + bool + operator==(const FilterFunctionRegisterKey& other) const { + return func_name == other.func_name && + func_param_type_list == other.func_param_type_list; + } + + struct Hash { + size_t + operator()(const FilterFunctionRegisterKey& s) const { + size_t h1 = std::hash{}(s.func_name); + size_t h2 = boost::hash_range(s.func_param_type_list.begin(), + s.func_param_type_list.end()); + return h1 ^ h2; + } + }; +}; + +using FilterFunctionParameter = std::shared_ptr; +using FilterFunctionReturn = VectorPtr; +using FilterFunctionPtr = void (*)(const RowVector& args, + FilterFunctionReturn& result); + +class FunctionFactory { + public: + static FunctionFactory& + Instance(); + + void + Initialize(); + + void + RegisterFilterFunction(std::string func_name, + std::vector func_param_type_list, + FilterFunctionPtr func); + + const FilterFunctionPtr + GetFilterFunction(const FilterFunctionRegisterKey& func_sig) const; + + size_t + GetFilterFunctionNum() const { + return filter_function_map_.size(); + } + + std::vector + ListAllFilterFunctions() const { + std::vector result; + for (const auto& [key, value] : filter_function_map_) { + result.push_back(key); + } + return result; + } + + private: + void + RegisterAllFunctions(); + + std::unordered_map + filter_function_map_; + std::once_flag init_flag_; +}; + +} // namespace expression +} // namespace exec +} // namespace milvus diff --git a/internal/core/src/exec/expression/function/FunctionImplUtils.cpp b/internal/core/src/exec/expression/function/FunctionImplUtils.cpp new file mode 100644 index 0000000000..cbfa22957d --- /dev/null +++ b/internal/core/src/exec/expression/function/FunctionImplUtils.cpp @@ -0,0 +1,30 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "exec/expression/function/FunctionImplUtils.h" +#include "common/EasyAssert.h" + +namespace milvus::exec::expression::function { + +void +CheckVarcharOrStringType(std::shared_ptr& vec) { + if (vec->type() != DataType::VARCHAR && vec->type() != DataType::STRING) { + PanicInfo(ExprInvalid, + "invalid argument type, expect VARCHAR or STRING, actual {}", + vec->type()); + } +} + +} // namespace milvus::exec::expression::function diff --git a/internal/core/src/exec/expression/function/FunctionImplUtils.h b/internal/core/src/exec/expression/function/FunctionImplUtils.h new file mode 100644 index 0000000000..2def96e82d --- /dev/null +++ b/internal/core/src/exec/expression/function/FunctionImplUtils.h @@ -0,0 +1,25 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "common/Vector.h" + +namespace milvus::exec::expression::function { + +void +CheckVarcharOrStringType(std::shared_ptr& vec); + +} // namespace milvus::exec::expression::function diff --git a/internal/core/src/exec/expression/function/impl/Empty.cpp b/internal/core/src/exec/expression/function/impl/Empty.cpp new file mode 100644 index 0000000000..02ceabf7b8 --- /dev/null +++ b/internal/core/src/exec/expression/function/impl/Empty.cpp @@ -0,0 +1,59 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "exec/expression/function/FunctionImplUtils.h" +#include "exec/expression/function/impl/StringFunctions.h" + +#include +#include +#include "common/EasyAssert.h" +#include "exec/expression/function/FunctionFactory.h" + +namespace milvus { +namespace exec { +namespace expression { +namespace function { + +void +EmptyVarchar(const RowVector& args, FilterFunctionReturn& result) { + if (args.childrens().size() != 1) { + PanicInfo(ExprInvalid, + "invalid argument count, expect 1, actual {}", + args.childrens().size()); + } + auto arg = args.child(0); + auto vec = std::dynamic_pointer_cast(arg); + Assert(vec != nullptr); + CheckVarcharOrStringType(vec); + TargetBitmap bitmap(vec->size(), false); + TargetBitmap valid_bitmap(vec->size(), true); + for (size_t i = 0; i < vec->size(); ++i) { + if (vec->ValidAt(i)) { + bitmap[i] = reinterpret_cast( + vec->RawValueAt(i, sizeof(std::string))) + ->empty(); + } else { + valid_bitmap[i] = false; + } + } + result = std::make_shared(std::move(bitmap), + std::move(valid_bitmap)); +} + +} // namespace function +} // namespace expression +} // namespace exec +} // namespace milvus diff --git a/internal/core/src/exec/expression/function/impl/StartsWith.cpp b/internal/core/src/exec/expression/function/impl/StartsWith.cpp new file mode 100644 index 0000000000..c3a27a778b --- /dev/null +++ b/internal/core/src/exec/expression/function/impl/StartsWith.cpp @@ -0,0 +1,64 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "exec/expression/function/FunctionImplUtils.h" +#include "exec/expression/function/impl/StringFunctions.h" + +#include +#include +#include "common/EasyAssert.h" +#include "exec/expression/function/FunctionFactory.h" + +namespace milvus { +namespace exec { +namespace expression { +namespace function { + +void +StartsWithVarchar(const RowVector& args, FilterFunctionReturn& result) { + if (args.childrens().size() != 2) { + PanicInfo(ExprInvalid, + "invalid argument count, expect 2, actual {}", + args.childrens().size()); + } + auto strs = std::dynamic_pointer_cast(args.child(0)); + Assert(strs != nullptr); + CheckVarcharOrStringType(strs); + auto prefixes = std::dynamic_pointer_cast(args.child(1)); + Assert(prefixes != nullptr); + CheckVarcharOrStringType(prefixes); + + TargetBitmap bitmap(strs->size(), false); + TargetBitmap valid_bitmap(strs->size(), true); + for (size_t i = 0; i < strs->size(); ++i) { + if (strs->ValidAt(i) && prefixes->ValidAt(i)) { + auto* str_ptr = reinterpret_cast( + strs->RawValueAt(i, sizeof(std::string))); + auto* prefix_ptr = reinterpret_cast( + prefixes->RawValueAt(i, sizeof(std::string))); + bitmap.set(i, str_ptr->find(*prefix_ptr) == 0); + } else { + valid_bitmap[i] = false; + } + } + result = std::make_shared(std::move(bitmap), + std::move(valid_bitmap)); +} + +} // namespace function +} // namespace expression +} // namespace exec +} // namespace milvus diff --git a/internal/core/src/exec/expression/function/impl/StringFunctions.h b/internal/core/src/exec/expression/function/impl/StringFunctions.h new file mode 100644 index 0000000000..b8d17bda99 --- /dev/null +++ b/internal/core/src/exec/expression/function/impl/StringFunctions.h @@ -0,0 +1,35 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "common/Vector.h" +#include "exec/expression/function/FunctionFactory.h" + +namespace milvus { +namespace exec { +namespace expression { +namespace function { + +void +EmptyVarchar(const RowVector& args, FilterFunctionReturn& result); + +void +StartsWithVarchar(const RowVector& args, FilterFunctionReturn& result); + +} // namespace function +} // namespace expression +} // namespace exec +} // namespace milvus diff --git a/internal/core/src/exec/expression/function/init_c.cpp b/internal/core/src/exec/expression/function/init_c.cpp new file mode 100644 index 0000000000..072bd866a5 --- /dev/null +++ b/internal/core/src/exec/expression/function/init_c.cpp @@ -0,0 +1,23 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "exec/expression/function/init_c.h" +#include "exec/expression/function/FunctionFactory.h" + +void +InitExecExpressionFunctionFactory() { + milvus::exec::expression::FunctionFactory::Instance().Initialize(); +} diff --git a/internal/core/src/exec/expression/function/init_c.h b/internal/core/src/exec/expression/function/init_c.h new file mode 100644 index 0000000000..c7dbd3867f --- /dev/null +++ b/internal/core/src/exec/expression/function/init_c.h @@ -0,0 +1,28 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#ifdef __cplusplus +extern "C" { +#endif + +void +InitExecExpressionFunctionFactory(); + +#ifdef __cplusplus +}; +#endif diff --git a/internal/core/src/exec/operator/FilterBitsNode.cpp b/internal/core/src/exec/operator/FilterBitsNode.cpp index f7716a3fa1..3bf6d03968 100644 --- a/internal/core/src/exec/operator/FilterBitsNode.cpp +++ b/internal/core/src/exec/operator/FilterBitsNode.cpp @@ -76,13 +76,24 @@ PhyFilterBitsNode::GetOutput() { "PhyFilterBitsNode result size should be size one and not " "be nullptr"); - auto col_vec = std::dynamic_pointer_cast(results_[0]); - auto col_vec_size = col_vec->size(); - TargetBitmapView view(col_vec->GetRawData(), col_vec_size); - bitset.append(view); - TargetBitmapView valid_view(col_vec->GetValidRawData(), col_vec_size); - valid_bitset.append(valid_view); - num_processed_rows_ += col_vec_size; + if (auto col_vec = + std::dynamic_pointer_cast(results_[0])) { + if (col_vec->IsBitmap()) { + auto col_vec_size = col_vec->size(); + TargetBitmapView view(col_vec->GetRawData(), col_vec_size); + bitset.append(view); + TargetBitmapView valid_view(col_vec->GetValidRawData(), + col_vec_size); + valid_bitset.append(valid_view); + num_processed_rows_ += col_vec_size; + } else { + PanicInfo(ExprInvalid, + "PhyFilterBitsNode result should be bitmap"); + } + } else { + PanicInfo(ExprInvalid, + "PhyFilterBitsNode result should be ColumnVector"); + } } bitset.flip(); Assert(bitset.size() == need_process_rows_); @@ -102,4 +113,4 @@ PhyFilterBitsNode::GetOutput() { } } // namespace exec -} // namespace milvus \ No newline at end of file +} // namespace milvus diff --git a/internal/core/src/expr/ITypeExpr.h b/internal/core/src/expr/ITypeExpr.h index f41b76d1a2..320e616b4e 100644 --- a/internal/core/src/expr/ITypeExpr.h +++ b/internal/core/src/expr/ITypeExpr.h @@ -21,6 +21,7 @@ #include #include +#include "exec/expression/function/FunctionFactory.h" #include "common/Exception.h" #include "common/Schema.h" #include "common/Types.h" @@ -211,6 +212,7 @@ class ITypeExpr { using TypedExprPtr = std::shared_ptr; +// NOTE: unused class InputTypeExpr : public ITypeExpr { public: InputTypeExpr(DataType type) : ITypeExpr(type) { @@ -224,42 +226,7 @@ class InputTypeExpr : public ITypeExpr { using InputTypeExprPtr = std::shared_ptr; -class CallTypeExpr : public ITypeExpr { - public: - CallTypeExpr(DataType type, - const std::vector& inputs, - std::string fun_name) - : ITypeExpr{type, std::move(inputs)} { - } - - virtual ~CallTypeExpr() = default; - - virtual const std::string& - name() const { - return name_; - } - - std::string - ToString() const override { - std::string str{}; - str += name(); - str += "("; - for (size_t i = 0; i < inputs_.size(); ++i) { - if (i != 0) { - str += ","; - } - str += inputs_[i]->ToString(); - } - str += ")"; - return str; - } - - private: - std::string name_; -}; - -using CallTypeExprPtr = std::shared_ptr; - +// NOTE: unused class FieldAccessTypeExpr : public ITypeExpr { public: FieldAccessTypeExpr(DataType type, const std::string& name) @@ -311,6 +278,71 @@ class ITypeFilterExpr : public ITypeExpr { virtual ~ITypeFilterExpr() = default; }; +class ColumnExpr : public ITypeExpr { + public: + explicit ColumnExpr(const ColumnInfo& column) + : ITypeExpr(column.data_type_), column_(column) { + } + + const ColumnInfo& + GetColumn() const { + return column_; + } + + std::string + ToString() const override { + std::stringstream ss; + ss << "ColumnExpr: {columnInfo:" << column_.ToString() << "}"; + return ss.str(); + } + + private: + const ColumnInfo column_; +}; + +class ValueExpr : public ITypeExpr { + public: + explicit ValueExpr(const proto::plan::GenericValue& val) + : ITypeExpr(DataType::NONE), val_(val) { + switch (val.val_case()) { + case proto::plan::GenericValue::ValCase::kBoolVal: + type_ = DataType::BOOL; + break; + case proto::plan::GenericValue::ValCase::kInt64Val: + type_ = DataType::INT64; + break; + case proto::plan::GenericValue::ValCase::kFloatVal: + type_ = DataType::FLOAT; + break; + case proto::plan::GenericValue::ValCase::kStringVal: + type_ = DataType::VARCHAR; + break; + case proto::plan::GenericValue::ValCase::kArrayVal: + type_ = DataType::ARRAY; + break; + case proto::plan::GenericValue::ValCase::VAL_NOT_SET: + type_ = DataType::NONE; + break; + } + } + + std::string + ToString() const override { + std::stringstream ss; + ss << "ValueExpr: {" + << " val:" << val_.DebugString() << "}"; + return ss.str(); + } + + const proto::plan::GenericValue + GetGenericValue() const { + return val_; + } + + private: + const proto::plan::GenericValue val_; +}; + class UnaryRangeFilterExpr : public ITypeFilterExpr { public: explicit UnaryRangeFilterExpr(const ColumnInfo& column, @@ -595,6 +627,46 @@ class BinaryArithOpEvalRangeExpr : public ITypeFilterExpr { const proto::plan::GenericValue value_; }; +class CallExpr : public ITypeFilterExpr { + public: + CallExpr(const std::string fun_name, + const std::vector& parameters, + const exec::expression::FilterFunctionPtr function_ptr) + : fun_name_(std::move(fun_name)), function_ptr_(function_ptr) { + inputs_.insert(inputs_.end(), parameters.begin(), parameters.end()); + } + + virtual ~CallExpr() = default; + + const std::string& + fun_name() const { + return fun_name_; + } + + const exec::expression::FilterFunctionPtr + function_ptr() const { + return function_ptr_; + } + + std::string + ToString() const override { + std::string parameters; + for (auto& e : inputs_) { + parameters += e->ToString(); + parameters += ", "; + } + return fmt::format("CallExpr:[Function Name: {}, Parameters: {}]", + fun_name_, + parameters); + } + + private: + const std::string fun_name_; + const exec::expression::FilterFunctionPtr function_ptr_; +}; + +using CallExprPtr = std::shared_ptr; + class CompareExpr : public ITypeFilterExpr { public: CompareExpr(const FieldId& left_field, diff --git a/internal/core/src/query/PlanProto.cpp b/internal/core/src/query/PlanProto.cpp index d61ad31ce9..b3ddb01dc0 100644 --- a/internal/core/src/query/PlanProto.cpp +++ b/internal/core/src/query/PlanProto.cpp @@ -15,9 +15,11 @@ #include #include +#include #include "common/VectorTrait.h" #include "common/EasyAssert.h" +#include "exec/expression/function/FunctionFactory.h" #include "pb/plan.pb.h" #include "query/Utils.h" #include "knowhere/comp/materialized_view.h" @@ -256,6 +258,29 @@ ProtoParser::ParseBinaryRangeExprs( expr_pb.upper_inclusive()); } +expr::TypedExprPtr +ProtoParser::ParseCallExprs(const proto::plan::CallExpr& expr_pb) { + std::vector parameters; + std::vector func_param_type_list; + for (auto& param_expr : expr_pb.function_parameters()) { + // function parameter can be any type + auto e = this->ParseExprs(param_expr, TypeIsAny); + parameters.push_back(e); + func_param_type_list.push_back(e->type()); + } + auto& factory = exec::expression::FunctionFactory::Instance(); + exec::expression::FilterFunctionRegisterKey func_sig{ + expr_pb.function_name(), std::move(func_param_type_list)}; + + auto function = factory.GetFilterFunction(func_sig); + if (function == nullptr) { + PanicInfo(ExprInvalid, + "function " + func_sig.ToString() + " not found. "); + } + return std::make_shared( + expr_pb.function_name(), parameters, function); +} + expr::TypedExprPtr ProtoParser::ParseCompareExprs(const proto::plan::CompareExpr& expr_pb) { auto& left_column_info = expr_pb.left_column_info(); @@ -349,45 +374,80 @@ ProtoParser::ParseJsonContainsExprs( std::move(values)); } +expr::TypedExprPtr +ProtoParser::ParseColumnExprs(const proto::plan::ColumnExpr& expr_pb) { + return std::make_shared(expr_pb.info()); +} + +expr::TypedExprPtr +ProtoParser::ParseValueExprs(const proto::plan::ValueExpr& expr_pb) { + return std::make_shared(expr_pb.value()); +} + expr::TypedExprPtr ProtoParser::CreateAlwaysTrueExprs() { return std::make_shared(); } expr::TypedExprPtr -ProtoParser::ParseExprs(const proto::plan::Expr& expr_pb) { +ProtoParser::ParseExprs(const proto::plan::Expr& expr_pb, + TypeCheckFunction type_check) { using ppe = proto::plan::Expr; + expr::TypedExprPtr result; switch (expr_pb.expr_case()) { case ppe::kUnaryRangeExpr: { - return ParseUnaryRangeExprs(expr_pb.unary_range_expr()); + result = ParseUnaryRangeExprs(expr_pb.unary_range_expr()); + break; } case ppe::kBinaryExpr: { - return ParseBinaryExprs(expr_pb.binary_expr()); + result = ParseBinaryExprs(expr_pb.binary_expr()); + break; } case ppe::kUnaryExpr: { - return ParseUnaryExprs(expr_pb.unary_expr()); + result = ParseUnaryExprs(expr_pb.unary_expr()); + break; } case ppe::kTermExpr: { - return ParseTermExprs(expr_pb.term_expr()); + result = ParseTermExprs(expr_pb.term_expr()); + break; } case ppe::kBinaryRangeExpr: { - return ParseBinaryRangeExprs(expr_pb.binary_range_expr()); + result = ParseBinaryRangeExprs(expr_pb.binary_range_expr()); + break; } case ppe::kCompareExpr: { - return ParseCompareExprs(expr_pb.compare_expr()); + result = ParseCompareExprs(expr_pb.compare_expr()); + break; } case ppe::kBinaryArithOpEvalRangeExpr: { - return ParseBinaryArithOpEvalRangeExprs( + result = ParseBinaryArithOpEvalRangeExprs( expr_pb.binary_arith_op_eval_range_expr()); + break; } case ppe::kExistsExpr: { - return ParseExistExprs(expr_pb.exists_expr()); + result = ParseExistExprs(expr_pb.exists_expr()); + break; } case ppe::kAlwaysTrueExpr: { - return CreateAlwaysTrueExprs(); + result = CreateAlwaysTrueExprs(); + break; } case ppe::kJsonContainsExpr: { - return ParseJsonContainsExprs(expr_pb.json_contains_expr()); + result = ParseJsonContainsExprs(expr_pb.json_contains_expr()); + break; + } + case ppe::kCallExpr: { + result = ParseCallExprs(expr_pb.call_expr()); + break; + } + // may emit various types + case ppe::kColumnExpr: { + result = ParseColumnExprs(expr_pb.column_expr()); + break; + } + case ppe::kValueExpr: { + result = ParseValueExprs(expr_pb.value_expr()); + break; } default: { std::string s; @@ -396,6 +456,11 @@ ProtoParser::ParseExprs(const proto::plan::Expr& expr_pb) { std::string("unsupported expr proto node: ") + s); } } + if (type_check(result->type())) { + return result; + } + PanicInfo( + ExprInvalid, "expr type check failed, actual type: {}", result->type()); } } // namespace milvus::query diff --git a/internal/core/src/query/PlanProto.h b/internal/core/src/query/PlanProto.h index 63673cefb9..28aaaaa0cb 100644 --- a/internal/core/src/query/PlanProto.h +++ b/internal/core/src/query/PlanProto.h @@ -23,6 +23,17 @@ namespace milvus::query { class ProtoParser { + public: + using TypeCheckFunction = std::function; + static bool + TypeIsBool(const DataType type) { + return type == DataType::BOOL; + } + static bool + TypeIsAny(const DataType) { + return true; + } + public: explicit ProtoParser(const Schema& schema) : schema(schema) { } @@ -40,10 +51,15 @@ class ProtoParser { CreateRetrievePlan(const proto::plan::PlanNode& plan_node_proto); expr::TypedExprPtr - ParseUnaryRangeExprs(const proto::plan::UnaryRangeExpr& expr_pb); + ParseExprs(const proto::plan::Expr& expr_pb, + TypeCheckFunction type_check = TypeIsBool); + + private: + expr::TypedExprPtr + CreateAlwaysTrueExprs(); expr::TypedExprPtr - ParseExprs(const proto::plan::Expr& expr_pb); + ParseBinaryExprs(const proto::plan::BinaryExpr& expr_pb); expr::TypedExprPtr ParseBinaryArithOpEvalRangeExprs( @@ -52,18 +68,15 @@ class ProtoParser { expr::TypedExprPtr ParseBinaryRangeExprs(const proto::plan::BinaryRangeExpr& expr_pb); + expr::TypedExprPtr + ParseCallExprs(const proto::plan::CallExpr& expr_pb); + + expr::TypedExprPtr + ParseColumnExprs(const proto::plan::ColumnExpr& expr_pb); + expr::TypedExprPtr ParseCompareExprs(const proto::plan::CompareExpr& expr_pb); - expr::TypedExprPtr - ParseTermExprs(const proto::plan::TermExpr& expr_pb); - - expr::TypedExprPtr - ParseUnaryExprs(const proto::plan::UnaryExpr& expr_pb); - - expr::TypedExprPtr - ParseBinaryExprs(const proto::plan::BinaryExpr& expr_pb); - expr::TypedExprPtr ParseExistExprs(const proto::plan::ExistsExpr& expr_pb); @@ -71,14 +84,23 @@ class ProtoParser { ParseJsonContainsExprs(const proto::plan::JSONContainsExpr& expr_pb); expr::TypedExprPtr - CreateAlwaysTrueExprs(); + ParseTermExprs(const proto::plan::TermExpr& expr_pb); + + expr::TypedExprPtr + ParseUnaryExprs(const proto::plan::UnaryExpr& expr_pb); + + expr::TypedExprPtr + ParseUnaryRangeExprs(const proto::plan::UnaryRangeExpr& expr_pb); + + expr::TypedExprPtr + ParseValueExprs(const proto::plan::ValueExpr& expr_pb); private: const Schema& schema; }; } // namespace milvus::query -// + template <> struct fmt::formatter : formatter { diff --git a/internal/core/src/segcore/SegmentChunkReader.cpp b/internal/core/src/segcore/SegmentChunkReader.cpp new file mode 100644 index 0000000000..432f0832c8 --- /dev/null +++ b/internal/core/src/segcore/SegmentChunkReader.cpp @@ -0,0 +1,327 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "segcore/SegmentChunkReader.h" + +namespace milvus::segcore { +template +MultipleChunkDataAccessor +SegmentChunkReader::GetChunkDataAccessor(FieldId field_id, + bool index, + int64_t& current_chunk_id, + int64_t& current_chunk_pos) const { + if (index) { + auto& indexing = const_cast&>( + segment_->chunk_scalar_index(field_id, current_chunk_id)); + auto current_chunk_size = segment_->type() == SegmentType::Growing + ? SizePerChunk() + : active_count_; + + if (indexing.HasRawData()) { + return [&, current_chunk_size]() -> const data_access_type { + if (current_chunk_pos >= current_chunk_size) { + current_chunk_id++; + current_chunk_pos = 0; + indexing = const_cast&>( + segment_->chunk_scalar_index(field_id, + current_chunk_id)); + } + auto raw = indexing.Reverse_Lookup(current_chunk_pos); + current_chunk_pos++; + if (!raw.has_value()) { + return std::nullopt; + } + return raw.value(); + }; + } + } + auto chunk_data = + segment_->chunk_data(field_id, current_chunk_id).data(); + auto chunk_valid_data = + segment_->chunk_data(field_id, current_chunk_id).valid_data(); + auto current_chunk_size = segment_->chunk_size(field_id, current_chunk_id); + return [=, + ¤t_chunk_id, + ¤t_chunk_pos]() mutable -> const data_access_type { + if (current_chunk_pos >= current_chunk_size) { + current_chunk_id++; + current_chunk_pos = 0; + chunk_data = + segment_->chunk_data(field_id, current_chunk_id).data(); + chunk_valid_data = + segment_->chunk_data(field_id, current_chunk_id) + .valid_data(); + current_chunk_size = + segment_->chunk_size(field_id, current_chunk_id); + } + if (chunk_valid_data && !chunk_valid_data[current_chunk_pos]) { + current_chunk_pos++; + return std::nullopt; + } + return chunk_data[current_chunk_pos++]; + }; +} + +template <> +MultipleChunkDataAccessor +SegmentChunkReader::GetChunkDataAccessor( + FieldId field_id, + bool index, + int64_t& current_chunk_id, + int64_t& current_chunk_pos) const { + if (index) { + auto& indexing = const_cast&>( + segment_->chunk_scalar_index(field_id, + current_chunk_id)); + auto current_chunk_size = segment_->type() == SegmentType::Growing + ? SizePerChunk() + : active_count_; + + if (indexing.HasRawData()) { + return [&, current_chunk_size]() mutable -> const data_access_type { + if (current_chunk_pos >= current_chunk_size) { + current_chunk_id++; + current_chunk_pos = 0; + indexing = const_cast&>( + segment_->chunk_scalar_index( + field_id, current_chunk_id)); + } + auto raw = indexing.Reverse_Lookup(current_chunk_pos); + current_chunk_pos++; + if (!raw.has_value()) { + return std::nullopt; + } + return raw.value(); + }; + } + } + if (segment_->type() == SegmentType::Growing && + !storage::MmapManager::GetInstance() + .GetMmapConfig() + .growing_enable_mmap) { + auto chunk_data = + segment_->chunk_data(field_id, current_chunk_id) + .data(); + auto chunk_valid_data = + segment_->chunk_data(field_id, current_chunk_id) + .valid_data(); + auto current_chunk_size = + segment_->chunk_size(field_id, current_chunk_id); + return [=, + ¤t_chunk_id, + ¤t_chunk_pos]() mutable -> const data_access_type { + if (current_chunk_pos >= current_chunk_size) { + current_chunk_id++; + current_chunk_pos = 0; + chunk_data = + segment_ + ->chunk_data(field_id, current_chunk_id) + .data(); + chunk_valid_data = + segment_ + ->chunk_data(field_id, current_chunk_id) + .valid_data(); + current_chunk_size = + segment_->chunk_size(field_id, current_chunk_id); + } + if (chunk_valid_data && !chunk_valid_data[current_chunk_pos]) { + current_chunk_pos++; + return std::nullopt; + } + return chunk_data[current_chunk_pos++]; + }; + } else { + auto chunk_data = + segment_->chunk_view(field_id, current_chunk_id) + .first.data(); + auto chunk_valid_data = + segment_->chunk_data(field_id, current_chunk_id) + .valid_data(); + auto current_chunk_size = + segment_->chunk_size(field_id, current_chunk_id); + return [=, + ¤t_chunk_id, + ¤t_chunk_pos]() mutable -> const data_access_type { + if (current_chunk_pos >= current_chunk_size) { + current_chunk_id++; + current_chunk_pos = 0; + chunk_data = segment_ + ->chunk_view( + field_id, current_chunk_id) + .first.data(); + chunk_valid_data = segment_ + ->chunk_data( + field_id, current_chunk_id) + .valid_data(); + current_chunk_size = + segment_->chunk_size(field_id, current_chunk_id); + } + if (chunk_valid_data && !chunk_valid_data[current_chunk_pos]) { + current_chunk_pos++; + return std::nullopt; + } + + return std::string(chunk_data[current_chunk_pos++]); + }; + } +} + +MultipleChunkDataAccessor +SegmentChunkReader::GetChunkDataAccessor(DataType data_type, + FieldId field_id, + bool index, + int64_t& current_chunk_id, + int64_t& current_chunk_pos) const { + switch (data_type) { + case DataType::BOOL: + return GetChunkDataAccessor( + field_id, index, current_chunk_id, current_chunk_pos); + case DataType::INT8: + return GetChunkDataAccessor( + field_id, index, current_chunk_id, current_chunk_pos); + case DataType::INT16: + return GetChunkDataAccessor( + field_id, index, current_chunk_id, current_chunk_pos); + case DataType::INT32: + return GetChunkDataAccessor( + field_id, index, current_chunk_id, current_chunk_pos); + case DataType::INT64: + return GetChunkDataAccessor( + field_id, index, current_chunk_id, current_chunk_pos); + case DataType::FLOAT: + return GetChunkDataAccessor( + field_id, index, current_chunk_id, current_chunk_pos); + case DataType::DOUBLE: + return GetChunkDataAccessor( + field_id, index, current_chunk_id, current_chunk_pos); + case DataType::VARCHAR: { + return GetChunkDataAccessor( + field_id, index, current_chunk_id, current_chunk_pos); + } + default: + PanicInfo(DataTypeInvalid, "unsupported data type: {}", data_type); + } +} + +template +ChunkDataAccessor +SegmentChunkReader::GetChunkDataAccessor(FieldId field_id, + int chunk_id, + int data_barrier) const { + if (chunk_id >= data_barrier) { + auto& indexing = segment_->chunk_scalar_index(field_id, chunk_id); + if (indexing.HasRawData()) { + return [&indexing](int i) -> const data_access_type { + auto raw = indexing.Reverse_Lookup(i); + if (!raw.has_value()) { + return std::nullopt; + } + return raw.value(); + }; + } + } + auto chunk_data = segment_->chunk_data(field_id, chunk_id).data(); + auto chunk_valid_data = + segment_->chunk_data(field_id, chunk_id).valid_data(); + return [chunk_data, chunk_valid_data](int i) -> const data_access_type { + if (chunk_valid_data && !chunk_valid_data[i]) { + return std::nullopt; + } + return chunk_data[i]; + }; +} + +template <> +ChunkDataAccessor +SegmentChunkReader::GetChunkDataAccessor(FieldId field_id, + int chunk_id, + int data_barrier) const { + if (chunk_id >= data_barrier) { + auto& indexing = + segment_->chunk_scalar_index(field_id, chunk_id); + if (indexing.HasRawData()) { + return [&indexing](int i) -> const data_access_type { + auto raw = indexing.Reverse_Lookup(i); + if (!raw.has_value()) { + return std::nullopt; + } + return raw.value(); + }; + } + } + if (segment_->type() == SegmentType::Growing && + !storage::MmapManager::GetInstance() + .GetMmapConfig() + .growing_enable_mmap) { + auto chunk_data = + segment_->chunk_data(field_id, chunk_id).data(); + auto chunk_valid_data = + segment_->chunk_data(field_id, chunk_id).valid_data(); + return [chunk_data, chunk_valid_data](int i) -> const data_access_type { + if (chunk_valid_data && !chunk_valid_data[i]) { + return std::nullopt; + } + return chunk_data[i]; + }; + } else { + auto chunk_info = + segment_->chunk_view(field_id, chunk_id); + return [chunk_data = std::move(chunk_info.first), + chunk_valid_data = std::move(chunk_info.second)]( + int i) -> const data_access_type { + if (i < chunk_valid_data.size() && !chunk_valid_data[i]) { + return std::nullopt; + } + return std::string(chunk_data[i]); + }; + } +} + +ChunkDataAccessor +SegmentChunkReader::GetChunkDataAccessor(DataType data_type, + FieldId field_id, + int chunk_id, + int data_barrier) const { + switch (data_type) { + case DataType::BOOL: + return GetChunkDataAccessor(field_id, chunk_id, data_barrier); + case DataType::INT8: + return GetChunkDataAccessor( + field_id, chunk_id, data_barrier); + case DataType::INT16: + return GetChunkDataAccessor( + field_id, chunk_id, data_barrier); + case DataType::INT32: + return GetChunkDataAccessor( + field_id, chunk_id, data_barrier); + case DataType::INT64: + return GetChunkDataAccessor( + field_id, chunk_id, data_barrier); + case DataType::FLOAT: + return GetChunkDataAccessor( + field_id, chunk_id, data_barrier); + case DataType::DOUBLE: + return GetChunkDataAccessor( + field_id, chunk_id, data_barrier); + case DataType::VARCHAR: { + return GetChunkDataAccessor( + field_id, chunk_id, data_barrier); + } + default: + PanicInfo(DataTypeInvalid, "unsupported data type: {}", data_type); + } +} + +} // namespace milvus::segcore diff --git a/internal/core/src/segcore/SegmentChunkReader.h b/internal/core/src/segcore/SegmentChunkReader.h new file mode 100644 index 0000000000..9c662e9318 --- /dev/null +++ b/internal/core/src/segcore/SegmentChunkReader.h @@ -0,0 +1,141 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include + +#include "common/Types.h" +#include "segcore/SegmentInterface.h" + +namespace milvus::segcore { + +using data_access_type = std::optional>; + +using ChunkDataAccessor = std::function; +using MultipleChunkDataAccessor = std::function; + +class SegmentChunkReader { + public: + SegmentChunkReader(const segcore::SegmentInternalInterface* segment, + int64_t active_count) + : segment_(segment), + active_count_(active_count), + size_per_chunk_(segment->size_per_chunk()) { + } + + MultipleChunkDataAccessor + GetChunkDataAccessor(DataType data_type, + FieldId field_id, + bool index, + int64_t& current_chunk_id, + int64_t& current_chunk_pos) const; + + ChunkDataAccessor + GetChunkDataAccessor(DataType data_type, + FieldId field_id, + int chunk_id, + int data_barrier) const; + + void + MoveCursorForMultipleChunk(int64_t& current_chunk_id, + int64_t& current_chunk_pos, + const FieldId field_id, + const int64_t num_chunk, + const int64_t batch_size) const { + int64_t processed_rows = 0; + for (int64_t chunk_id = current_chunk_id; chunk_id < num_chunk; + ++chunk_id) { + int64_t chunk_size = 0; + if (segment_->type() == SegmentType::Growing) { + const auto size_per_chunk = SizePerChunk(); + chunk_size = chunk_id == num_chunk - 1 + ? active_count_ - chunk_id * size_per_chunk + : size_per_chunk; + } else { + chunk_size = segment_->chunk_size(field_id, chunk_id); + } + + for (int64_t i = chunk_id == current_chunk_id ? current_chunk_pos + : 0; + i < chunk_size; + ++i) { + if (++processed_rows >= batch_size) { + current_chunk_id = chunk_id; + current_chunk_pos = i + 1; + } + } + } + } + + void + MoveCursorForSingleChunk(int64_t& current_chunk_id, + int64_t& current_chunk_pos, + const int64_t num_chunk, + const int64_t batch_size) const { + int64_t processed_rows = 0; + for (int64_t chunk_id = current_chunk_id; chunk_id < num_chunk; + ++chunk_id) { + auto chunk_size = chunk_id == num_chunk - 1 + ? active_count_ - chunk_id * SizePerChunk() + : SizePerChunk(); + + for (int64_t i = chunk_id == current_chunk_id ? current_chunk_pos + : 0; + i < chunk_size; + ++i) { + if (++processed_rows >= batch_size) { + current_chunk_id = chunk_id; + current_chunk_pos = i + 1; + } + } + } + } + + int64_t + SizePerChunk() const { + return size_per_chunk_; + } + + const int64_t active_count_; + const segcore::SegmentInternalInterface* segment_; + + private: + template + MultipleChunkDataAccessor + GetChunkDataAccessor(FieldId field_id, + bool index, + int64_t& current_chunk_id, + int64_t& current_chunk_pos) const; + + template + ChunkDataAccessor + GetChunkDataAccessor(FieldId field_id, + int chunk_id, + int data_barrier) const; + + const int64_t size_per_chunk_; +}; + +} // namespace milvus::segcore diff --git a/internal/core/src/segcore/SegmentInterface.h b/internal/core/src/segcore/SegmentInterface.h index fe09f7c3af..fecb45fec6 100644 --- a/internal/core/src/segcore/SegmentInterface.h +++ b/internal/core/src/segcore/SegmentInterface.h @@ -149,9 +149,7 @@ class SegmentInternalInterface : public SegmentInterface { template std::pair, FixedVector> chunk_view(FieldId field_id, int64_t chunk_id) const { - auto chunk_info = chunk_view_impl(field_id, chunk_id); - auto string_views = chunk_info.first; - auto valid_data = chunk_info.second; + auto [string_views, valid_data] = chunk_view_impl(field_id, chunk_id); if constexpr (std::is_same_v) { return std::make_pair(std::move(string_views), std::move(valid_data)); diff --git a/internal/core/unittest/CMakeLists.txt b/internal/core/unittest/CMakeLists.txt index 3fdd8cbf89..3b4109b6c6 100644 --- a/internal/core/unittest/CMakeLists.txt +++ b/internal/core/unittest/CMakeLists.txt @@ -48,6 +48,7 @@ set(MILVUS_TEST_FILES test_expr.cpp test_expr_materialized_view.cpp test_float16.cpp + test_function.cpp test_futures.cpp test_group_by.cpp test_growing.cpp diff --git a/internal/core/unittest/test_exec.cpp b/internal/core/unittest/test_exec.cpp index e26e911997..6e65c1f089 100644 --- a/internal/core/unittest/test_exec.cpp +++ b/internal/core/unittest/test_exec.cpp @@ -27,6 +27,7 @@ #include "exec/QueryContext.h" #include "expr/ITypeExpr.h" #include "exec/expression/Expr.h" +#include "exec/expression/function/FunctionFactory.h" using namespace milvus; using namespace milvus::exec; @@ -40,6 +41,10 @@ class TaskTest : public testing::TestWithParam { using namespace milvus; using namespace milvus::query; using namespace milvus::segcore; + milvus::exec::expression::FunctionFactory& factory = + milvus::exec::expression::FunctionFactory::Instance(); + factory.Initialize(); + auto schema = std::make_shared(); auto vec_fid = schema->AddDebugField( "fakevec", GetParam(), 16, knowhere::metric::L2); @@ -113,6 +118,62 @@ INSTANTIATE_TEST_SUITE_P(TaskTestSuite, ::testing::Values(DataType::VECTOR_FLOAT, DataType::VECTOR_SPARSE_FLOAT)); +TEST_P(TaskTest, RegisterFunction) { + milvus::exec::expression::FunctionFactory& factory = + milvus::exec::expression::FunctionFactory::Instance(); + ASSERT_EQ(factory.GetFilterFunctionNum(), 2); + auto all_functions = factory.ListAllFilterFunctions(); + // for (auto& f : all_functions) { + // std::cout << f.toString() << std::endl; + // } + + auto func_ptr = factory.GetFilterFunction( + milvus::exec::expression::FilterFunctionRegisterKey{ + "empty", {DataType::VARCHAR}}); + ASSERT_TRUE(func_ptr != nullptr); +} + +TEST_P(TaskTest, CallExprEmpty) { + expr::ColumnInfo col(field_map_["string1"], DataType::VARCHAR); + std::vector parameters; + parameters.push_back(std::make_shared(col)); + milvus::exec::expression::FunctionFactory& factory = + milvus::exec::expression::FunctionFactory::Instance(); + auto empty_function_ptr = factory.GetFilterFunction( + milvus::exec::expression::FilterFunctionRegisterKey{ + "empty", {DataType::VARCHAR}}); + auto call_expr = std::make_shared( + "empty", parameters, empty_function_ptr); + ASSERT_EQ(call_expr->inputs().size(), 1); + std::vector sources; + auto filter_node = std::make_shared( + "plannode id 1", call_expr, sources); + auto plan = plan::PlanFragment(filter_node); + auto query_context = std::make_shared( + "test1", + segment_.get(), + 1000000, + MAX_TIMESTAMP, + std::make_shared( + std::unordered_map{})); + + auto start = std::chrono::steady_clock::now(); + auto task = Task::Create("task_call_expr_empty", plan, 0, query_context); + int64_t num_rows = 0; + for (;;) { + auto result = task->Next(); + if (!result) { + break; + } + num_rows += result->size(); + } + auto cost = std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count(); + std::cout << "cost: " << cost << "us" << std::endl; + EXPECT_EQ(num_rows, num_rows_); +} + TEST_P(TaskTest, UnaryExpr) { ::milvus::proto::plan::GenericValue value; value.set_int64_val(-1); @@ -355,4 +416,4 @@ TEST_P(TaskTest, CompileInputs_or_with_and) { "PhyUnaryRangeFilterExpr"); } } -} \ No newline at end of file +} diff --git a/internal/core/unittest/test_expr.cpp b/internal/core/unittest/test_expr.cpp index 51079941a5..1b08bca0c5 100644 --- a/internal/core/unittest/test_expr.cpp +++ b/internal/core/unittest/test_expr.cpp @@ -15,6 +15,7 @@ #include #include #include +#include #include #include #include @@ -33,6 +34,7 @@ #include "index/IndexFactory.h" #include "exec/expression/Expr.h" #include "exec/Task.h" +#include "exec/expression/function/FunctionFactory.h" #include "expr/ITypeExpr.h" #include "index/BitmapIndex.h" #include "index/InvertedIndexTantivy.h" @@ -1739,6 +1741,175 @@ TEST_P(ExprTest, TestTermNullable) { } } +TEST_P(ExprTest, TestCall) { + milvus::exec::expression::FunctionFactory& factory = + milvus::exec::expression::FunctionFactory::Instance(); + factory.Initialize(); + + auto schema = std::make_shared(); + auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); + auto varchar_fid = schema->AddDebugField("address", DataType::VARCHAR); + schema->set_primary_field_id(varchar_fid); + + auto seg = CreateGrowingSegment(schema, empty_index_meta); + int N = 1000; + std::vector address_col; + int num_iters = 1; + for (int iter = 0; iter < num_iters; ++iter) { + auto raw_data = DataGen(schema, N, iter); + auto new_address_col = raw_data.get_col(varchar_fid); + address_col.insert( + address_col.end(), new_address_col.begin(), new_address_col.end()); + seg->PreInsert(N); + seg->Insert(iter * N, + N, + raw_data.row_ids_.data(), + raw_data.timestamps_.data(), + raw_data.raw_); + } + + auto seg_promote = dynamic_cast(seg.get()); + + std::tuple> test_cases[] = { + {R"(vector_anns: < + field_id: 100 + predicates: < + call_expr: < + function_name: "empty" + function_parameters: < + column_expr: < + info: < + field_id: 101 + data_type: VarChar + > + > + > + > + > + query_info: < + topk: 10 + round_decimal: 3 + metric_type: "L2" + search_params: "{\"nprobe\": 10}" + > + placeholder_tag: "$0" + >)", + [](std::string& v) { return v.empty(); }}, + {R"(vector_anns: < + field_id: 100 + predicates: < + call_expr: < + function_name: "starts_with" + function_parameters: < + column_expr: < + info: < + field_id: 101 + data_type: VarChar + > + > + > + function_parameters: < + column_expr: < + info: < + field_id: 101 + data_type: VarChar + > + > + > + > + > + query_info: < + topk: 10 + round_decimal: 3 + metric_type: "L2" + search_params: "{\"nprobe\": 10}" + > + placeholder_tag: "$0" + >)", + [](std::string&) { return true; }}}; + + for (auto& [raw_plan, ref_func] : test_cases) { + auto plan_str = translate_text_plan_with_metric_type(raw_plan); + auto plan = + CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); + BitsetType final; + final = ExecuteQueryExpr( + plan->plan_node_->plannodes_->sources()[0]->sources()[0], + seg_promote, + N * num_iters, + MAX_TIMESTAMP); + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + + ASSERT_EQ(ans, ref_func(address_col[i])) + << "@" << i << "!!" << address_col[i]; + } + } + + std::string incorrect_test_cases[] = { + R"(vector_anns: < + field_id: 100 + predicates: < + call_expr: < + function_name: "empty" + function_parameters: < + column_expr: < + info: < + field_id: 101 + data_type: VarChar + > + > + > + function_parameters: < + column_expr: < + info: < + field_id: 101 + data_type: VarChar + > + > + > + > + > + query_info: < + topk: 10 + round_decimal: 3 + metric_type: "L2" + search_params: "{\"nprobe\": 10}" + > + placeholder_tag: "$0" + >)", + R"(vector_anns: < + field_id: 100 + predicates: < + call_expr: < + function_name: "starts_with" + function_parameters: < + column_expr: < + info: < + field_id: 101 + data_type: VarChar + > + > + > + > + > + query_info: < + topk: 10 + round_decimal: 3 + metric_type: "L2" + search_params: "{\"nprobe\": 10}" + > + placeholder_tag: "$0" + >)"}; + for (auto& raw_plan : incorrect_test_cases) { + auto plan_str = translate_text_plan_with_metric_type(raw_plan); + EXPECT_ANY_THROW( + CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size())); + } +} + TEST_P(ExprTest, TestCompare) { std::vector>> testcases = { diff --git a/internal/core/unittest/test_function.cpp b/internal/core/unittest/test_function.cpp new file mode 100644 index 0000000000..3b98aa6dff --- /dev/null +++ b/internal/core/unittest/test_function.cpp @@ -0,0 +1,239 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#include +#include + +#include "common/Types.h" +#include "common/Vector.h" +#include "exec/expression/function/impl/StringFunctions.h" + +using namespace milvus; +using namespace milvus::exec::expression::function; + +class FunctionTest : public ::testing::Test { + protected: + void + SetUp() override { + } + + void + TearDown() override { + } +}; + +TEST_F(FunctionTest, Empty) { + std::vector arg_vec; + auto col1 = + std::make_shared(milvus::DataType::STRING, 15); + auto* col1_data = col1->RawAsValues(); + for (int i = 0; i <= 10; ++i) { + col1_data[i] = std::to_string(i); + } + for (int i = 11; i < 15; ++i) { + col1_data[i] = ""; + } + arg_vec.push_back(col1); + milvus::RowVector args(std::move(arg_vec)); + VectorPtr result; + EmptyVarchar(args, result); + + auto result_vec = std::dynamic_pointer_cast(result); + ASSERT_NE(result_vec, nullptr); + TargetBitmapView bitmap(result_vec->GetRawData(), result_vec->size()); + for (int i = 0; i < 15; ++i) { + EXPECT_TRUE(result_vec->ValidAt(i)) << "i: " << i; + EXPECT_EQ(bitmap[i], i >= 11) << "i: " << i; + } +} + +TEST_F(FunctionTest, EmptyNull) { + std::vector arg_vec; + auto col1 = std::make_shared( + milvus::DataType::STRING, 15, 15); + arg_vec.push_back(col1); + milvus::RowVector args(std::move(arg_vec)); + VectorPtr result; + EmptyVarchar(args, result); + + auto result_vec = std::dynamic_pointer_cast(result); + ASSERT_NE(result_vec, nullptr); + TargetBitmapView bitmap(result_vec->GetRawData(), result_vec->size()); + for (int i = 0; i < 15; ++i) { + EXPECT_FALSE(result_vec->ValidAt(i)) << "i: " << i; + } +} + +TEST_F(FunctionTest, EmptyConstant) { + std::vector arg_vec; + auto col1 = std::make_shared>( + milvus::DataType::STRING, 15, "xx"); + arg_vec.push_back(col1); + milvus::RowVector args(std::move(arg_vec)); + VectorPtr result; + EmptyVarchar(args, result); + + auto result_vec = std::dynamic_pointer_cast(result); + ASSERT_NE(result_vec, nullptr); + TargetBitmapView bitmap(result_vec->GetRawData(), result_vec->size()); + for (int i = 0; i < 15; ++i) { + EXPECT_TRUE(result_vec->ValidAt(i)) << "i: " << i; + EXPECT_FALSE(bitmap[i]) << "i: " << i; + } +} + +TEST_F(FunctionTest, EmptyIncorrectArgs) { + VectorPtr result; + + std::vector arg_vec; + + // empty args + milvus::RowVector empty_args(arg_vec); + EXPECT_ANY_THROW(EmptyVarchar(empty_args, result)); + + // incorrect type, expected string or varchar + arg_vec.push_back(std::make_shared( + milvus::DataType::INT32, 15, 15)); + milvus::RowVector int_args(arg_vec); + EXPECT_ANY_THROW(EmptyVarchar(int_args, result)); + + arg_vec.clear(); + + // incorrect size, expected 1 + arg_vec.push_back(std::make_shared( + milvus::DataType::STRING, 15, 15)); + arg_vec.push_back(std::make_shared( + milvus::DataType::STRING, 15, 15)); + milvus::RowVector multi_args(arg_vec); + EXPECT_ANY_THROW(EmptyVarchar(multi_args, result)); +} + +static constexpr int64_t STARTS_WITH_ROW_COUNT = 8; +static void +InitStrsForStartWith(std::shared_ptr col1) { + auto* col1_data = col1->RawAsValues(); + col1_data[0] = "123"; + col1_data[1] = ""; + col1_data[2] = ""; + TargetBitmapView valid_bitmap_col1(col1->GetValidRawData(), col1->size()); + valid_bitmap_col1[2] = false; + col1_data[3] = "aaabbbaaa"; + col1_data[4] = "aaabbbaaa"; + col1_data[5] = "xx"; + col1_data[6] = "xx"; + col1_data[7] = "1"; +} + +static void +StartWithCheck(VectorPtr result, bool valid[], bool expected[]) { + auto result_vec = std::dynamic_pointer_cast(result); + ASSERT_NE(result_vec, nullptr); + TargetBitmapView bitmap(result_vec->GetRawData(), result_vec->size()); + for (int i = 0; i < STARTS_WITH_ROW_COUNT; ++i) { + EXPECT_EQ(result_vec->ValidAt(i), valid[i]) << "i: " << i; + EXPECT_EQ(bitmap[i], expected[i]) << "i: " << i; + } +} + +TEST_F(FunctionTest, StartsWithColumnVector) { + std::vector arg_vec; + + auto col1 = std::make_shared(milvus::DataType::STRING, + STARTS_WITH_ROW_COUNT); + InitStrsForStartWith(col1); + arg_vec.push_back(col1); + + auto col2 = std::make_shared(milvus::DataType::STRING, + STARTS_WITH_ROW_COUNT); + auto* col2_data = col2->RawAsValues(); + col2_data[0] = "12"; + col2_data[1] = "1"; + col2_data[3] = "aaabbbaaac"; + col2_data[4] = "aaabbbaax"; + col2_data[5] = ""; + col2_data[6] = ""; + TargetBitmapView valid_bitmap_col2(col2->GetValidRawData(), col2->size()); + valid_bitmap_col2[6] = false; + col2_data[7] = "124"; + arg_vec.push_back(col2); + + milvus::RowVector args(std::move(arg_vec)); + + bool valid[STARTS_WITH_ROW_COUNT] = { + true, true, false, true, true, true, false, true}; + bool expected[STARTS_WITH_ROW_COUNT] = { + true, false, false, false, false, true, false, false}; + + VectorPtr result; + StartsWithVarchar(args, result); + StartWithCheck(result, valid, expected); +} + +TEST_F(FunctionTest, StartsWithColumnAndConstantVector) { + std::vector arg_vec; + + auto col1 = std::make_shared(milvus::DataType::STRING, + STARTS_WITH_ROW_COUNT); + InitStrsForStartWith(col1); + arg_vec.push_back(col1); + + const std::string constant_str = "1"; + auto col2 = std::make_shared>( + milvus::DataType::STRING, STARTS_WITH_ROW_COUNT, constant_str); + arg_vec.push_back(col2); + + milvus::RowVector args(std::move(arg_vec)); + + bool valid[STARTS_WITH_ROW_COUNT] = { + true, true, false, true, true, true, true, true}; + bool expected[STARTS_WITH_ROW_COUNT] = { + true, false, false, false, false, false, false, true}; + + VectorPtr result; + StartsWithVarchar(args, result); + StartWithCheck(result, valid, expected); +} + +TEST_F(FunctionTest, StartsWithIncorrectArgs) { + VectorPtr result; + + std::vector arg_vec; + + // empty args + milvus::RowVector empty_args(arg_vec); + EXPECT_ANY_THROW(StartsWithVarchar(empty_args, result)); + + // incorrect type, expected string or varchar + arg_vec.push_back(std::make_shared( + milvus::DataType::STRING, 15, 15)); + arg_vec.push_back(std::make_shared( + milvus::DataType::INT32, 15, 15)); + milvus::RowVector string_int_args(arg_vec); + EXPECT_ANY_THROW(StartsWithVarchar(string_int_args, result)); + + arg_vec.clear(); + + // incorrect size, expected 2 + arg_vec.push_back(std::make_shared( + milvus::DataType::STRING, 15, 15)); + milvus::RowVector single_args(arg_vec); + EXPECT_ANY_THROW(StartsWithVarchar(single_args, result)); + + arg_vec.clear(); + arg_vec.push_back(std::make_shared( + milvus::DataType::STRING, 15, 15)); + arg_vec.push_back(std::make_shared( + milvus::DataType::STRING, 15, 15)); + arg_vec.push_back(std::make_shared( + milvus::DataType::STRING, 15, 15)); + milvus::RowVector three_args(arg_vec); + EXPECT_ANY_THROW(StartsWithVarchar(three_args, result)); +} diff --git a/internal/parser/planparserv2/Plan.g4 b/internal/parser/planparserv2/Plan.g4 index c0644436a2..f28f471228 100644 --- a/internal/parser/planparserv2/Plan.g4 +++ b/internal/parser/planparserv2/Plan.g4 @@ -24,6 +24,7 @@ expr: | (JSONContainsAll | ArrayContainsAll)'('expr',' expr')' # JSONContainsAll | (JSONContainsAny | ArrayContainsAny)'('expr',' expr')' # JSONContainsAny | ArrayLength'('(Identifier | JSONIdentifier)')' # ArrayLength + | Identifier '(' ( expr (',' expr )* ','? )? ')' # Call | expr op1 = (LT | LE) (Identifier | JSONIdentifier) op2 = (LT | LE) expr # Range | expr op1 = (GT | GE) (Identifier | JSONIdentifier) op2 = (GT | GE) expr # ReverseRange | expr op = (LT | LE | GT | GE) expr # Relational diff --git a/internal/parser/planparserv2/check_identical_test.go b/internal/parser/planparserv2/check_identical_test.go index 9f48aec504..321920e8c6 100644 --- a/internal/parser/planparserv2/check_identical_test.go +++ b/internal/parser/planparserv2/check_identical_test.go @@ -14,17 +14,27 @@ func TestCheckIdentical(t *testing.T) { helper, err := typeutil.CreateSchemaHelper(schema) assert.NoError(t, err) - exprStr1 := `not (((Int64Field > 0) and (FloatField <= 20.0)) or ((Int32Field in [1, 2, 3]) and (VarCharField < "str")))` - exprStr2 := `Int32Field in [1, 2, 3]` + exprStr1Arr := []string{ + `not (((Int64Field > 0) and (FloatField <= 20.0)) or ((Int32Field in [1, 2, 3]) and (VarCharField < "str")))`, + `f1()`, + } + exprStr2Arr := []string{ + `Int32Field in [1, 2, 3]`, + `f2(Int32Field, Int64Field)`, + } + for i := range exprStr1Arr { + exprStr1 := exprStr1Arr[i] + exprStr2 := exprStr2Arr[i] - expr1, err := ParseExpr(helper, exprStr1) - assert.NoError(t, err) - expr2, err := ParseExpr(helper, exprStr2) - assert.NoError(t, err) + expr1, err := ParseExpr(helper, exprStr1) + assert.NoError(t, err) + expr2, err := ParseExpr(helper, exprStr2) + assert.NoError(t, err) - assert.True(t, CheckPredicatesIdentical(expr1, expr1)) - assert.True(t, CheckPredicatesIdentical(expr2, expr2)) - assert.False(t, CheckPredicatesIdentical(expr1, expr2)) + assert.True(t, CheckPredicatesIdentical(expr1, expr1)) + assert.True(t, CheckPredicatesIdentical(expr2, expr2)) + assert.False(t, CheckPredicatesIdentical(expr1, expr2)) + } } func TestCheckQueryInfoIdentical(t *testing.T) { diff --git a/internal/parser/planparserv2/generated/Plan.interp b/internal/parser/planparserv2/generated/Plan.interp index 41ef66eeef..8cb8890c4f 100644 --- a/internal/parser/planparserv2/generated/Plan.interp +++ b/internal/parser/planparserv2/generated/Plan.interp @@ -101,4 +101,4 @@ expr atn: -[4, 1, 46, 123, 2, 0, 7, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 5, 0, 18, 8, 0, 10, 0, 12, 0, 21, 9, 0, 1, 0, 3, 0, 24, 8, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 3, 0, 64, 8, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 3, 0, 80, 8, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 5, 0, 118, 8, 0, 10, 0, 12, 0, 121, 9, 0, 1, 0, 0, 1, 0, 1, 0, 0, 12, 2, 0, 15, 16, 28, 29, 2, 0, 32, 32, 35, 35, 2, 0, 33, 33, 36, 36, 2, 0, 34, 34, 37, 37, 2, 0, 42, 42, 44, 44, 1, 0, 17, 19, 1, 0, 15, 16, 1, 0, 21, 22, 1, 0, 6, 7, 1, 0, 8, 9, 1, 0, 6, 9, 1, 0, 10, 11, 154, 0, 63, 1, 0, 0, 0, 2, 3, 6, 0, -1, 0, 3, 64, 5, 40, 0, 0, 4, 64, 5, 41, 0, 0, 5, 64, 5, 39, 0, 0, 6, 64, 5, 43, 0, 0, 7, 64, 5, 42, 0, 0, 8, 64, 5, 44, 0, 0, 9, 10, 5, 1, 0, 0, 10, 11, 3, 0, 0, 0, 11, 12, 5, 2, 0, 0, 12, 64, 1, 0, 0, 0, 13, 14, 5, 3, 0, 0, 14, 19, 3, 0, 0, 0, 15, 16, 5, 4, 0, 0, 16, 18, 3, 0, 0, 0, 17, 15, 1, 0, 0, 0, 18, 21, 1, 0, 0, 0, 19, 17, 1, 0, 0, 0, 19, 20, 1, 0, 0, 0, 20, 23, 1, 0, 0, 0, 21, 19, 1, 0, 0, 0, 22, 24, 5, 4, 0, 0, 23, 22, 1, 0, 0, 0, 23, 24, 1, 0, 0, 0, 24, 25, 1, 0, 0, 0, 25, 26, 5, 5, 0, 0, 26, 64, 1, 0, 0, 0, 27, 64, 5, 31, 0, 0, 28, 29, 5, 14, 0, 0, 29, 30, 5, 1, 0, 0, 30, 31, 5, 42, 0, 0, 31, 32, 5, 4, 0, 0, 32, 33, 5, 43, 0, 0, 33, 64, 5, 2, 0, 0, 34, 35, 7, 0, 0, 0, 35, 64, 3, 0, 0, 19, 36, 37, 7, 1, 0, 0, 37, 38, 5, 1, 0, 0, 38, 39, 3, 0, 0, 0, 39, 40, 5, 4, 0, 0, 40, 41, 3, 0, 0, 0, 41, 42, 5, 2, 0, 0, 42, 64, 1, 0, 0, 0, 43, 44, 7, 2, 0, 0, 44, 45, 5, 1, 0, 0, 45, 46, 3, 0, 0, 0, 46, 47, 5, 4, 0, 0, 47, 48, 3, 0, 0, 0, 48, 49, 5, 2, 0, 0, 49, 64, 1, 0, 0, 0, 50, 51, 7, 3, 0, 0, 51, 52, 5, 1, 0, 0, 52, 53, 3, 0, 0, 0, 53, 54, 5, 4, 0, 0, 54, 55, 3, 0, 0, 0, 55, 56, 5, 2, 0, 0, 56, 64, 1, 0, 0, 0, 57, 58, 5, 38, 0, 0, 58, 59, 5, 1, 0, 0, 59, 60, 7, 4, 0, 0, 60, 64, 5, 2, 0, 0, 61, 62, 5, 13, 0, 0, 62, 64, 3, 0, 0, 1, 63, 2, 1, 0, 0, 0, 63, 4, 1, 0, 0, 0, 63, 5, 1, 0, 0, 0, 63, 6, 1, 0, 0, 0, 63, 7, 1, 0, 0, 0, 63, 8, 1, 0, 0, 0, 63, 9, 1, 0, 0, 0, 63, 13, 1, 0, 0, 0, 63, 27, 1, 0, 0, 0, 63, 28, 1, 0, 0, 0, 63, 34, 1, 0, 0, 0, 63, 36, 1, 0, 0, 0, 63, 43, 1, 0, 0, 0, 63, 50, 1, 0, 0, 0, 63, 57, 1, 0, 0, 0, 63, 61, 1, 0, 0, 0, 64, 119, 1, 0, 0, 0, 65, 66, 10, 20, 0, 0, 66, 67, 5, 20, 0, 0, 67, 118, 3, 0, 0, 21, 68, 69, 10, 18, 0, 0, 69, 70, 7, 5, 0, 0, 70, 118, 3, 0, 0, 19, 71, 72, 10, 17, 0, 0, 72, 73, 7, 6, 0, 0, 73, 118, 3, 0, 0, 18, 74, 75, 10, 16, 0, 0, 75, 76, 7, 7, 0, 0, 76, 118, 3, 0, 0, 17, 77, 79, 10, 15, 0, 0, 78, 80, 5, 29, 0, 0, 79, 78, 1, 0, 0, 0, 79, 80, 1, 0, 0, 0, 80, 81, 1, 0, 0, 0, 81, 82, 5, 30, 0, 0, 82, 118, 3, 0, 0, 16, 83, 84, 10, 10, 0, 0, 84, 85, 7, 8, 0, 0, 85, 86, 7, 4, 0, 0, 86, 87, 7, 8, 0, 0, 87, 118, 3, 0, 0, 11, 88, 89, 10, 9, 0, 0, 89, 90, 7, 9, 0, 0, 90, 91, 7, 4, 0, 0, 91, 92, 7, 9, 0, 0, 92, 118, 3, 0, 0, 10, 93, 94, 10, 8, 0, 0, 94, 95, 7, 10, 0, 0, 95, 118, 3, 0, 0, 9, 96, 97, 10, 7, 0, 0, 97, 98, 7, 11, 0, 0, 98, 118, 3, 0, 0, 8, 99, 100, 10, 6, 0, 0, 100, 101, 5, 23, 0, 0, 101, 118, 3, 0, 0, 7, 102, 103, 10, 5, 0, 0, 103, 104, 5, 25, 0, 0, 104, 118, 3, 0, 0, 6, 105, 106, 10, 4, 0, 0, 106, 107, 5, 24, 0, 0, 107, 118, 3, 0, 0, 5, 108, 109, 10, 3, 0, 0, 109, 110, 5, 26, 0, 0, 110, 118, 3, 0, 0, 4, 111, 112, 10, 2, 0, 0, 112, 113, 5, 27, 0, 0, 113, 118, 3, 0, 0, 3, 114, 115, 10, 22, 0, 0, 115, 116, 5, 12, 0, 0, 116, 118, 5, 43, 0, 0, 117, 65, 1, 0, 0, 0, 117, 68, 1, 0, 0, 0, 117, 71, 1, 0, 0, 0, 117, 74, 1, 0, 0, 0, 117, 77, 1, 0, 0, 0, 117, 83, 1, 0, 0, 0, 117, 88, 1, 0, 0, 0, 117, 93, 1, 0, 0, 0, 117, 96, 1, 0, 0, 0, 117, 99, 1, 0, 0, 0, 117, 102, 1, 0, 0, 0, 117, 105, 1, 0, 0, 0, 117, 108, 1, 0, 0, 0, 117, 111, 1, 0, 0, 0, 117, 114, 1, 0, 0, 0, 118, 121, 1, 0, 0, 0, 119, 117, 1, 0, 0, 0, 119, 120, 1, 0, 0, 0, 120, 1, 1, 0, 0, 0, 121, 119, 1, 0, 0, 0, 6, 19, 23, 63, 79, 117, 119] \ No newline at end of file +[4, 1, 46, 139, 2, 0, 7, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 5, 0, 18, 8, 0, 10, 0, 12, 0, 21, 9, 0, 1, 0, 3, 0, 24, 8, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 5, 0, 67, 8, 0, 10, 0, 12, 0, 70, 9, 0, 1, 0, 3, 0, 73, 8, 0, 3, 0, 75, 8, 0, 1, 0, 1, 0, 1, 0, 3, 0, 80, 8, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 3, 0, 96, 8, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 5, 0, 134, 8, 0, 10, 0, 12, 0, 137, 9, 0, 1, 0, 0, 1, 0, 1, 0, 0, 12, 2, 0, 15, 16, 28, 29, 2, 0, 32, 32, 35, 35, 2, 0, 33, 33, 36, 36, 2, 0, 34, 34, 37, 37, 2, 0, 42, 42, 44, 44, 1, 0, 17, 19, 1, 0, 15, 16, 1, 0, 21, 22, 1, 0, 6, 7, 1, 0, 8, 9, 1, 0, 6, 9, 1, 0, 10, 11, 174, 0, 79, 1, 0, 0, 0, 2, 3, 6, 0, -1, 0, 3, 80, 5, 40, 0, 0, 4, 80, 5, 41, 0, 0, 5, 80, 5, 39, 0, 0, 6, 80, 5, 43, 0, 0, 7, 80, 5, 42, 0, 0, 8, 80, 5, 44, 0, 0, 9, 10, 5, 1, 0, 0, 10, 11, 3, 0, 0, 0, 11, 12, 5, 2, 0, 0, 12, 80, 1, 0, 0, 0, 13, 14, 5, 3, 0, 0, 14, 19, 3, 0, 0, 0, 15, 16, 5, 4, 0, 0, 16, 18, 3, 0, 0, 0, 17, 15, 1, 0, 0, 0, 18, 21, 1, 0, 0, 0, 19, 17, 1, 0, 0, 0, 19, 20, 1, 0, 0, 0, 20, 23, 1, 0, 0, 0, 21, 19, 1, 0, 0, 0, 22, 24, 5, 4, 0, 0, 23, 22, 1, 0, 0, 0, 23, 24, 1, 0, 0, 0, 24, 25, 1, 0, 0, 0, 25, 26, 5, 5, 0, 0, 26, 80, 1, 0, 0, 0, 27, 80, 5, 31, 0, 0, 28, 29, 5, 14, 0, 0, 29, 30, 5, 1, 0, 0, 30, 31, 5, 42, 0, 0, 31, 32, 5, 4, 0, 0, 32, 33, 5, 43, 0, 0, 33, 80, 5, 2, 0, 0, 34, 35, 7, 0, 0, 0, 35, 80, 3, 0, 0, 20, 36, 37, 7, 1, 0, 0, 37, 38, 5, 1, 0, 0, 38, 39, 3, 0, 0, 0, 39, 40, 5, 4, 0, 0, 40, 41, 3, 0, 0, 0, 41, 42, 5, 2, 0, 0, 42, 80, 1, 0, 0, 0, 43, 44, 7, 2, 0, 0, 44, 45, 5, 1, 0, 0, 45, 46, 3, 0, 0, 0, 46, 47, 5, 4, 0, 0, 47, 48, 3, 0, 0, 0, 48, 49, 5, 2, 0, 0, 49, 80, 1, 0, 0, 0, 50, 51, 7, 3, 0, 0, 51, 52, 5, 1, 0, 0, 52, 53, 3, 0, 0, 0, 53, 54, 5, 4, 0, 0, 54, 55, 3, 0, 0, 0, 55, 56, 5, 2, 0, 0, 56, 80, 1, 0, 0, 0, 57, 58, 5, 38, 0, 0, 58, 59, 5, 1, 0, 0, 59, 60, 7, 4, 0, 0, 60, 80, 5, 2, 0, 0, 61, 62, 5, 42, 0, 0, 62, 74, 5, 1, 0, 0, 63, 68, 3, 0, 0, 0, 64, 65, 5, 4, 0, 0, 65, 67, 3, 0, 0, 0, 66, 64, 1, 0, 0, 0, 67, 70, 1, 0, 0, 0, 68, 66, 1, 0, 0, 0, 68, 69, 1, 0, 0, 0, 69, 72, 1, 0, 0, 0, 70, 68, 1, 0, 0, 0, 71, 73, 5, 4, 0, 0, 72, 71, 1, 0, 0, 0, 72, 73, 1, 0, 0, 0, 73, 75, 1, 0, 0, 0, 74, 63, 1, 0, 0, 0, 74, 75, 1, 0, 0, 0, 75, 76, 1, 0, 0, 0, 76, 80, 5, 2, 0, 0, 77, 78, 5, 13, 0, 0, 78, 80, 3, 0, 0, 1, 79, 2, 1, 0, 0, 0, 79, 4, 1, 0, 0, 0, 79, 5, 1, 0, 0, 0, 79, 6, 1, 0, 0, 0, 79, 7, 1, 0, 0, 0, 79, 8, 1, 0, 0, 0, 79, 9, 1, 0, 0, 0, 79, 13, 1, 0, 0, 0, 79, 27, 1, 0, 0, 0, 79, 28, 1, 0, 0, 0, 79, 34, 1, 0, 0, 0, 79, 36, 1, 0, 0, 0, 79, 43, 1, 0, 0, 0, 79, 50, 1, 0, 0, 0, 79, 57, 1, 0, 0, 0, 79, 61, 1, 0, 0, 0, 79, 77, 1, 0, 0, 0, 80, 135, 1, 0, 0, 0, 81, 82, 10, 21, 0, 0, 82, 83, 5, 20, 0, 0, 83, 134, 3, 0, 0, 22, 84, 85, 10, 19, 0, 0, 85, 86, 7, 5, 0, 0, 86, 134, 3, 0, 0, 20, 87, 88, 10, 18, 0, 0, 88, 89, 7, 6, 0, 0, 89, 134, 3, 0, 0, 19, 90, 91, 10, 17, 0, 0, 91, 92, 7, 7, 0, 0, 92, 134, 3, 0, 0, 18, 93, 95, 10, 16, 0, 0, 94, 96, 5, 29, 0, 0, 95, 94, 1, 0, 0, 0, 95, 96, 1, 0, 0, 0, 96, 97, 1, 0, 0, 0, 97, 98, 5, 30, 0, 0, 98, 134, 3, 0, 0, 17, 99, 100, 10, 10, 0, 0, 100, 101, 7, 8, 0, 0, 101, 102, 7, 4, 0, 0, 102, 103, 7, 8, 0, 0, 103, 134, 3, 0, 0, 11, 104, 105, 10, 9, 0, 0, 105, 106, 7, 9, 0, 0, 106, 107, 7, 4, 0, 0, 107, 108, 7, 9, 0, 0, 108, 134, 3, 0, 0, 10, 109, 110, 10, 8, 0, 0, 110, 111, 7, 10, 0, 0, 111, 134, 3, 0, 0, 9, 112, 113, 10, 7, 0, 0, 113, 114, 7, 11, 0, 0, 114, 134, 3, 0, 0, 8, 115, 116, 10, 6, 0, 0, 116, 117, 5, 23, 0, 0, 117, 134, 3, 0, 0, 7, 118, 119, 10, 5, 0, 0, 119, 120, 5, 25, 0, 0, 120, 134, 3, 0, 0, 6, 121, 122, 10, 4, 0, 0, 122, 123, 5, 24, 0, 0, 123, 134, 3, 0, 0, 5, 124, 125, 10, 3, 0, 0, 125, 126, 5, 26, 0, 0, 126, 134, 3, 0, 0, 4, 127, 128, 10, 2, 0, 0, 128, 129, 5, 27, 0, 0, 129, 134, 3, 0, 0, 3, 130, 131, 10, 23, 0, 0, 131, 132, 5, 12, 0, 0, 132, 134, 5, 43, 0, 0, 133, 81, 1, 0, 0, 0, 133, 84, 1, 0, 0, 0, 133, 87, 1, 0, 0, 0, 133, 90, 1, 0, 0, 0, 133, 93, 1, 0, 0, 0, 133, 99, 1, 0, 0, 0, 133, 104, 1, 0, 0, 0, 133, 109, 1, 0, 0, 0, 133, 112, 1, 0, 0, 0, 133, 115, 1, 0, 0, 0, 133, 118, 1, 0, 0, 0, 133, 121, 1, 0, 0, 0, 133, 124, 1, 0, 0, 0, 133, 127, 1, 0, 0, 0, 133, 130, 1, 0, 0, 0, 134, 137, 1, 0, 0, 0, 135, 133, 1, 0, 0, 0, 135, 136, 1, 0, 0, 0, 136, 1, 1, 0, 0, 0, 137, 135, 1, 0, 0, 0, 9, 19, 23, 68, 72, 74, 79, 95, 133, 135] \ No newline at end of file diff --git a/internal/parser/planparserv2/generated/plan_base_visitor.go b/internal/parser/planparserv2/generated/plan_base_visitor.go index e8ae619676..2e7a30e771 100644 --- a/internal/parser/planparserv2/generated/plan_base_visitor.go +++ b/internal/parser/planparserv2/generated/plan_base_visitor.go @@ -59,6 +59,10 @@ func (v *BasePlanVisitor) VisitShift(ctx *ShiftContext) interface{} { return v.VisitChildren(ctx) } +func (v *BasePlanVisitor) VisitCall(ctx *CallContext) interface{} { + return v.VisitChildren(ctx) +} + func (v *BasePlanVisitor) VisitReverseRange(ctx *ReverseRangeContext) interface{} { return v.VisitChildren(ctx) } diff --git a/internal/parser/planparserv2/generated/plan_parser.go b/internal/parser/planparserv2/generated/plan_parser.go index e5dc91fda2..8869e09e0a 100644 --- a/internal/parser/planparserv2/generated/plan_parser.go +++ b/internal/parser/planparserv2/generated/plan_parser.go @@ -50,65 +50,73 @@ func planParserInit() { } staticData.PredictionContextCache = antlr.NewPredictionContextCache() staticData.serializedATN = []int32{ - 4, 1, 46, 123, 2, 0, 7, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, + 4, 1, 46, 139, 2, 0, 7, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 5, 0, 18, 8, 0, 10, 0, 12, 0, 21, 9, 0, 1, 0, 3, 0, 24, 8, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, - 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 3, 0, 64, 8, 0, 1, 0, 1, - 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, - 0, 3, 0, 80, 8, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, + 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 5, 0, + 67, 8, 0, 10, 0, 12, 0, 70, 9, 0, 1, 0, 3, 0, 73, 8, 0, 3, 0, 75, 8, 0, + 1, 0, 1, 0, 1, 0, 3, 0, 80, 8, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, + 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 3, 0, 96, 8, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, - 1, 0, 1, 0, 1, 0, 5, 0, 118, 8, 0, 10, 0, 12, 0, 121, 9, 0, 1, 0, 0, 1, - 0, 1, 0, 0, 12, 2, 0, 15, 16, 28, 29, 2, 0, 32, 32, 35, 35, 2, 0, 33, 33, - 36, 36, 2, 0, 34, 34, 37, 37, 2, 0, 42, 42, 44, 44, 1, 0, 17, 19, 1, 0, - 15, 16, 1, 0, 21, 22, 1, 0, 6, 7, 1, 0, 8, 9, 1, 0, 6, 9, 1, 0, 10, 11, - 154, 0, 63, 1, 0, 0, 0, 2, 3, 6, 0, -1, 0, 3, 64, 5, 40, 0, 0, 4, 64, 5, - 41, 0, 0, 5, 64, 5, 39, 0, 0, 6, 64, 5, 43, 0, 0, 7, 64, 5, 42, 0, 0, 8, - 64, 5, 44, 0, 0, 9, 10, 5, 1, 0, 0, 10, 11, 3, 0, 0, 0, 11, 12, 5, 2, 0, - 0, 12, 64, 1, 0, 0, 0, 13, 14, 5, 3, 0, 0, 14, 19, 3, 0, 0, 0, 15, 16, - 5, 4, 0, 0, 16, 18, 3, 0, 0, 0, 17, 15, 1, 0, 0, 0, 18, 21, 1, 0, 0, 0, - 19, 17, 1, 0, 0, 0, 19, 20, 1, 0, 0, 0, 20, 23, 1, 0, 0, 0, 21, 19, 1, - 0, 0, 0, 22, 24, 5, 4, 0, 0, 23, 22, 1, 0, 0, 0, 23, 24, 1, 0, 0, 0, 24, - 25, 1, 0, 0, 0, 25, 26, 5, 5, 0, 0, 26, 64, 1, 0, 0, 0, 27, 64, 5, 31, - 0, 0, 28, 29, 5, 14, 0, 0, 29, 30, 5, 1, 0, 0, 30, 31, 5, 42, 0, 0, 31, - 32, 5, 4, 0, 0, 32, 33, 5, 43, 0, 0, 33, 64, 5, 2, 0, 0, 34, 35, 7, 0, - 0, 0, 35, 64, 3, 0, 0, 19, 36, 37, 7, 1, 0, 0, 37, 38, 5, 1, 0, 0, 38, - 39, 3, 0, 0, 0, 39, 40, 5, 4, 0, 0, 40, 41, 3, 0, 0, 0, 41, 42, 5, 2, 0, - 0, 42, 64, 1, 0, 0, 0, 43, 44, 7, 2, 0, 0, 44, 45, 5, 1, 0, 0, 45, 46, - 3, 0, 0, 0, 46, 47, 5, 4, 0, 0, 47, 48, 3, 0, 0, 0, 48, 49, 5, 2, 0, 0, - 49, 64, 1, 0, 0, 0, 50, 51, 7, 3, 0, 0, 51, 52, 5, 1, 0, 0, 52, 53, 3, - 0, 0, 0, 53, 54, 5, 4, 0, 0, 54, 55, 3, 0, 0, 0, 55, 56, 5, 2, 0, 0, 56, - 64, 1, 0, 0, 0, 57, 58, 5, 38, 0, 0, 58, 59, 5, 1, 0, 0, 59, 60, 7, 4, - 0, 0, 60, 64, 5, 2, 0, 0, 61, 62, 5, 13, 0, 0, 62, 64, 3, 0, 0, 1, 63, - 2, 1, 0, 0, 0, 63, 4, 1, 0, 0, 0, 63, 5, 1, 0, 0, 0, 63, 6, 1, 0, 0, 0, - 63, 7, 1, 0, 0, 0, 63, 8, 1, 0, 0, 0, 63, 9, 1, 0, 0, 0, 63, 13, 1, 0, - 0, 0, 63, 27, 1, 0, 0, 0, 63, 28, 1, 0, 0, 0, 63, 34, 1, 0, 0, 0, 63, 36, - 1, 0, 0, 0, 63, 43, 1, 0, 0, 0, 63, 50, 1, 0, 0, 0, 63, 57, 1, 0, 0, 0, - 63, 61, 1, 0, 0, 0, 64, 119, 1, 0, 0, 0, 65, 66, 10, 20, 0, 0, 66, 67, - 5, 20, 0, 0, 67, 118, 3, 0, 0, 21, 68, 69, 10, 18, 0, 0, 69, 70, 7, 5, - 0, 0, 70, 118, 3, 0, 0, 19, 71, 72, 10, 17, 0, 0, 72, 73, 7, 6, 0, 0, 73, - 118, 3, 0, 0, 18, 74, 75, 10, 16, 0, 0, 75, 76, 7, 7, 0, 0, 76, 118, 3, - 0, 0, 17, 77, 79, 10, 15, 0, 0, 78, 80, 5, 29, 0, 0, 79, 78, 1, 0, 0, 0, - 79, 80, 1, 0, 0, 0, 80, 81, 1, 0, 0, 0, 81, 82, 5, 30, 0, 0, 82, 118, 3, - 0, 0, 16, 83, 84, 10, 10, 0, 0, 84, 85, 7, 8, 0, 0, 85, 86, 7, 4, 0, 0, - 86, 87, 7, 8, 0, 0, 87, 118, 3, 0, 0, 11, 88, 89, 10, 9, 0, 0, 89, 90, - 7, 9, 0, 0, 90, 91, 7, 4, 0, 0, 91, 92, 7, 9, 0, 0, 92, 118, 3, 0, 0, 10, - 93, 94, 10, 8, 0, 0, 94, 95, 7, 10, 0, 0, 95, 118, 3, 0, 0, 9, 96, 97, - 10, 7, 0, 0, 97, 98, 7, 11, 0, 0, 98, 118, 3, 0, 0, 8, 99, 100, 10, 6, - 0, 0, 100, 101, 5, 23, 0, 0, 101, 118, 3, 0, 0, 7, 102, 103, 10, 5, 0, - 0, 103, 104, 5, 25, 0, 0, 104, 118, 3, 0, 0, 6, 105, 106, 10, 4, 0, 0, - 106, 107, 5, 24, 0, 0, 107, 118, 3, 0, 0, 5, 108, 109, 10, 3, 0, 0, 109, - 110, 5, 26, 0, 0, 110, 118, 3, 0, 0, 4, 111, 112, 10, 2, 0, 0, 112, 113, - 5, 27, 0, 0, 113, 118, 3, 0, 0, 3, 114, 115, 10, 22, 0, 0, 115, 116, 5, - 12, 0, 0, 116, 118, 5, 43, 0, 0, 117, 65, 1, 0, 0, 0, 117, 68, 1, 0, 0, - 0, 117, 71, 1, 0, 0, 0, 117, 74, 1, 0, 0, 0, 117, 77, 1, 0, 0, 0, 117, - 83, 1, 0, 0, 0, 117, 88, 1, 0, 0, 0, 117, 93, 1, 0, 0, 0, 117, 96, 1, 0, - 0, 0, 117, 99, 1, 0, 0, 0, 117, 102, 1, 0, 0, 0, 117, 105, 1, 0, 0, 0, - 117, 108, 1, 0, 0, 0, 117, 111, 1, 0, 0, 0, 117, 114, 1, 0, 0, 0, 118, - 121, 1, 0, 0, 0, 119, 117, 1, 0, 0, 0, 119, 120, 1, 0, 0, 0, 120, 1, 1, - 0, 0, 0, 121, 119, 1, 0, 0, 0, 6, 19, 23, 63, 79, 117, 119, + 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 5, 0, 134, + 8, 0, 10, 0, 12, 0, 137, 9, 0, 1, 0, 0, 1, 0, 1, 0, 0, 12, 2, 0, 15, 16, + 28, 29, 2, 0, 32, 32, 35, 35, 2, 0, 33, 33, 36, 36, 2, 0, 34, 34, 37, 37, + 2, 0, 42, 42, 44, 44, 1, 0, 17, 19, 1, 0, 15, 16, 1, 0, 21, 22, 1, 0, 6, + 7, 1, 0, 8, 9, 1, 0, 6, 9, 1, 0, 10, 11, 174, 0, 79, 1, 0, 0, 0, 2, 3, + 6, 0, -1, 0, 3, 80, 5, 40, 0, 0, 4, 80, 5, 41, 0, 0, 5, 80, 5, 39, 0, 0, + 6, 80, 5, 43, 0, 0, 7, 80, 5, 42, 0, 0, 8, 80, 5, 44, 0, 0, 9, 10, 5, 1, + 0, 0, 10, 11, 3, 0, 0, 0, 11, 12, 5, 2, 0, 0, 12, 80, 1, 0, 0, 0, 13, 14, + 5, 3, 0, 0, 14, 19, 3, 0, 0, 0, 15, 16, 5, 4, 0, 0, 16, 18, 3, 0, 0, 0, + 17, 15, 1, 0, 0, 0, 18, 21, 1, 0, 0, 0, 19, 17, 1, 0, 0, 0, 19, 20, 1, + 0, 0, 0, 20, 23, 1, 0, 0, 0, 21, 19, 1, 0, 0, 0, 22, 24, 5, 4, 0, 0, 23, + 22, 1, 0, 0, 0, 23, 24, 1, 0, 0, 0, 24, 25, 1, 0, 0, 0, 25, 26, 5, 5, 0, + 0, 26, 80, 1, 0, 0, 0, 27, 80, 5, 31, 0, 0, 28, 29, 5, 14, 0, 0, 29, 30, + 5, 1, 0, 0, 30, 31, 5, 42, 0, 0, 31, 32, 5, 4, 0, 0, 32, 33, 5, 43, 0, + 0, 33, 80, 5, 2, 0, 0, 34, 35, 7, 0, 0, 0, 35, 80, 3, 0, 0, 20, 36, 37, + 7, 1, 0, 0, 37, 38, 5, 1, 0, 0, 38, 39, 3, 0, 0, 0, 39, 40, 5, 4, 0, 0, + 40, 41, 3, 0, 0, 0, 41, 42, 5, 2, 0, 0, 42, 80, 1, 0, 0, 0, 43, 44, 7, + 2, 0, 0, 44, 45, 5, 1, 0, 0, 45, 46, 3, 0, 0, 0, 46, 47, 5, 4, 0, 0, 47, + 48, 3, 0, 0, 0, 48, 49, 5, 2, 0, 0, 49, 80, 1, 0, 0, 0, 50, 51, 7, 3, 0, + 0, 51, 52, 5, 1, 0, 0, 52, 53, 3, 0, 0, 0, 53, 54, 5, 4, 0, 0, 54, 55, + 3, 0, 0, 0, 55, 56, 5, 2, 0, 0, 56, 80, 1, 0, 0, 0, 57, 58, 5, 38, 0, 0, + 58, 59, 5, 1, 0, 0, 59, 60, 7, 4, 0, 0, 60, 80, 5, 2, 0, 0, 61, 62, 5, + 42, 0, 0, 62, 74, 5, 1, 0, 0, 63, 68, 3, 0, 0, 0, 64, 65, 5, 4, 0, 0, 65, + 67, 3, 0, 0, 0, 66, 64, 1, 0, 0, 0, 67, 70, 1, 0, 0, 0, 68, 66, 1, 0, 0, + 0, 68, 69, 1, 0, 0, 0, 69, 72, 1, 0, 0, 0, 70, 68, 1, 0, 0, 0, 71, 73, + 5, 4, 0, 0, 72, 71, 1, 0, 0, 0, 72, 73, 1, 0, 0, 0, 73, 75, 1, 0, 0, 0, + 74, 63, 1, 0, 0, 0, 74, 75, 1, 0, 0, 0, 75, 76, 1, 0, 0, 0, 76, 80, 5, + 2, 0, 0, 77, 78, 5, 13, 0, 0, 78, 80, 3, 0, 0, 1, 79, 2, 1, 0, 0, 0, 79, + 4, 1, 0, 0, 0, 79, 5, 1, 0, 0, 0, 79, 6, 1, 0, 0, 0, 79, 7, 1, 0, 0, 0, + 79, 8, 1, 0, 0, 0, 79, 9, 1, 0, 0, 0, 79, 13, 1, 0, 0, 0, 79, 27, 1, 0, + 0, 0, 79, 28, 1, 0, 0, 0, 79, 34, 1, 0, 0, 0, 79, 36, 1, 0, 0, 0, 79, 43, + 1, 0, 0, 0, 79, 50, 1, 0, 0, 0, 79, 57, 1, 0, 0, 0, 79, 61, 1, 0, 0, 0, + 79, 77, 1, 0, 0, 0, 80, 135, 1, 0, 0, 0, 81, 82, 10, 21, 0, 0, 82, 83, + 5, 20, 0, 0, 83, 134, 3, 0, 0, 22, 84, 85, 10, 19, 0, 0, 85, 86, 7, 5, + 0, 0, 86, 134, 3, 0, 0, 20, 87, 88, 10, 18, 0, 0, 88, 89, 7, 6, 0, 0, 89, + 134, 3, 0, 0, 19, 90, 91, 10, 17, 0, 0, 91, 92, 7, 7, 0, 0, 92, 134, 3, + 0, 0, 18, 93, 95, 10, 16, 0, 0, 94, 96, 5, 29, 0, 0, 95, 94, 1, 0, 0, 0, + 95, 96, 1, 0, 0, 0, 96, 97, 1, 0, 0, 0, 97, 98, 5, 30, 0, 0, 98, 134, 3, + 0, 0, 17, 99, 100, 10, 10, 0, 0, 100, 101, 7, 8, 0, 0, 101, 102, 7, 4, + 0, 0, 102, 103, 7, 8, 0, 0, 103, 134, 3, 0, 0, 11, 104, 105, 10, 9, 0, + 0, 105, 106, 7, 9, 0, 0, 106, 107, 7, 4, 0, 0, 107, 108, 7, 9, 0, 0, 108, + 134, 3, 0, 0, 10, 109, 110, 10, 8, 0, 0, 110, 111, 7, 10, 0, 0, 111, 134, + 3, 0, 0, 9, 112, 113, 10, 7, 0, 0, 113, 114, 7, 11, 0, 0, 114, 134, 3, + 0, 0, 8, 115, 116, 10, 6, 0, 0, 116, 117, 5, 23, 0, 0, 117, 134, 3, 0, + 0, 7, 118, 119, 10, 5, 0, 0, 119, 120, 5, 25, 0, 0, 120, 134, 3, 0, 0, + 6, 121, 122, 10, 4, 0, 0, 122, 123, 5, 24, 0, 0, 123, 134, 3, 0, 0, 5, + 124, 125, 10, 3, 0, 0, 125, 126, 5, 26, 0, 0, 126, 134, 3, 0, 0, 4, 127, + 128, 10, 2, 0, 0, 128, 129, 5, 27, 0, 0, 129, 134, 3, 0, 0, 3, 130, 131, + 10, 23, 0, 0, 131, 132, 5, 12, 0, 0, 132, 134, 5, 43, 0, 0, 133, 81, 1, + 0, 0, 0, 133, 84, 1, 0, 0, 0, 133, 87, 1, 0, 0, 0, 133, 90, 1, 0, 0, 0, + 133, 93, 1, 0, 0, 0, 133, 99, 1, 0, 0, 0, 133, 104, 1, 0, 0, 0, 133, 109, + 1, 0, 0, 0, 133, 112, 1, 0, 0, 0, 133, 115, 1, 0, 0, 0, 133, 118, 1, 0, + 0, 0, 133, 121, 1, 0, 0, 0, 133, 124, 1, 0, 0, 0, 133, 127, 1, 0, 0, 0, + 133, 130, 1, 0, 0, 0, 134, 137, 1, 0, 0, 0, 135, 133, 1, 0, 0, 0, 135, + 136, 1, 0, 0, 0, 136, 1, 1, 0, 0, 0, 137, 135, 1, 0, 0, 0, 9, 19, 23, 68, + 72, 74, 79, 95, 133, 135, } deserializer := antlr.NewATNDeserializer(nil) staticData.atn = deserializer.Deserialize(staticData.serializedATN) @@ -981,6 +989,79 @@ func (s *ShiftContext) Accept(visitor antlr.ParseTreeVisitor) interface{} { } } +type CallContext struct { + ExprContext +} + +func NewCallContext(parser antlr.Parser, ctx antlr.ParserRuleContext) *CallContext { + var p = new(CallContext) + + InitEmptyExprContext(&p.ExprContext) + p.parser = parser + p.CopyAll(ctx.(*ExprContext)) + + return p +} + +func (s *CallContext) GetRuleContext() antlr.RuleContext { + return s +} + +func (s *CallContext) Identifier() antlr.TerminalNode { + return s.GetToken(PlanParserIdentifier, 0) +} + +func (s *CallContext) AllExpr() []IExprContext { + children := s.GetChildren() + len := 0 + for _, ctx := range children { + if _, ok := ctx.(IExprContext); ok { + len++ + } + } + + tst := make([]IExprContext, len) + i := 0 + for _, ctx := range children { + if t, ok := ctx.(IExprContext); ok { + tst[i] = t.(IExprContext) + i++ + } + } + + return tst +} + +func (s *CallContext) Expr(i int) IExprContext { + var t antlr.RuleContext + j := 0 + for _, ctx := range s.GetChildren() { + if _, ok := ctx.(IExprContext); ok { + if j == i { + t = ctx.(antlr.RuleContext) + break + } + j++ + } + } + + if t == nil { + return nil + } + + return t.(IExprContext) +} + +func (s *CallContext) Accept(visitor antlr.ParseTreeVisitor) interface{} { + switch t := visitor.(type) { + case PlanVisitor: + return t.VisitCall(s) + + default: + return t.VisitChildren(s) + } +} + type ReverseRangeContext struct { ExprContext op1 antlr.Token @@ -2231,14 +2312,14 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { var _alt int p.EnterOuterAlt(localctx, 1) - p.SetState(63) + p.SetState(79) p.GetErrorHandler().Sync(p) if p.HasError() { goto errorExit } - switch p.GetTokenStream().LA(1) { - case PlanParserIntegerConstant: + switch p.GetInterpreter().AdaptivePredict(p.BaseParser, p.GetTokenStream(), 5, p.GetParserRuleContext()) { + case 1: localctx = NewIntegerContext(p, localctx) p.SetParserRuleContext(localctx) _prevctx = localctx @@ -2252,7 +2333,7 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } - case PlanParserFloatingConstant: + case 2: localctx = NewFloatingContext(p, localctx) p.SetParserRuleContext(localctx) _prevctx = localctx @@ -2265,7 +2346,7 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } - case PlanParserBooleanConstant: + case 3: localctx = NewBooleanContext(p, localctx) p.SetParserRuleContext(localctx) _prevctx = localctx @@ -2278,7 +2359,7 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } - case PlanParserStringLiteral: + case 4: localctx = NewStringContext(p, localctx) p.SetParserRuleContext(localctx) _prevctx = localctx @@ -2291,7 +2372,7 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } - case PlanParserIdentifier: + case 5: localctx = NewIdentifierContext(p, localctx) p.SetParserRuleContext(localctx) _prevctx = localctx @@ -2304,7 +2385,7 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } - case PlanParserJSONIdentifier: + case 6: localctx = NewJSONIdentifierContext(p, localctx) p.SetParserRuleContext(localctx) _prevctx = localctx @@ -2317,7 +2398,7 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } - case PlanParserT__0: + case 7: localctx = NewParensContext(p, localctx) p.SetParserRuleContext(localctx) _prevctx = localctx @@ -2342,7 +2423,7 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } - case PlanParserT__2: + case 8: localctx = NewArrayContext(p, localctx) p.SetParserRuleContext(localctx) _prevctx = localctx @@ -2420,7 +2501,7 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } - case PlanParserEmptyArray: + case 9: localctx = NewEmptyArrayContext(p, localctx) p.SetParserRuleContext(localctx) _prevctx = localctx @@ -2433,7 +2514,7 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } - case PlanParserTEXTMATCH: + case 10: localctx = NewTextMatchContext(p, localctx) p.SetParserRuleContext(localctx) _prevctx = localctx @@ -2486,7 +2567,7 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } - case PlanParserADD, PlanParserSUB, PlanParserBNOT, PlanParserNOT: + case 11: localctx = NewUnaryContext(p, localctx) p.SetParserRuleContext(localctx) _prevctx = localctx @@ -2510,10 +2591,10 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } { p.SetState(35) - p.expr(19) + p.expr(20) } - case PlanParserJSONContains, PlanParserArrayContains: + case 12: localctx = NewJSONContainsContext(p, localctx) p.SetParserRuleContext(localctx) _prevctx = localctx @@ -2561,7 +2642,7 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } - case PlanParserJSONContainsAll, PlanParserArrayContainsAll: + case 13: localctx = NewJSONContainsAllContext(p, localctx) p.SetParserRuleContext(localctx) _prevctx = localctx @@ -2609,7 +2690,7 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } - case PlanParserJSONContainsAny, PlanParserArrayContainsAny: + case 14: localctx = NewJSONContainsAnyContext(p, localctx) p.SetParserRuleContext(localctx) _prevctx = localctx @@ -2657,7 +2738,7 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } - case PlanParserArrayLength: + case 15: localctx = NewArrayLengthContext(p, localctx) p.SetParserRuleContext(localctx) _prevctx = localctx @@ -2697,13 +2778,13 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } - case PlanParserEXISTS: - localctx = NewExistsContext(p, localctx) + case 16: + localctx = NewCallContext(p, localctx) p.SetParserRuleContext(localctx) _prevctx = localctx { p.SetState(61) - p.Match(PlanParserEXISTS) + p.Match(PlanParserIdentifier) if p.HasError() { // Recognition error - abort rule goto errorExit @@ -2711,20 +2792,115 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } { p.SetState(62) + p.Match(PlanParserT__0) + if p.HasError() { + // Recognition error - abort rule + goto errorExit + } + } + p.SetState(74) + p.GetErrorHandler().Sync(p) + if p.HasError() { + goto errorExit + } + _la = p.GetTokenStream().LA(1) + + if (int64(_la) & ^0x3f) == 0 && ((int64(1)<<_la)&35183030034442) != 0 { + { + p.SetState(63) + p.expr(0) + } + p.SetState(68) + p.GetErrorHandler().Sync(p) + if p.HasError() { + goto errorExit + } + _alt = p.GetInterpreter().AdaptivePredict(p.BaseParser, p.GetTokenStream(), 2, p.GetParserRuleContext()) + if p.HasError() { + goto errorExit + } + for _alt != 2 && _alt != antlr.ATNInvalidAltNumber { + if _alt == 1 { + { + p.SetState(64) + p.Match(PlanParserT__3) + if p.HasError() { + // Recognition error - abort rule + goto errorExit + } + } + { + p.SetState(65) + p.expr(0) + } + + } + p.SetState(70) + p.GetErrorHandler().Sync(p) + if p.HasError() { + goto errorExit + } + _alt = p.GetInterpreter().AdaptivePredict(p.BaseParser, p.GetTokenStream(), 2, p.GetParserRuleContext()) + if p.HasError() { + goto errorExit + } + } + p.SetState(72) + p.GetErrorHandler().Sync(p) + if p.HasError() { + goto errorExit + } + _la = p.GetTokenStream().LA(1) + + if _la == PlanParserT__3 { + { + p.SetState(71) + p.Match(PlanParserT__3) + if p.HasError() { + // Recognition error - abort rule + goto errorExit + } + } + + } + + } + { + p.SetState(76) + p.Match(PlanParserT__1) + if p.HasError() { + // Recognition error - abort rule + goto errorExit + } + } + + case 17: + localctx = NewExistsContext(p, localctx) + p.SetParserRuleContext(localctx) + _prevctx = localctx + { + p.SetState(77) + p.Match(PlanParserEXISTS) + if p.HasError() { + // Recognition error - abort rule + goto errorExit + } + } + { + p.SetState(78) p.expr(1) } - default: - p.SetError(antlr.NewNoViableAltException(p, nil, nil, nil, nil, nil)) + case antlr.ATNInvalidAltNumber: goto errorExit } p.GetParserRuleContext().SetStop(p.GetTokenStream().LT(-1)) - p.SetState(119) + p.SetState(135) p.GetErrorHandler().Sync(p) if p.HasError() { goto errorExit } - _alt = p.GetInterpreter().AdaptivePredict(p.BaseParser, p.GetTokenStream(), 5, p.GetParserRuleContext()) + _alt = p.GetInterpreter().AdaptivePredict(p.BaseParser, p.GetTokenStream(), 8, p.GetParserRuleContext()) if p.HasError() { goto errorExit } @@ -2734,24 +2910,24 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { p.TriggerExitRuleEvent() } _prevctx = localctx - p.SetState(117) + p.SetState(133) p.GetErrorHandler().Sync(p) if p.HasError() { goto errorExit } - switch p.GetInterpreter().AdaptivePredict(p.BaseParser, p.GetTokenStream(), 4, p.GetParserRuleContext()) { + switch p.GetInterpreter().AdaptivePredict(p.BaseParser, p.GetTokenStream(), 7, p.GetParserRuleContext()) { case 1: localctx = NewPowerContext(p, NewExprContext(p, _parentctx, _parentState)) p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr) - p.SetState(65) + p.SetState(81) - if !(p.Precpred(p.GetParserRuleContext(), 20)) { - p.SetError(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 20)", "")) + if !(p.Precpred(p.GetParserRuleContext(), 21)) { + p.SetError(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 21)", "")) goto errorExit } { - p.SetState(66) + p.SetState(82) p.Match(PlanParserPOW) if p.HasError() { // Recognition error - abort rule @@ -2759,21 +2935,21 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } { - p.SetState(67) - p.expr(21) + p.SetState(83) + p.expr(22) } case 2: localctx = NewMulDivModContext(p, NewExprContext(p, _parentctx, _parentState)) p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr) - p.SetState(68) + p.SetState(84) - if !(p.Precpred(p.GetParserRuleContext(), 18)) { - p.SetError(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 18)", "")) + if !(p.Precpred(p.GetParserRuleContext(), 19)) { + p.SetError(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 19)", "")) goto errorExit } { - p.SetState(69) + p.SetState(85) var _lt = p.GetTokenStream().LT(1) @@ -2791,21 +2967,21 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } { - p.SetState(70) - p.expr(19) + p.SetState(86) + p.expr(20) } case 3: localctx = NewAddSubContext(p, NewExprContext(p, _parentctx, _parentState)) p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr) - p.SetState(71) + p.SetState(87) - if !(p.Precpred(p.GetParserRuleContext(), 17)) { - p.SetError(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 17)", "")) + if !(p.Precpred(p.GetParserRuleContext(), 18)) { + p.SetError(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 18)", "")) goto errorExit } { - p.SetState(72) + p.SetState(88) var _lt = p.GetTokenStream().LT(1) @@ -2823,21 +2999,21 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } { - p.SetState(73) - p.expr(18) + p.SetState(89) + p.expr(19) } case 4: localctx = NewShiftContext(p, NewExprContext(p, _parentctx, _parentState)) p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr) - p.SetState(74) + p.SetState(90) - if !(p.Precpred(p.GetParserRuleContext(), 16)) { - p.SetError(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 16)", "")) + if !(p.Precpred(p.GetParserRuleContext(), 17)) { + p.SetError(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 17)", "")) goto errorExit } { - p.SetState(75) + p.SetState(91) var _lt = p.GetTokenStream().LT(1) @@ -2855,20 +3031,20 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } { - p.SetState(76) - p.expr(17) + p.SetState(92) + p.expr(18) } case 5: localctx = NewTermContext(p, NewExprContext(p, _parentctx, _parentState)) p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr) - p.SetState(77) + p.SetState(93) - if !(p.Precpred(p.GetParserRuleContext(), 15)) { - p.SetError(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 15)", "")) + if !(p.Precpred(p.GetParserRuleContext(), 16)) { + p.SetError(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 16)", "")) goto errorExit } - p.SetState(79) + p.SetState(95) p.GetErrorHandler().Sync(p) if p.HasError() { goto errorExit @@ -2877,7 +3053,7 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { if _la == PlanParserNOT { { - p.SetState(78) + p.SetState(94) var _m = p.Match(PlanParserNOT) @@ -2890,7 +3066,7 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } { - p.SetState(81) + p.SetState(97) p.Match(PlanParserIN) if p.HasError() { // Recognition error - abort rule @@ -2898,21 +3074,21 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } { - p.SetState(82) - p.expr(16) + p.SetState(98) + p.expr(17) } case 6: localctx = NewRangeContext(p, NewExprContext(p, _parentctx, _parentState)) p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr) - p.SetState(83) + p.SetState(99) if !(p.Precpred(p.GetParserRuleContext(), 10)) { p.SetError(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 10)", "")) goto errorExit } { - p.SetState(84) + p.SetState(100) var _lt = p.GetTokenStream().LT(1) @@ -2930,7 +3106,7 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } { - p.SetState(85) + p.SetState(101) _la = p.GetTokenStream().LA(1) if !(_la == PlanParserIdentifier || _la == PlanParserJSONIdentifier) { @@ -2941,7 +3117,7 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } { - p.SetState(86) + p.SetState(102) var _lt = p.GetTokenStream().LT(1) @@ -2959,21 +3135,21 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } { - p.SetState(87) + p.SetState(103) p.expr(11) } case 7: localctx = NewReverseRangeContext(p, NewExprContext(p, _parentctx, _parentState)) p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr) - p.SetState(88) + p.SetState(104) if !(p.Precpred(p.GetParserRuleContext(), 9)) { p.SetError(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 9)", "")) goto errorExit } { - p.SetState(89) + p.SetState(105) var _lt = p.GetTokenStream().LT(1) @@ -2991,7 +3167,7 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } { - p.SetState(90) + p.SetState(106) _la = p.GetTokenStream().LA(1) if !(_la == PlanParserIdentifier || _la == PlanParserJSONIdentifier) { @@ -3002,7 +3178,7 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } { - p.SetState(91) + p.SetState(107) var _lt = p.GetTokenStream().LT(1) @@ -3020,21 +3196,21 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } { - p.SetState(92) + p.SetState(108) p.expr(10) } case 8: localctx = NewRelationalContext(p, NewExprContext(p, _parentctx, _parentState)) p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr) - p.SetState(93) + p.SetState(109) if !(p.Precpred(p.GetParserRuleContext(), 8)) { p.SetError(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 8)", "")) goto errorExit } { - p.SetState(94) + p.SetState(110) var _lt = p.GetTokenStream().LT(1) @@ -3052,21 +3228,21 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } { - p.SetState(95) + p.SetState(111) p.expr(9) } case 9: localctx = NewEqualityContext(p, NewExprContext(p, _parentctx, _parentState)) p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr) - p.SetState(96) + p.SetState(112) if !(p.Precpred(p.GetParserRuleContext(), 7)) { p.SetError(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 7)", "")) goto errorExit } { - p.SetState(97) + p.SetState(113) var _lt = p.GetTokenStream().LT(1) @@ -3084,21 +3260,21 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } { - p.SetState(98) + p.SetState(114) p.expr(8) } case 10: localctx = NewBitAndContext(p, NewExprContext(p, _parentctx, _parentState)) p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr) - p.SetState(99) + p.SetState(115) if !(p.Precpred(p.GetParserRuleContext(), 6)) { p.SetError(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 6)", "")) goto errorExit } { - p.SetState(100) + p.SetState(116) p.Match(PlanParserBAND) if p.HasError() { // Recognition error - abort rule @@ -3106,21 +3282,21 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } { - p.SetState(101) + p.SetState(117) p.expr(7) } case 11: localctx = NewBitXorContext(p, NewExprContext(p, _parentctx, _parentState)) p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr) - p.SetState(102) + p.SetState(118) if !(p.Precpred(p.GetParserRuleContext(), 5)) { p.SetError(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 5)", "")) goto errorExit } { - p.SetState(103) + p.SetState(119) p.Match(PlanParserBXOR) if p.HasError() { // Recognition error - abort rule @@ -3128,21 +3304,21 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } { - p.SetState(104) + p.SetState(120) p.expr(6) } case 12: localctx = NewBitOrContext(p, NewExprContext(p, _parentctx, _parentState)) p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr) - p.SetState(105) + p.SetState(121) if !(p.Precpred(p.GetParserRuleContext(), 4)) { p.SetError(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 4)", "")) goto errorExit } { - p.SetState(106) + p.SetState(122) p.Match(PlanParserBOR) if p.HasError() { // Recognition error - abort rule @@ -3150,21 +3326,21 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } { - p.SetState(107) + p.SetState(123) p.expr(5) } case 13: localctx = NewLogicalAndContext(p, NewExprContext(p, _parentctx, _parentState)) p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr) - p.SetState(108) + p.SetState(124) if !(p.Precpred(p.GetParserRuleContext(), 3)) { p.SetError(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 3)", "")) goto errorExit } { - p.SetState(109) + p.SetState(125) p.Match(PlanParserAND) if p.HasError() { // Recognition error - abort rule @@ -3172,21 +3348,21 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } { - p.SetState(110) + p.SetState(126) p.expr(4) } case 14: localctx = NewLogicalOrContext(p, NewExprContext(p, _parentctx, _parentState)) p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr) - p.SetState(111) + p.SetState(127) if !(p.Precpred(p.GetParserRuleContext(), 2)) { p.SetError(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 2)", "")) goto errorExit } { - p.SetState(112) + p.SetState(128) p.Match(PlanParserOR) if p.HasError() { // Recognition error - abort rule @@ -3194,21 +3370,21 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } { - p.SetState(113) + p.SetState(129) p.expr(3) } case 15: localctx = NewLikeContext(p, NewExprContext(p, _parentctx, _parentState)) p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr) - p.SetState(114) + p.SetState(130) - if !(p.Precpred(p.GetParserRuleContext(), 22)) { - p.SetError(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 22)", "")) + if !(p.Precpred(p.GetParserRuleContext(), 23)) { + p.SetError(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 23)", "")) goto errorExit } { - p.SetState(115) + p.SetState(131) p.Match(PlanParserLIKE) if p.HasError() { // Recognition error - abort rule @@ -3216,7 +3392,7 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } { - p.SetState(116) + p.SetState(132) p.Match(PlanParserStringLiteral) if p.HasError() { // Recognition error - abort rule @@ -3229,12 +3405,12 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } - p.SetState(121) + p.SetState(137) p.GetErrorHandler().Sync(p) if p.HasError() { goto errorExit } - _alt = p.GetInterpreter().AdaptivePredict(p.BaseParser, p.GetTokenStream(), 5, p.GetParserRuleContext()) + _alt = p.GetInterpreter().AdaptivePredict(p.BaseParser, p.GetTokenStream(), 8, p.GetParserRuleContext()) if p.HasError() { goto errorExit } @@ -3270,19 +3446,19 @@ func (p *PlanParser) Sempred(localctx antlr.RuleContext, ruleIndex, predIndex in func (p *PlanParser) Expr_Sempred(localctx antlr.RuleContext, predIndex int) bool { switch predIndex { case 0: - return p.Precpred(p.GetParserRuleContext(), 20) + return p.Precpred(p.GetParserRuleContext(), 21) case 1: - return p.Precpred(p.GetParserRuleContext(), 18) + return p.Precpred(p.GetParserRuleContext(), 19) case 2: - return p.Precpred(p.GetParserRuleContext(), 17) + return p.Precpred(p.GetParserRuleContext(), 18) case 3: - return p.Precpred(p.GetParserRuleContext(), 16) + return p.Precpred(p.GetParserRuleContext(), 17) case 4: - return p.Precpred(p.GetParserRuleContext(), 15) + return p.Precpred(p.GetParserRuleContext(), 16) case 5: return p.Precpred(p.GetParserRuleContext(), 10) @@ -3312,7 +3488,7 @@ func (p *PlanParser) Expr_Sempred(localctx antlr.RuleContext, predIndex int) boo return p.Precpred(p.GetParserRuleContext(), 2) case 14: - return p.Precpred(p.GetParserRuleContext(), 22) + return p.Precpred(p.GetParserRuleContext(), 23) default: panic("No predicate with index: " + fmt.Sprint(predIndex)) diff --git a/internal/parser/planparserv2/generated/plan_visitor.go b/internal/parser/planparserv2/generated/plan_visitor.go index acaa0a833b..a043068901 100644 --- a/internal/parser/planparserv2/generated/plan_visitor.go +++ b/internal/parser/planparserv2/generated/plan_visitor.go @@ -46,6 +46,9 @@ type PlanVisitor interface { // Visit a parse tree produced by PlanParser#Shift. VisitShift(ctx *ShiftContext) interface{} + // Visit a parse tree produced by PlanParser#Call. + VisitCall(ctx *CallContext) interface{} + // Visit a parse tree produced by PlanParser#ReverseRange. VisitReverseRange(ctx *ReverseRangeContext) interface{} diff --git a/internal/parser/planparserv2/parser_visitor.go b/internal/parser/planparserv2/parser_visitor.go index 92af0da0c4..cf9280a720 100644 --- a/internal/parser/planparserv2/parser_visitor.go +++ b/internal/parser/planparserv2/parser_visitor.go @@ -594,6 +594,28 @@ func (v *ParserVisitor) getChildColumnInfo(identifier, child antlr.TerminalNode) return v.getColumnInfoFromJSONIdentifier(child.GetText()) } +// VisitCall parses the expr to call plan. +func (v *ParserVisitor) VisitCall(ctx *parser.CallContext) interface{} { + functionName := strings.ToLower(ctx.Identifier().GetText()) + numParams := len(ctx.AllExpr()) + funcParameters := make([]*planpb.Expr, 0, numParams) + for _, param := range ctx.AllExpr() { + paramExpr := getExpr(param.Accept(v)) + funcParameters = append(funcParameters, paramExpr.expr) + } + return &ExprWithType{ + expr: &planpb.Expr{ + Expr: &planpb.Expr_CallExpr{ + CallExpr: &planpb.CallExpr{ + FunctionName: functionName, + FunctionParameters: funcParameters, + }, + }, + }, + dataType: schemapb.DataType_Bool, + } +} + // VisitRange translates expr to range plan. func (v *ParserVisitor) VisitRange(ctx *parser.RangeContext) interface{} { columnInfo, err := v.getChildColumnInfo(ctx.Identifier(), ctx.JSONIdentifier()) diff --git a/internal/parser/planparserv2/plan_parser_v2_test.go b/internal/parser/planparserv2/plan_parser_v2_test.go index 5c74535cec..8d884c5652 100644 --- a/internal/parser/planparserv2/plan_parser_v2_test.go +++ b/internal/parser/planparserv2/plan_parser_v2_test.go @@ -58,13 +58,11 @@ func newTestSchemaHelper(t *testing.T) *typeutil.SchemaHelper { } func assertValidExpr(t *testing.T, helper *typeutil.SchemaHelper, exprStr string) { - _, err := ParseExpr(helper, exprStr) + expr, err := ParseExpr(helper, exprStr) assert.NoError(t, err, exprStr) - - // expr, err := ParseExpr(helper, exprStr) - // assert.NoError(t, err, exprStr) // fmt.Printf("expr: %s\n", exprStr) - // ShowExpr(expr) + assert.NotNil(t, expr, exprStr) + ShowExpr(expr) } func assertInvalidExpr(t *testing.T, helper *typeutil.SchemaHelper, exprStr string) { @@ -106,6 +104,43 @@ func TestExpr_Term(t *testing.T) { } } +func TestExpr_Call(t *testing.T) { + schema := newTestSchema() + helper, err := typeutil.CreateSchemaHelper(schema) + assert.NoError(t, err) + + testcases := []struct { + CallExpr string + FunctionName string + ParameterNum int + }{ + {`hello123()`, "hello123", 0}, + {`lt(Int32Field)`, "lt", 1}, + // test parens + {`lt((((Int32Field))))`, "lt", 1}, + {`empty(VarCharField,)`, "empty", 1}, + {`f2(Int64Field)`, "f2", 1}, + {`f2(Int64Field, 4)`, "f2", 2}, + {`f3(JSON_FIELD["A"], Int32Field)`, "f3", 2}, + {`f5(3+3, Int32Field)`, "f5", 2}, + } + for _, testcase := range testcases { + expr, err := ParseExpr(helper, testcase.CallExpr) + assert.NoError(t, err, testcase) + assert.Equal(t, testcase.FunctionName, expr.GetCallExpr().FunctionName, testcase) + assert.Equal(t, testcase.ParameterNum, len(expr.GetCallExpr().FunctionParameters), testcase) + ShowExpr(expr) + } + + expr, err := ParseExpr(helper, "xxx(1+1, !true, f(10+10))") + assert.NoError(t, err) + assert.Equal(t, "xxx", expr.GetCallExpr().FunctionName) + assert.Equal(t, 3, len(expr.GetCallExpr().FunctionParameters)) + assert.Equal(t, int64(2), expr.GetCallExpr().GetFunctionParameters()[0].GetValueExpr().GetValue().GetInt64Val()) + assert.Equal(t, false, expr.GetCallExpr().GetFunctionParameters()[1].GetValueExpr().GetValue().GetBoolVal()) + assert.Equal(t, int64(20), expr.GetCallExpr().GetFunctionParameters()[2].GetCallExpr().GetFunctionParameters()[0].GetValueExpr().GetValue().GetInt64Val()) +} + func TestExpr_Compare(t *testing.T) { schema := newTestSchema() helper, err := typeutil.CreateSchemaHelper(schema) @@ -247,7 +282,7 @@ func TestExpr_BinaryArith(t *testing.T) { exprStrs := []string{ `Int64Field % 10 == 9`, `Int64Field % 10 != 9`, - `Int64Field + 1.1 == 2.1`, + `FloatField + 1.1 == 2.1`, `A % 10 != 2`, `Int8Field + 1 < 2`, `Int16Field - 3 <= 4`, @@ -265,6 +300,13 @@ func TestExpr_BinaryArith(t *testing.T) { assertValidExpr(t, helper, exprStr) } + invalidExprs := []string{ + `Int64Field + 1.1 == 2.1`, + } + for _, exprStr := range invalidExprs { + assertInvalidExpr(t, helper, exprStr) + } + // TODO: enable these after execution backend is ready. unsupported := []string{ `ArrayField + 15 == 16`, @@ -286,6 +328,7 @@ func TestExpr_Value(t *testing.T) { `true`, `false`, `"str"`, + `3 > 2`, } for _, exprStr := range exprStrs { expr := handleExpr(helper, exprStr) @@ -935,6 +978,8 @@ func Test_JSONContains(t *testing.T) { `json_contains(JSONField["x"], 5)`, `not json_contains(JSONField["x"], 5)`, `JSON_CONTAINS(JSONField["x"], 5)`, + `json_Contains(JSONField, 5)`, + `JSON_contains(JSONField, 5)`, `json_contains(A, [1,2,3])`, `array_contains(A, [1,2,3])`, `array_contains(ArrayField, [1,2,3])`, @@ -970,8 +1015,6 @@ func Test_InvalidJSONContains(t *testing.T) { `json_contains(A, StringField > 5)`, `json_contains(A)`, `json_contains(A, 5, C)`, - `json_Contains(JSONField, 5)`, - `JSON_contains(JSONField, 5)`, } for _, expr = range exprs { _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ diff --git a/internal/parser/planparserv2/show_visitor.go b/internal/parser/planparserv2/show_visitor.go index b9b263b6e0..1a06d93d5e 100644 --- a/internal/parser/planparserv2/show_visitor.go +++ b/internal/parser/planparserv2/show_visitor.go @@ -46,6 +46,8 @@ func (v *ShowExprVisitor) VisitExpr(expr *planpb.Expr) interface{} { js["expr"] = v.VisitUnaryExpr(realExpr.UnaryExpr) case *planpb.Expr_BinaryExpr: js["expr"] = v.VisitBinaryExpr(realExpr.BinaryExpr) + case *planpb.Expr_CallExpr: + js["expr"] = v.VisitCallExpr(realExpr.CallExpr) case *planpb.Expr_CompareExpr: js["expr"] = v.VisitCompareExpr(realExpr.CompareExpr) case *planpb.Expr_UnaryRangeExpr: @@ -93,6 +95,18 @@ func (v *ShowExprVisitor) VisitBinaryExpr(expr *planpb.BinaryExpr) interface{} { return js } +func (v *ShowExprVisitor) VisitCallExpr(expr *planpb.CallExpr) interface{} { + js := make(map[string]interface{}) + js["expr_type"] = "call" + js["func_name"] = expr.FunctionName + params := make([]interface{}, 0, len(expr.FunctionParameters)) + for _, p := range expr.FunctionParameters { + params = append(params, v.VisitExpr(p)) + } + js["func_parameters"] = params + return js +} + func (v *ShowExprVisitor) VisitCompareExpr(expr *planpb.CompareExpr) interface{} { js := make(map[string]interface{}) js["expr_type"] = "compare" @@ -164,6 +178,6 @@ func NewShowExprVisitor() LogicalExprVisitor { func ShowExpr(expr *planpb.Expr) { v := NewShowExprVisitor() js := v.VisitExpr(expr) - b, _ := json.MarshalIndent(js, "", " ") + b, _ := json.Marshal(js) log.Info("[ShowExpr]", zap.String("expr", string(b))) } diff --git a/internal/parser/planparserv2/utils.go b/internal/parser/planparserv2/utils.go index db87dcfa2b..6708e64808 100644 --- a/internal/parser/planparserv2/utils.go +++ b/internal/parser/planparserv2/utils.go @@ -223,14 +223,14 @@ func castValue(dataType schemapb.DataType, value *planpb.GenericValue) (*planpb. return nil, fmt.Errorf("cannot cast value to %s, value: %s", dataType.String(), value) } -func combineBinaryArithExpr(op planpb.OpType, arithOp planpb.ArithOpType, columnInfo *planpb.ColumnInfo, operand *planpb.GenericValue, value *planpb.GenericValue) *planpb.Expr { +func combineBinaryArithExpr(op planpb.OpType, arithOp planpb.ArithOpType, columnInfo *planpb.ColumnInfo, operand *planpb.GenericValue, value *planpb.GenericValue) (*planpb.Expr, error) { dataType := columnInfo.GetDataType() if typeutil.IsArrayType(dataType) && len(columnInfo.GetNestedPath()) != 0 { dataType = columnInfo.GetElementType() } castedValue, err := castValue(dataType, operand) if err != nil { - return nil + return nil, err } return &planpb.Expr{ Expr: &planpb.Expr_BinaryArithOpEvalRangeExpr{ @@ -242,7 +242,7 @@ func combineBinaryArithExpr(op planpb.OpType, arithOp planpb.ArithOpType, column Value: value, }, }, - } + }, nil } func combineArrayLengthExpr(op planpb.OpType, arithOp planpb.ArithOpType, columnInfo *planpb.ColumnInfo, value *planpb.GenericValue) (*planpb.Expr, error) { @@ -282,7 +282,7 @@ func handleBinaryArithExpr(op planpb.OpType, arithExpr *planpb.BinaryArithExpr, // a * 2 == 3 // a / 2 == 3 // a % 2 == 3 - return combineBinaryArithExpr(op, arithOp, leftExpr.GetInfo(), rightValue.GetValue(), valueExpr.GetValue()), nil + return combineBinaryArithExpr(op, arithOp, leftExpr.GetInfo(), rightValue.GetValue(), valueExpr.GetValue()) } else if rightExpr != nil && leftValue != nil { // 2 + a == 3 // 2 - a == 3 @@ -292,7 +292,7 @@ func handleBinaryArithExpr(op planpb.OpType, arithExpr *planpb.BinaryArithExpr, switch arithExpr.GetOp() { case planpb.ArithOpType_Add, planpb.ArithOpType_Mul: - return combineBinaryArithExpr(op, arithOp, rightExpr.GetInfo(), leftValue.GetValue(), valueExpr.GetValue()), nil + return combineBinaryArithExpr(op, arithOp, rightExpr.GetInfo(), leftValue.GetValue(), valueExpr.GetValue()) default: return nil, fmt.Errorf("module field is not yet supported") } diff --git a/internal/proto/plan.proto b/internal/proto/plan.proto index 16ed9aee2b..0ee1d6c03a 100644 --- a/internal/proto/plan.proto +++ b/internal/proto/plan.proto @@ -105,6 +105,11 @@ message BinaryRangeExpr { GenericValue upper_value = 5; } +message CallExpr { + string function_name = 1; + repeated Expr function_parameters = 2; +} + message CompareExpr { ColumnInfo left_column_info = 1; ColumnInfo right_column_info = 2; @@ -191,6 +196,7 @@ message Expr { ExistsExpr exists_expr = 11; AlwaysTrueExpr always_true_expr = 12; JSONContainsExpr json_contains_expr = 13; + CallExpr call_expr = 14; }; } diff --git a/internal/querynodev2/server.go b/internal/querynodev2/server.go index 19ac5e7a96..051630a527 100644 --- a/internal/querynodev2/server.go +++ b/internal/querynodev2/server.go @@ -23,6 +23,7 @@ package querynodev2 #include "segcore/segment_c.h" #include "segcore/segcore_init_c.h" #include "common/init_c.h" +#include "exec/expression/function/init_c.h" */ import "C" @@ -252,6 +253,7 @@ func (node *QueryNode) InitSegcore() error { } initcore.InitTraceConfig(paramtable.Get()) + C.InitExecExpressionFunctionFactory() return nil } diff --git a/tests/go_client/common/utils.go b/tests/go_client/common/utils.go index b87e4ff837..f66d9cb8a9 100644 --- a/tests/go_client/common/utils.go +++ b/tests/go_client/common/utils.go @@ -151,8 +151,8 @@ var InvalidExpressions = []InvalidExprStruct{ {Expr: fmt.Sprintf("json_contains (%s['list'], [2])", DefaultJSONFieldName), ErrNil: true, ErrMsg: ""}, {Expr: fmt.Sprintf("json_contains_all (%s['list'], 2)", DefaultJSONFieldName), ErrNil: false, ErrMsg: "contains_all operation element must be an array"}, {Expr: fmt.Sprintf("JSON_CONTAINS_ANY (%s['list'], 2)", DefaultJSONFieldName), ErrNil: false, ErrMsg: "contains_any operation element must be an array"}, - {Expr: fmt.Sprintf("json_contains_aby (%s['list'], 2)", DefaultJSONFieldName), ErrNil: false, ErrMsg: "invalid expression: json_contains_aby"}, - {Expr: fmt.Sprintf("json_contains_aby (%s['list'], 2)", DefaultJSONFieldName), ErrNil: false, ErrMsg: "invalid expression: json_contains_aby"}, + {Expr: fmt.Sprintf("json_contains_aby (%s['list'], 2)", DefaultJSONFieldName), ErrNil: false, ErrMsg: "function json_contains_aby(json, int64_t) not found."}, + {Expr: fmt.Sprintf("json_contains_aby (%s['list'], 2)", DefaultJSONFieldName), ErrNil: false, ErrMsg: "function json_contains_aby(json, int64_t) not found."}, {Expr: fmt.Sprintf("%s[-1] > %d", DefaultInt8ArrayField, TestCapacity), ErrNil: false, ErrMsg: "cannot parse expression"}, // array[-1] > {Expr: fmt.Sprintf("%s[-1] > 1", DefaultJSONFieldName), ErrNil: false, ErrMsg: "invalid expression"}, // json[-1] > } diff --git a/tests/python_client/testcases/test_query.py b/tests/python_client/testcases/test_query.py index f15184b549..ac8f47b5e2 100644 --- a/tests/python_client/testcases/test_query.py +++ b/tests/python_client/testcases/test_query.py @@ -5415,3 +5415,59 @@ class TestQueryTextMatchNegative(TestcaseBase): check_task=CheckTasks.err_res, check_items=error, ) + + +class TestQueryFunction(TestcaseBase): + @pytest.mark.tags(CaseLabel.L1) + def test_query_function_calls(self): + """ + target: test query data + method: create collection and insert data + query with mix call expr in string field and int field + expected: query successfully + """ + collection_w, vectors = self.init_collection_general(prefix, insert_data=True, + primary_field=ct.default_string_field_name)[0:2] + res = vectors[0].iloc[:, 1:3].to_dict('records') + output_fields = [default_float_field_name, default_string_field_name] + for mixed_call_expr in [ + "not empty(varchar) && int64 >= 0", + # function call is case-insensitive + "not EmPty(varchar) && int64 >= 0", + "not EMPTY(varchar) && int64 >= 0", + "starts_with(varchar, varchar) && int64 >= 0", + ]: + collection_w.query( + mixed_call_expr, + output_fields=output_fields, + check_task=CheckTasks.check_query_results, + check_items={exp_res: res}, + ) + + @pytest.mark.tags(CaseLabel.L1) + def test_query_invalid(self): + """ + target: test query with invalid call expression + method: query with invalid call expr + expected: raise exception + """ + collection_w, entities = self.init_collection_general( + prefix, insert_data=True, nb=10 + )[0:2] + test_cases = [ + ( + "A_FUNCTION_THAT_DOES_NOT_EXIST()", + "function A_FUNCTION_THAT_DOES_NOT_EXIST() not found", + ), + # empty + ("empty()", "function empty() not found"), + (f"empty({default_int_field_name})", "function empty(int64_t) not found"), + # starts_with + (f"starts_with({default_int_field_name})", "function starts_with(int64_t) not found"), + (f"starts_with({default_int_field_name}, {default_int_field_name})", "function starts_with(int64_t, int64_t) not found"), + ] + for call_expr, err_msg in test_cases: + error = {ct.err_code: 65535, ct.err_msg: err_msg} + collection_w.query( + call_expr, check_task=CheckTasks.err_res, check_items=error + )