feat: Implement custom function module in milvus expr (#36560)

OSPP 2024 project:
https://summer-ospp.ac.cn/org/prodetail/247410235?list=org&navpage=org

Solutions:

- parser (planparserv2)
    - add CallExpr in planparserv2/Plan.g4
    - update parser_visitor and show_visitor
- grpc protobuf
    - add CallExpr in plan.proto
- execution (`core/src/exec`)
- add `CallExpr` `ValueExpr` and `ColumnExpr` (both logical and
physical) for function call and function parameters
- function factory (`core/src/exec/expression/function`)
    - create a global hashmap when starting milvus (see server.go)
- the global hashmap stores function signatures and their function
pointers, the CallExpr in execution engine can get the function pointer
by function signature.
- custom functions
    - empty(string)
    - starts_with(string, string)
- add cpp/go unittests and E2E tests

closes: #36559

Signed-off-by: Yinzuo Jiang <jiangyinzuo@foxmail.com>
This commit is contained in:
Yinzuo Jiang 2024-10-25 15:25:30 +08:00 committed by GitHub
parent b45cf2d49f
commit 3628593d20
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
50 changed files with 3003 additions and 791 deletions

View File

@ -1,20 +1,15 @@
# Visitor Pattern # Visitor Pattern
Visitor Pattern is used in segcore for parse and execute Execution Plan. Visitor Pattern is used in segcore for parse and execute Execution Plan.
1. Inside `${core}/src/query/PlanNode.h`, contains physical plan for vector search: 1. Inside `${internal/core}/src/query/PlanNode.h`, contains physical plan for vector search:
1. `FloatVectorANNS` FloatVector search execution node 1. `FloatVectorANNS` FloatVector search execution node
2. `BinaryVectorANNS` BinaryVector search execution node 2. `BinaryVectorANNS` BinaryVector search execution node
2. `${core}/src/query/Expr.h` contains physical plan for scalar expression: 2. `${internal/core}/src/query/Expr.h` contains physical plan for scalar expression:
1. `TermExpr` support operation like `col in [1, 2, 3]` 1. `TermExpr` support operation like `col in [1, 2, 3]`
2. `RangeExpr` support constant compare with data column like `a >= 5` `1 < b < 2` 2. `RangeExpr` support constant compare with data column like `a >= 5` `1 < b < 2`
3. `CompareExpr` support compare with different columns, like `a < b` 3. `CompareExpr` support compare with different columns, like `a < b`
4. `LogicalBinaryExpr` support and/or 4. `LogicalBinaryExpr` support and/or
5. `LogicalUnaryExpr` support not 5. `LogicalUnaryExpr` support not
Currently, under `${core/query/visitors}` directory, there are the following visitors: Currently, under `${internal/core/src/query}` directory, there are the following visitors:
1. `ShowPlanNodeVisitor` prints PlanNode in json 1. `ExecPlanNodeVistor` physical plan executor only supports ANNS node for now
2. `ShowExprVisitor` Expr -> json
3. `Verify...Visitor` validates ...
4. `ExtractInfo...Visitor` extracts info from..., including involved_fields and else
5. `ExecExprVisitor` generates bitmask according to expression
6. `ExecPlanNodeVistor` physical plan executor only supports ANNS node for now

View File

@ -292,6 +292,12 @@ install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/src/segcore/
FILES_MATCHING PATTERN "*_c.h" FILES_MATCHING PATTERN "*_c.h"
) )
# Install exec/expression/function
install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/src/exec/expression/function/
DESTINATION include/exec/expression/function
FILES_MATCHING PATTERN "*_c.h"
)
# Install indexbuilder # Install indexbuilder
install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/src/indexbuilder/ install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/src/indexbuilder/
DESTINATION include/indexbuilder DESTINATION include/indexbuilder

View File

@ -59,6 +59,7 @@ using float16 = knowhere::fp16;
using bfloat16 = knowhere::bf16; using bfloat16 = knowhere::bf16;
using bin1 = knowhere::bin1; using bin1 = knowhere::bin1;
// See also: https://github.com/milvus-io/milvus-proto/blob/master/proto/schema.proto
enum class DataType { enum class DataType {
NONE = 0, NONE = 0,
BOOL = 1, BOOL = 1,
@ -682,4 +683,4 @@ struct fmt::formatter<milvus::OpType> : formatter<string_view> {
} }
return formatter<string_view>::format(name, ctx); return formatter<string_view>::format(name, ctx);
} }
}; };

View File

@ -17,11 +17,13 @@
#pragma once #pragma once
#include <memory> #include <memory>
#include <string>
#include "EasyAssert.h" #include "EasyAssert.h"
#include "Types.h" #include "Types.h"
#include "bitset/bitset.h"
#include "common/FieldData.h" #include "common/FieldData.h"
#include "common/FieldDataInterface.h"
#include "common/Types.h"
namespace milvus { namespace milvus {
@ -29,7 +31,6 @@ namespace milvus {
* @brief base class for different type vector * @brief base class for different type vector
* @todo implement full null value support * @todo implement full null value support
*/ */
class BaseVector { class BaseVector {
public: public:
BaseVector(DataType data_type, BaseVector(DataType data_type,
@ -58,18 +59,39 @@ class BaseVector {
using VectorPtr = std::shared_ptr<BaseVector>; using VectorPtr = std::shared_ptr<BaseVector>;
/**
* SimpleVector abstracts over various Columnar Storage Formats,
* it is used in custom functions.
*/
class SimpleVector : public BaseVector {
public:
SimpleVector(DataType data_type,
size_t length,
std::optional<size_t> null_count = std::nullopt)
: BaseVector(data_type, length, null_count) {
}
virtual void*
RawValueAt(size_t index, size_t size_of_element) = 0;
virtual bool
ValidAt(size_t index) = 0;
};
/** /**
* @brief Single vector for scalar types * @brief Single vector for scalar types
* @todo using memory pool && buffer replace FieldData * @todo using memory pool && buffer replace FieldData
*/ */
class ColumnVector final : public BaseVector { class ColumnVector final : public SimpleVector {
public: public:
ColumnVector(DataType data_type, ColumnVector(DataType data_type,
size_t length, size_t length,
std::optional<size_t> null_count = std::nullopt) std::optional<size_t> null_count = std::nullopt)
: BaseVector(data_type, length, null_count) { : SimpleVector(data_type, length, null_count),
is_bitmap_(false),
valid_values_(length,
!null_count.has_value() || null_count.value() == 0) {
values_ = InitScalarFieldData(data_type, false, length); values_ = InitScalarFieldData(data_type, false, length);
valid_values_ = InitScalarFieldData(data_type, false, length);
} }
// ColumnVector(FixedVector<bool>&& data) // ColumnVector(FixedVector<bool>&& data)
@ -78,20 +100,14 @@ class ColumnVector final : public BaseVector {
// std::make_shared<FieldData<bool>>(DataType::BOOL, std::move(data)); // std::make_shared<FieldData<bool>>(DataType::BOOL, std::move(data));
// } // }
// // the size is the number of bits
// ColumnVector(TargetBitmap&& bitmap)
// : BaseVector(DataType::INT8, bitmap.size()) {
// values_ = std::make_shared<FieldDataImpl<uint8_t, false>>(
// bitmap.size(), DataType::INT8, false, std::move(bitmap).into());
// }
// the size is the number of bits // the size is the number of bits
// TODO: separate the usage of bitmap from scalar field data
ColumnVector(TargetBitmap&& bitmap, TargetBitmap&& valid_bitmap) ColumnVector(TargetBitmap&& bitmap, TargetBitmap&& valid_bitmap)
: BaseVector(DataType::INT8, bitmap.size()) { : SimpleVector(DataType::INT8, bitmap.size()),
is_bitmap_(true),
valid_values_(std::move(valid_bitmap)) {
values_ = std::make_shared<FieldBitsetImpl<uint8_t>>(DataType::INT8, values_ = std::make_shared<FieldBitsetImpl<uint8_t>>(DataType::INT8,
std::move(bitmap)); std::move(bitmap));
valid_values_ = std::make_shared<FieldBitsetImpl<uint8_t>>(
DataType::INT8, std::move(valid_bitmap));
} }
virtual ~ColumnVector() override { virtual ~ColumnVector() override {
@ -100,28 +116,81 @@ class ColumnVector final : public BaseVector {
} }
void* void*
GetRawData() { RawValueAt(size_t index, size_t size_of_element) override {
return reinterpret_cast<char*>(GetRawData()) + index * size_of_element;
}
bool
ValidAt(size_t index) override {
return valid_values_[index];
}
void*
GetRawData() const {
return values_->Data(); return values_->Data();
} }
void* void*
GetValidRawData() { GetValidRawData() {
return valid_values_->Data(); return valid_values_.data();
} }
template <typename As> template <typename As>
const As* As*
RawAsValues() const { RawAsValues() const {
return reinterpret_cast<const As*>(values_->Data()); return reinterpret_cast<As*>(values_->Data());
}
bool
IsBitmap() const {
return is_bitmap_;
} }
private: private:
bool is_bitmap_; // TODO: remove the field after implementing BitmapVector
FieldDataPtr values_; FieldDataPtr values_;
FieldDataPtr valid_values_; TargetBitmap valid_values_; // false means the value is null
}; };
using ColumnVectorPtr = std::shared_ptr<ColumnVector>; using ColumnVectorPtr = std::shared_ptr<ColumnVector>;
template <typename T>
class ConstantVector : public SimpleVector {
public:
ConstantVector(DataType data_type,
size_t length,
const T& val,
std::optional<size_t> null_count = std::nullopt)
: SimpleVector(data_type, length),
val_(val),
is_null_(null_count.has_value() && null_count.value() > 0) {
}
void*
RawValueAt(size_t _index, size_t _size_of_element) override {
return &val_;
}
bool
ValidAt(size_t _index) override {
return !is_null_;
}
const T&
GetValue() const {
return val_;
}
bool
IsNull() const {
return is_null_;
}
private:
T val_;
bool is_null_;
};
/** /**
* @brief Multi vectors for scalar types * @brief Multi vectors for scalar types
* mainly using it to pass internal result in segcore scalar engine system * mainly using it to pass internal result in segcore scalar engine system
@ -149,8 +218,7 @@ class RowVector : public BaseVector {
} }
RowVector(std::vector<VectorPtr>&& children) RowVector(std::vector<VectorPtr>&& children)
: BaseVector(DataType::ROW, 0) { : BaseVector(DataType::ROW, 0), children_values_(std::move(children)) {
children_values_ = std::move(children);
for (auto& child : children_values_) { for (auto& child : children_values_) {
if (child->size() > length_) { if (child->size() > length_) {
length_ = child->size(); length_ = child->size();
@ -159,12 +227,12 @@ class RowVector : public BaseVector {
} }
const std::vector<VectorPtr>& const std::vector<VectorPtr>&
childrens() { childrens() const {
return children_values_; return children_values_;
} }
VectorPtr VectorPtr
child(int index) { child(int index) const {
assert(index < children_values_.size()); assert(index < children_values_.size());
return children_values_[index]; return children_values_[index];
} }
@ -174,5 +242,4 @@ class RowVector : public BaseVector {
}; };
using RowVectorPtr = std::shared_ptr<RowVector>; using RowVectorPtr = std::shared_ptr<RowVector>;
} // namespace milvus } // namespace milvus

View File

@ -105,4 +105,4 @@ SetTrace(CTraceConfig* config) {
config->oltpSecure, config->oltpSecure,
config->nodeID}; config->nodeID};
milvus::tracer::initTelemetry(traceConfig); milvus::tracer::initTelemetry(traceConfig);
} }

View File

@ -235,4 +235,4 @@ Task::Next(ContinueFuture* future) {
} }
} // namespace exec } // namespace exec
} // namespace milvus } // namespace milvus

View File

@ -0,0 +1,46 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "common/FieldDataInterface.h"
#include "common/Vector.h"
#include "exec/expression/CallExpr.h"
#include "exec/expression/EvalCtx.h"
#include "exec/expression/function/FunctionFactory.h"
#include <utility>
#include <vector>
namespace milvus {
namespace exec {
void
PhyCallExpr::Eval(EvalCtx& context, VectorPtr& result) {
AssertInfo(inputs_.size() == expr_->inputs().size(),
"logical call expr needs {} inputs, but {} inputs are provided",
expr_->inputs().size(),
inputs_.size());
std::vector<VectorPtr> args;
for (auto& input : this->inputs_) {
VectorPtr arg_result;
input->Eval(context, arg_result);
args.push_back(std::move(arg_result));
}
RowVector row_vector(std::move(args));
this->expr_->function_ptr()(row_vector, result);
}
} // namespace exec
} // namespace milvus

View File

@ -0,0 +1,83 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cstdint>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "common/EasyAssert.h"
#include "common/FieldDataInterface.h"
#include "common/Utils.h"
#include "common/Vector.h"
#include "exec/expression/EvalCtx.h"
#include "exec/expression/Expr.h"
#include "exec/expression/function/FunctionFactory.h"
#include "expr/ITypeExpr.h"
#include "fmt/core.h"
#include "segcore/SegmentInterface.h"
namespace milvus {
namespace exec {
class PhyCallExpr : public Expr {
public:
PhyCallExpr(const std::vector<std::shared_ptr<Expr>>& input,
const std::shared_ptr<const milvus::expr::CallExpr>& expr,
const std::string& name,
const segcore::SegmentInternalInterface* segment,
int64_t active_count,
int64_t batch_size)
: Expr(DataType::BOOL, std::move(input), name),
expr_(expr),
active_count_(active_count),
segment_(segment),
batch_size_(batch_size) {
size_per_chunk_ = segment_->size_per_chunk();
num_chunk_ = upper_div(active_count_, size_per_chunk_);
AssertInfo(
batch_size_ > 0,
fmt::format("expr batch size should greater than zero, but now: {}",
batch_size_));
}
void
Eval(EvalCtx& context, VectorPtr& result) override;
void
MoveCursor() override {
for (auto input : inputs_) {
input->MoveCursor();
}
}
private:
std::shared_ptr<const milvus::expr::CallExpr> expr_;
int64_t active_count_{0};
int64_t num_chunk_{0};
int64_t current_chunk_id_{0};
int64_t current_chunk_pos_{0};
int64_t size_per_chunk_{0};
const segcore::SegmentInternalInterface* segment_;
int64_t batch_size_;
};
} // namespace exec
} // namespace milvus

View File

@ -0,0 +1,147 @@
// Licensed to the LF AI & Data foundation under on0
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ColumnExpr.h"
namespace milvus {
namespace exec {
int64_t
PhyColumnExpr::GetNextBatchSize() {
auto current_rows = GetCurrentRows();
return current_rows + batch_size_ >= segment_chunk_reader_.active_count_
? segment_chunk_reader_.active_count_ - current_rows
: batch_size_;
}
void
PhyColumnExpr::Eval(EvalCtx& context, VectorPtr& result) {
switch (this->expr_->type()) {
case DataType::BOOL:
result = DoEval<bool>();
break;
case DataType::INT8:
result = DoEval<int8_t>();
break;
case DataType::INT16:
result = DoEval<int16_t>();
break;
case DataType::INT32:
result = DoEval<int32_t>();
break;
case DataType::INT64:
result = DoEval<int64_t>();
break;
case DataType::FLOAT:
result = DoEval<float>();
break;
case DataType::DOUBLE:
result = DoEval<double>();
break;
case DataType::VARCHAR: {
result = DoEval<std::string>();
break;
}
default:
PanicInfo(DataTypeInvalid,
"unsupported data type: {}",
this->expr_->type());
}
}
template <typename T>
VectorPtr
PhyColumnExpr::DoEval() {
// similar to PhyCompareFilterExpr::ExecCompareExprDispatcher(OpType op)
if (segment_chunk_reader_.segment_->is_chunked()) {
auto real_batch_size = GetNextBatchSize();
if (real_batch_size == 0) {
return nullptr;
}
auto res_vec = std::make_shared<ColumnVector>(
expr_->GetColumn().data_type_, real_batch_size);
T* res_value = res_vec->RawAsValues<T>();
TargetBitmapView valid_res(res_vec->GetValidRawData(), real_batch_size);
valid_res.set();
auto chunk_data = segment_chunk_reader_.GetChunkDataAccessor(
expr_->GetColumn().data_type_,
expr_->GetColumn().field_id_,
is_indexed_,
current_chunk_id_,
current_chunk_pos_);
for (int i = 0; i < real_batch_size; ++i) {
if (!chunk_data().has_value()) {
valid_res[i] = false;
continue;
}
res_value[i] = boost::get<T>(chunk_data().value());
}
return res_vec;
} else {
auto real_batch_size = GetNextBatchSize();
if (real_batch_size == 0) {
return nullptr;
}
auto res_vec = std::make_shared<ColumnVector>(
expr_->GetColumn().data_type_, real_batch_size);
T* res_value = res_vec->RawAsValues<T>();
TargetBitmapView valid_res(res_vec->GetValidRawData(), real_batch_size);
valid_res.set();
auto data_barrier = segment_chunk_reader_.segment_->num_chunk_data(
expr_->GetColumn().field_id_);
int64_t processed_rows = 0;
for (int64_t chunk_id = current_chunk_id_; chunk_id < num_chunk_;
++chunk_id) {
auto chunk_size =
chunk_id == num_chunk_ - 1
? segment_chunk_reader_.active_count_ -
chunk_id * segment_chunk_reader_.SizePerChunk()
: segment_chunk_reader_.SizePerChunk();
auto chunk_data = segment_chunk_reader_.GetChunkDataAccessor(
expr_->GetColumn().data_type_,
expr_->GetColumn().field_id_,
chunk_id,
data_barrier);
for (int i = chunk_id == current_chunk_id_ ? current_chunk_pos_ : 0;
i < chunk_size;
++i) {
if (!chunk_data(i).has_value()) {
valid_res[processed_rows] = false;
} else {
res_value[processed_rows] =
boost::get<T>(chunk_data(i).value());
}
processed_rows++;
if (processed_rows >= batch_size_) {
current_chunk_id_ = chunk_id;
current_chunk_pos_ = i + 1;
return res_vec;
}
}
}
return res_vec;
}
}
} //namespace exec
} // namespace milvus

View File

@ -0,0 +1,125 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <fmt/core.h>
#include <boost/variant.hpp>
#include "common/EasyAssert.h"
#include "common/Types.h"
#include "common/Vector.h"
#include "exec/expression/Expr.h"
#include "segcore/SegmentInterface.h"
#include "segcore/SegmentChunkReader.h"
namespace milvus {
namespace exec {
class PhyColumnExpr : public Expr {
public:
PhyColumnExpr(const std::vector<std::shared_ptr<Expr>>& input,
const std::shared_ptr<const milvus::expr::ColumnExpr>& expr,
const std::string& name,
const segcore::SegmentInternalInterface* segment,
int64_t active_count,
int64_t batch_size)
: Expr(expr->type(), std::move(input), name),
segment_chunk_reader_(segment, active_count),
batch_size_(batch_size),
expr_(expr) {
is_indexed_ = segment->HasIndex(expr_->GetColumn().field_id_);
if (segment->is_chunked()) {
num_chunk_ =
is_indexed_
? segment->num_chunk_index(expr_->GetColumn().field_id_)
: segment->type() == SegmentType::Growing
? upper_div(segment_chunk_reader_.active_count_,
segment_chunk_reader_.SizePerChunk())
: segment->num_chunk_data(expr_->GetColumn().field_id_);
} else {
num_chunk_ =
is_indexed_
? segment->num_chunk_index(expr_->GetColumn().field_id_)
: upper_div(segment_chunk_reader_.active_count_,
segment_chunk_reader_.SizePerChunk());
}
AssertInfo(
batch_size_ > 0,
fmt::format("expr batch size should greater than zero, but now: {}",
batch_size_));
}
void
Eval(EvalCtx& context, VectorPtr& result) override;
void
MoveCursor() override {
if (segment_chunk_reader_.segment_->is_chunked()) {
segment_chunk_reader_.MoveCursorForMultipleChunk(
current_chunk_id_,
current_chunk_pos_,
expr_->GetColumn().field_id_,
num_chunk_,
batch_size_);
} else {
segment_chunk_reader_.MoveCursorForSingleChunk(
current_chunk_id_, current_chunk_pos_, num_chunk_, batch_size_);
}
}
private:
int64_t
GetCurrentRows() const {
if (segment_chunk_reader_.segment_->is_chunked()) {
auto current_rows =
is_indexed_ && segment_chunk_reader_.segment_->type() ==
SegmentType::Sealed
? current_chunk_pos_
: segment_chunk_reader_.segment_->num_rows_until_chunk(
expr_->GetColumn().field_id_, current_chunk_id_) +
current_chunk_pos_;
return current_rows;
} else {
return segment_chunk_reader_.segment_->type() ==
SegmentType::Growing
? current_chunk_id_ *
segment_chunk_reader_.SizePerChunk() +
current_chunk_pos_
: current_chunk_pos_;
}
}
int64_t
GetNextBatchSize();
template <typename T>
VectorPtr
DoEval();
private:
bool is_indexed_;
int64_t num_chunk_{0};
int64_t current_chunk_id_{0};
int64_t current_chunk_pos_{0};
const segcore::SegmentChunkReader segment_chunk_reader_;
int64_t batch_size_;
std::shared_ptr<const milvus::expr::ColumnExpr> expr_;
};
} //namespace exec
} // namespace milvus

View File

@ -15,7 +15,6 @@
// limitations under the License. // limitations under the License.
#include "CompareExpr.h" #include "CompareExpr.h"
#include "common/type_c.h"
#include <optional> #include <optional>
#include "query/Relational.h" #include "query/Relational.h"
@ -32,212 +31,15 @@ int64_t
PhyCompareFilterExpr::GetNextBatchSize() { PhyCompareFilterExpr::GetNextBatchSize() {
auto current_rows = GetCurrentRows(); auto current_rows = GetCurrentRows();
return current_rows + batch_size_ >= active_count_ return current_rows + batch_size_ >= segment_chunk_reader_.active_count_
? active_count_ - current_rows ? segment_chunk_reader_.active_count_ - current_rows
: batch_size_; : batch_size_;
} }
template <typename T>
MultipleChunkDataAccessor
PhyCompareFilterExpr::GetChunkData(FieldId field_id,
bool index,
int64_t& current_chunk_id,
int64_t& current_chunk_pos) {
if (index) {
auto& indexing = const_cast<index::ScalarIndex<T>&>(
segment_->chunk_scalar_index<T>(field_id, current_chunk_id));
auto current_chunk_size = segment_->type() == SegmentType::Growing
? size_per_chunk_
: active_count_;
if (indexing.HasRawData()) {
return [&, current_chunk_size]() -> const number {
if (current_chunk_pos >= current_chunk_size) {
current_chunk_id++;
current_chunk_pos = 0;
indexing = const_cast<index::ScalarIndex<T>&>(
segment_->chunk_scalar_index<T>(field_id,
current_chunk_id));
}
auto raw = indexing.Reverse_Lookup(current_chunk_pos);
current_chunk_pos++;
if (!raw.has_value()) {
return std::nullopt;
}
return raw.value();
};
}
}
auto chunk_data =
segment_->chunk_data<T>(field_id, current_chunk_id).data();
auto chunk_valid_data =
segment_->chunk_data<T>(field_id, current_chunk_id).valid_data();
auto current_chunk_size = segment_->chunk_size(field_id, current_chunk_id);
return
[=, &current_chunk_id, &current_chunk_pos]() mutable -> const number {
if (current_chunk_pos >= current_chunk_size) {
current_chunk_id++;
current_chunk_pos = 0;
chunk_data =
segment_->chunk_data<T>(field_id, current_chunk_id).data();
chunk_valid_data =
segment_->chunk_data<T>(field_id, current_chunk_id)
.valid_data();
current_chunk_size =
segment_->chunk_size(field_id, current_chunk_id);
}
if (chunk_valid_data && !chunk_valid_data[current_chunk_pos]) {
current_chunk_pos++;
return std::nullopt;
}
return chunk_data[current_chunk_pos++];
};
}
template <>
MultipleChunkDataAccessor
PhyCompareFilterExpr::GetChunkData<std::string>(FieldId field_id,
bool index,
int64_t& current_chunk_id,
int64_t& current_chunk_pos) {
if (index) {
auto& indexing = const_cast<index::ScalarIndex<std::string>&>(
segment_->chunk_scalar_index<std::string>(field_id,
current_chunk_id));
auto current_chunk_size = segment_->type() == SegmentType::Growing
? size_per_chunk_
: active_count_;
if (indexing.HasRawData()) {
return [&, current_chunk_size]() mutable -> const number {
if (current_chunk_pos >= current_chunk_size) {
current_chunk_id++;
current_chunk_pos = 0;
indexing = const_cast<index::ScalarIndex<std::string>&>(
segment_->chunk_scalar_index<std::string>(
field_id, current_chunk_id));
}
auto raw = indexing.Reverse_Lookup(current_chunk_pos);
current_chunk_pos++;
if (!raw.has_value()) {
return std::nullopt;
}
return raw.value();
};
}
}
if (segment_->type() == SegmentType::Growing &&
!storage::MmapManager::GetInstance()
.GetMmapConfig()
.growing_enable_mmap) {
auto chunk_data =
segment_->chunk_data<std::string>(field_id, current_chunk_id)
.data();
auto chunk_valid_data =
segment_->chunk_data<std::string>(field_id, current_chunk_id)
.valid_data();
auto current_chunk_size =
segment_->chunk_size(field_id, current_chunk_id);
return [=,
&current_chunk_id,
&current_chunk_pos]() mutable -> const number {
if (current_chunk_pos >= current_chunk_size) {
current_chunk_id++;
current_chunk_pos = 0;
chunk_data =
segment_
->chunk_data<std::string>(field_id, current_chunk_id)
.data();
chunk_valid_data =
segment_
->chunk_data<std::string>(field_id, current_chunk_id)
.valid_data();
current_chunk_size =
segment_->chunk_size(field_id, current_chunk_id);
}
if (chunk_valid_data && !chunk_valid_data[current_chunk_pos]) {
current_chunk_pos++;
return std::nullopt;
}
return chunk_data[current_chunk_pos++];
};
} else {
auto chunk_data =
segment_->chunk_view<std::string_view>(field_id, current_chunk_id)
.first.data();
auto chunk_valid_data =
segment_->chunk_data<std::string_view>(field_id, current_chunk_id)
.valid_data();
auto current_chunk_size =
segment_->chunk_size(field_id, current_chunk_id);
return [=,
&current_chunk_id,
&current_chunk_pos]() mutable -> const number {
if (current_chunk_pos >= current_chunk_size) {
current_chunk_id++;
current_chunk_pos = 0;
chunk_data = segment_
->chunk_view<std::string_view>(
field_id, current_chunk_id)
.first.data();
chunk_valid_data = segment_
->chunk_data<std::string_view>(
field_id, current_chunk_id)
.valid_data();
current_chunk_size =
segment_->chunk_size(field_id, current_chunk_id);
}
if (chunk_valid_data && !chunk_valid_data[current_chunk_pos]) {
current_chunk_pos++;
return std::nullopt;
}
return std::string(chunk_data[current_chunk_pos++]);
};
}
}
MultipleChunkDataAccessor
PhyCompareFilterExpr::GetChunkData(DataType data_type,
FieldId field_id,
bool index,
int64_t& current_chunk_id,
int64_t& current_chunk_pos) {
switch (data_type) {
case DataType::BOOL:
return GetChunkData<bool>(
field_id, index, current_chunk_id, current_chunk_pos);
case DataType::INT8:
return GetChunkData<int8_t>(
field_id, index, current_chunk_id, current_chunk_pos);
case DataType::INT16:
return GetChunkData<int16_t>(
field_id, index, current_chunk_id, current_chunk_pos);
case DataType::INT32:
return GetChunkData<int32_t>(
field_id, index, current_chunk_id, current_chunk_pos);
case DataType::INT64:
return GetChunkData<int64_t>(
field_id, index, current_chunk_id, current_chunk_pos);
case DataType::FLOAT:
return GetChunkData<float>(
field_id, index, current_chunk_id, current_chunk_pos);
case DataType::DOUBLE:
return GetChunkData<double>(
field_id, index, current_chunk_id, current_chunk_pos);
case DataType::VARCHAR: {
return GetChunkData<std::string>(
field_id, index, current_chunk_id, current_chunk_pos);
}
default:
PanicInfo(DataTypeInvalid, "unsupported data type: {}", data_type);
}
}
template <typename OpType> template <typename OpType>
VectorPtr VectorPtr
PhyCompareFilterExpr::ExecCompareExprDispatcher(OpType op) { PhyCompareFilterExpr::ExecCompareExprDispatcher(OpType op) {
if (segment_->is_chunked()) { if (segment_chunk_reader_.segment_->is_chunked()) {
auto real_batch_size = GetNextBatchSize(); auto real_batch_size = GetNextBatchSize();
if (real_batch_size == 0) { if (real_batch_size == 0) {
return nullptr; return nullptr;
@ -249,16 +51,18 @@ PhyCompareFilterExpr::ExecCompareExprDispatcher(OpType op) {
TargetBitmapView valid_res(res_vec->GetValidRawData(), real_batch_size); TargetBitmapView valid_res(res_vec->GetValidRawData(), real_batch_size);
valid_res.set(); valid_res.set();
auto left = GetChunkData(expr_->left_data_type_, auto left =
expr_->left_field_id_, segment_chunk_reader_.GetChunkDataAccessor(expr_->left_data_type_,
is_left_indexed_, expr_->left_field_id_,
left_current_chunk_id_, is_left_indexed_,
left_current_chunk_pos_); left_current_chunk_id_,
auto right = GetChunkData(expr_->right_data_type_, left_current_chunk_pos_);
expr_->right_field_id_, auto right = segment_chunk_reader_.GetChunkDataAccessor(
is_right_indexed_, expr_->right_data_type_,
right_current_chunk_id_, expr_->right_field_id_,
right_current_chunk_pos_); is_right_indexed_,
right_current_chunk_id_,
right_current_chunk_pos_);
for (int i = 0; i < real_batch_size; ++i) { for (int i = 0; i < real_batch_size; ++i) {
if (!left().has_value() || !right().has_value()) { if (!left().has_value() || !right().has_value()) {
res[i] = false; res[i] = false;
@ -283,25 +87,30 @@ PhyCompareFilterExpr::ExecCompareExprDispatcher(OpType op) {
TargetBitmapView valid_res(res_vec->GetValidRawData(), real_batch_size); TargetBitmapView valid_res(res_vec->GetValidRawData(), real_batch_size);
valid_res.set(); valid_res.set();
auto left_data_barrier = auto left_data_barrier = segment_chunk_reader_.segment_->num_chunk_data(
segment_->num_chunk_data(expr_->left_field_id_); expr_->left_field_id_);
auto right_data_barrier = auto right_data_barrier =
segment_->num_chunk_data(expr_->right_field_id_); segment_chunk_reader_.segment_->num_chunk_data(
expr_->right_field_id_);
int64_t processed_rows = 0; int64_t processed_rows = 0;
for (int64_t chunk_id = current_chunk_id_; chunk_id < num_chunk_; for (int64_t chunk_id = current_chunk_id_; chunk_id < num_chunk_;
++chunk_id) { ++chunk_id) {
auto chunk_size = chunk_id == num_chunk_ - 1 auto chunk_size =
? active_count_ - chunk_id * size_per_chunk_ chunk_id == num_chunk_ - 1
: size_per_chunk_; ? segment_chunk_reader_.active_count_ -
auto left = GetChunkData(expr_->left_data_type_, chunk_id * segment_chunk_reader_.SizePerChunk()
expr_->left_field_id_, : segment_chunk_reader_.SizePerChunk();
chunk_id, auto left = segment_chunk_reader_.GetChunkDataAccessor(
left_data_barrier); expr_->left_data_type_,
auto right = GetChunkData(expr_->right_data_type_, expr_->left_field_id_,
expr_->right_field_id_, chunk_id,
chunk_id, left_data_barrier);
right_data_barrier); auto right = segment_chunk_reader_.GetChunkDataAccessor(
expr_->right_data_type_,
expr_->right_field_id_,
chunk_id,
right_data_barrier);
for (int i = chunk_id == current_chunk_id_ ? current_chunk_pos_ : 0; for (int i = chunk_id == current_chunk_id_ ? current_chunk_pos_ : 0;
i < chunk_size; i < chunk_size;
@ -328,108 +137,6 @@ PhyCompareFilterExpr::ExecCompareExprDispatcher(OpType op) {
} }
} }
template <typename T>
ChunkDataAccessor
PhyCompareFilterExpr::GetChunkData(FieldId field_id,
int chunk_id,
int data_barrier) {
if (chunk_id >= data_barrier) {
auto& indexing = segment_->chunk_scalar_index<T>(field_id, chunk_id);
if (indexing.HasRawData()) {
return [&indexing](int i) -> const number {
auto raw = indexing.Reverse_Lookup(i);
if (!raw.has_value()) {
return std::nullopt;
}
return raw.value();
};
}
}
auto chunk_data = segment_->chunk_data<T>(field_id, chunk_id).data();
auto chunk_valid_data =
segment_->chunk_data<T>(field_id, chunk_id).valid_data();
return [chunk_data, chunk_valid_data](int i) -> const number {
if (chunk_valid_data && !chunk_valid_data[i]) {
return std::nullopt;
}
return chunk_data[i];
};
}
template <>
ChunkDataAccessor
PhyCompareFilterExpr::GetChunkData<std::string>(FieldId field_id,
int chunk_id,
int data_barrier) {
if (chunk_id >= data_barrier) {
auto& indexing =
segment_->chunk_scalar_index<std::string>(field_id, chunk_id);
if (indexing.HasRawData()) {
return [&indexing](int i) -> const number {
auto raw = indexing.Reverse_Lookup(i);
if (!raw.has_value()) {
return std::nullopt;
}
return raw.value();
};
}
}
if (segment_->type() == SegmentType::Growing &&
!storage::MmapManager::GetInstance()
.GetMmapConfig()
.growing_enable_mmap) {
auto chunk_data =
segment_->chunk_data<std::string>(field_id, chunk_id).data();
auto chunk_valid_data =
segment_->chunk_data<std::string>(field_id, chunk_id).valid_data();
return [chunk_data, chunk_valid_data](int i) -> const number {
if (chunk_valid_data && !chunk_valid_data[i]) {
return std::nullopt;
}
return chunk_data[i];
};
} else {
auto chunk_info =
segment_->chunk_view<std::string_view>(field_id, chunk_id);
auto chunk_data = chunk_info.first.data();
auto chunk_valid_data = chunk_info.second.data();
return [chunk_data, chunk_valid_data](int i) -> const number {
if (chunk_valid_data && !chunk_valid_data[i]) {
return std::nullopt;
}
return std::string(chunk_data[i]);
};
}
}
ChunkDataAccessor
PhyCompareFilterExpr::GetChunkData(DataType data_type,
FieldId field_id,
int chunk_id,
int data_barrier) {
switch (data_type) {
case DataType::BOOL:
return GetChunkData<bool>(field_id, chunk_id, data_barrier);
case DataType::INT8:
return GetChunkData<int8_t>(field_id, chunk_id, data_barrier);
case DataType::INT16:
return GetChunkData<int16_t>(field_id, chunk_id, data_barrier);
case DataType::INT32:
return GetChunkData<int32_t>(field_id, chunk_id, data_barrier);
case DataType::INT64:
return GetChunkData<int64_t>(field_id, chunk_id, data_barrier);
case DataType::FLOAT:
return GetChunkData<float>(field_id, chunk_id, data_barrier);
case DataType::DOUBLE:
return GetChunkData<double>(field_id, chunk_id, data_barrier);
case DataType::VARCHAR: {
return GetChunkData<std::string>(field_id, chunk_id, data_barrier);
}
default:
PanicInfo(DataTypeInvalid, "unsupported data type: {}", data_type);
}
}
void void
PhyCompareFilterExpr::Eval(EvalCtx& context, VectorPtr& result) { PhyCompareFilterExpr::Eval(EvalCtx& context, VectorPtr& result) {
// For segment both fields has no index, can use SIMD to speed up. // For segment both fields has no index, can use SIMD to speed up.

View File

@ -18,7 +18,6 @@
#include <fmt/core.h> #include <fmt/core.h>
#include <boost/variant.hpp> #include <boost/variant.hpp>
#include <optional>
#include "common/EasyAssert.h" #include "common/EasyAssert.h"
#include "common/Types.h" #include "common/Types.h"
@ -26,24 +25,11 @@
#include "common/type_c.h" #include "common/type_c.h"
#include "exec/expression/Expr.h" #include "exec/expression/Expr.h"
#include "segcore/SegmentInterface.h" #include "segcore/SegmentInterface.h"
#include "segcore/SegmentChunkReader.h"
namespace milvus { namespace milvus {
namespace exec { namespace exec {
using number_type = boost::variant<bool,
int8_t,
int16_t,
int32_t,
int64_t,
float,
double,
std::string>;
using number = std::optional<number_type>;
using ChunkDataAccessor = std::function<const number(int)>;
using MultipleChunkDataAccessor = std::function<const number()>;
template <typename T, typename U, proto::plan::OpType op> template <typename T, typename U, proto::plan::OpType op>
struct CompareElementFunc { struct CompareElementFunc {
void void
@ -113,31 +99,32 @@ class PhyCompareFilterExpr : public Expr {
: Expr(DataType::BOOL, std::move(input), name), : Expr(DataType::BOOL, std::move(input), name),
left_field_(expr->left_field_id_), left_field_(expr->left_field_id_),
right_field_(expr->right_field_id_), right_field_(expr->right_field_id_),
segment_(segment), segment_chunk_reader_(segment, active_count),
active_count_(active_count),
batch_size_(batch_size), batch_size_(batch_size),
expr_(expr) { expr_(expr) {
is_left_indexed_ = segment_->HasIndex(left_field_); is_left_indexed_ = segment->HasIndex(left_field_);
is_right_indexed_ = segment_->HasIndex(right_field_); is_right_indexed_ = segment->HasIndex(right_field_);
size_per_chunk_ = segment_->size_per_chunk(); if (segment->is_chunked()) {
if (segment_->is_chunked()) {
left_num_chunk_ = left_num_chunk_ =
is_left_indexed_ is_left_indexed_
? segment_->num_chunk_index(expr_->left_field_id_) ? segment->num_chunk_index(expr_->left_field_id_)
: segment_->type() == SegmentType::Growing : segment->type() == SegmentType::Growing
? upper_div(active_count_, size_per_chunk_) ? upper_div(segment_chunk_reader_.active_count_,
: segment_->num_chunk_data(left_field_); segment_chunk_reader_.SizePerChunk())
: segment->num_chunk_data(left_field_);
right_num_chunk_ = right_num_chunk_ =
is_right_indexed_ is_right_indexed_
? segment_->num_chunk_index(expr_->right_field_id_) ? segment->num_chunk_index(expr_->right_field_id_)
: segment_->type() == SegmentType::Growing : segment->type() == SegmentType::Growing
? upper_div(active_count_, size_per_chunk_) ? upper_div(segment_chunk_reader_.active_count_,
: segment_->num_chunk_data(right_field_); segment_chunk_reader_.SizePerChunk())
: segment->num_chunk_data(right_field_);
num_chunk_ = left_num_chunk_; num_chunk_ = left_num_chunk_;
} else { } else {
num_chunk_ = is_left_indexed_ num_chunk_ = is_left_indexed_
? segment_->num_chunk_index(expr_->left_field_id_) ? segment->num_chunk_index(expr_->left_field_id_)
: upper_div(active_count_, size_per_chunk_); : upper_div(segment_chunk_reader_.active_count_,
segment_chunk_reader_.SizePerChunk());
} }
AssertInfo( AssertInfo(
@ -151,128 +138,60 @@ class PhyCompareFilterExpr : public Expr {
void void
MoveCursor() override { MoveCursor() override {
if (segment_->is_chunked()) { if (segment_chunk_reader_.segment_->is_chunked()) {
MoveCursorForMultipleChunk(); segment_chunk_reader_.MoveCursorForMultipleChunk(
left_current_chunk_id_,
left_current_chunk_pos_,
left_field_,
left_num_chunk_,
batch_size_);
segment_chunk_reader_.MoveCursorForMultipleChunk(
right_current_chunk_id_,
right_current_chunk_pos_,
right_field_,
right_num_chunk_,
batch_size_);
} else { } else {
MoveCursorForSingleChunk(); segment_chunk_reader_.MoveCursorForSingleChunk(
} current_chunk_id_, current_chunk_pos_, num_chunk_, batch_size_);
}
void
MoveCursorForMultipleChunk() {
int64_t processed_rows = 0;
for (int64_t chunk_id = left_current_chunk_id_;
chunk_id < left_num_chunk_;
++chunk_id) {
auto chunk_size = 0;
if (segment_->type() == SegmentType::Growing) {
chunk_size = chunk_id == left_num_chunk_ - 1
? active_count_ - chunk_id * size_per_chunk_
: size_per_chunk_;
} else {
chunk_size = segment_->chunk_size(left_field_, chunk_id);
}
for (int i = chunk_id == left_current_chunk_id_
? left_current_chunk_pos_
: 0;
i < chunk_size;
++i) {
if (++processed_rows >= batch_size_) {
left_current_chunk_id_ = chunk_id;
left_current_chunk_pos_ = i + 1;
}
}
}
processed_rows = 0;
for (int64_t chunk_id = right_current_chunk_id_;
chunk_id < right_num_chunk_;
++chunk_id) {
auto chunk_size = 0;
if (segment_->type() == SegmentType::Growing) {
chunk_size = chunk_id == right_num_chunk_ - 1
? active_count_ - chunk_id * size_per_chunk_
: size_per_chunk_;
} else {
chunk_size = segment_->chunk_size(right_field_, chunk_id);
}
for (int i = chunk_id == right_current_chunk_id_
? right_current_chunk_pos_
: 0;
i < chunk_size;
++i) {
if (++processed_rows >= batch_size_) {
right_current_chunk_id_ = chunk_id;
right_current_chunk_pos_ = i + 1;
}
}
}
}
void
MoveCursorForSingleChunk() {
int64_t processed_rows = 0;
for (int64_t chunk_id = current_chunk_id_; chunk_id < num_chunk_;
++chunk_id) {
auto chunk_size = chunk_id == num_chunk_ - 1
? active_count_ - chunk_id * size_per_chunk_
: size_per_chunk_;
for (int i = chunk_id == current_chunk_id_ ? current_chunk_pos_ : 0;
i < chunk_size;
++i) {
if (++processed_rows >= batch_size_) {
current_chunk_id_ = chunk_id;
current_chunk_pos_ = i + 1;
}
}
}
}
int64_t
GetCurrentRows() {
if (segment_->is_chunked()) {
auto current_rows =
is_left_indexed_ && segment_->type() == SegmentType::Sealed
? left_current_chunk_pos_
: segment_->num_rows_until_chunk(left_field_,
left_current_chunk_id_) +
left_current_chunk_pos_;
return current_rows;
} else {
return segment_->type() == SegmentType::Growing
? current_chunk_id_ * size_per_chunk_ +
current_chunk_pos_
: current_chunk_pos_;
} }
} }
private: private:
int64_t
GetCurrentRows() {
if (segment_chunk_reader_.segment_->is_chunked()) {
auto current_rows =
is_left_indexed_ && segment_chunk_reader_.segment_->type() ==
SegmentType::Sealed
? left_current_chunk_pos_
: segment_chunk_reader_.segment_->num_rows_until_chunk(
left_field_, left_current_chunk_id_) +
left_current_chunk_pos_;
return current_rows;
} else {
return segment_chunk_reader_.segment_->type() ==
SegmentType::Growing
? current_chunk_id_ *
segment_chunk_reader_.SizePerChunk() +
current_chunk_pos_
: current_chunk_pos_;
}
}
int64_t int64_t
GetNextBatchSize(); GetNextBatchSize();
bool bool
IsStringExpr(); IsStringExpr();
template <typename T>
MultipleChunkDataAccessor
GetChunkData(FieldId field_id,
bool index,
int64_t& current_chunk_id,
int64_t& current_chunk_pos);
template <typename T>
ChunkDataAccessor
GetChunkData(FieldId field_id, int chunk_id, int data_barrier);
template <typename T, typename U, typename FUNC, typename... ValTypes> template <typename T, typename U, typename FUNC, typename... ValTypes>
int64_t int64_t
ProcessBothDataChunks(FUNC func, ProcessBothDataChunks(FUNC func,
TargetBitmapView res, TargetBitmapView res,
TargetBitmapView valid_res, TargetBitmapView valid_res,
ValTypes... values) { ValTypes... values) {
if (segment_->is_chunked()) { if (segment_chunk_reader_.segment_->is_chunked()) {
return ProcessBothDataChunksForMultipleChunk<T, return ProcessBothDataChunksForMultipleChunk<T,
U, U,
FUNC, FUNC,
@ -292,18 +211,27 @@ class PhyCompareFilterExpr : public Expr {
ValTypes... values) { ValTypes... values) {
int64_t processed_size = 0; int64_t processed_size = 0;
const auto active_count = segment_chunk_reader_.active_count_;
for (size_t i = current_chunk_id_; i < num_chunk_; i++) { for (size_t i = current_chunk_id_; i < num_chunk_; i++) {
auto left_chunk = segment_->chunk_data<T>(left_field_, i); auto left_chunk =
auto right_chunk = segment_->chunk_data<U>(right_field_, i); segment_chunk_reader_.segment_->chunk_data<T>(left_field_, i);
auto right_chunk =
segment_chunk_reader_.segment_->chunk_data<U>(right_field_, i);
auto data_pos = (i == current_chunk_id_) ? current_chunk_pos_ : 0; auto data_pos = (i == current_chunk_id_) ? current_chunk_pos_ : 0;
auto size = auto size =
(i == (num_chunk_ - 1)) (i == (num_chunk_ - 1))
? (segment_->type() == SegmentType::Growing ? (segment_chunk_reader_.segment_->type() ==
? (active_count_ % size_per_chunk_ == 0 SegmentType::Growing
? size_per_chunk_ - data_pos ? (active_count % segment_chunk_reader_
: active_count_ % size_per_chunk_ - data_pos) .SizePerChunk() ==
: active_count_ - data_pos) 0
: size_per_chunk_ - data_pos; ? segment_chunk_reader_.SizePerChunk() -
data_pos
: active_count % segment_chunk_reader_
.SizePerChunk() -
data_pos)
: active_count - data_pos)
: segment_chunk_reader_.SizePerChunk() - data_pos;
if (processed_size + size >= batch_size_) { if (processed_size + size >= batch_size_) {
size = batch_size_ - processed_size; size = batch_size_ - processed_size;
@ -348,19 +276,29 @@ class PhyCompareFilterExpr : public Expr {
// only call this function when left and right are not indexed, so they have the same number of chunks // only call this function when left and right are not indexed, so they have the same number of chunks
for (size_t i = left_current_chunk_id_; i < left_num_chunk_; i++) { for (size_t i = left_current_chunk_id_; i < left_num_chunk_; i++) {
auto left_chunk = segment_->chunk_data<T>(left_field_, i); auto left_chunk =
auto right_chunk = segment_->chunk_data<U>(right_field_, i); segment_chunk_reader_.segment_->chunk_data<T>(left_field_, i);
auto right_chunk =
segment_chunk_reader_.segment_->chunk_data<U>(right_field_, i);
auto data_pos = auto data_pos =
(i == left_current_chunk_id_) ? left_current_chunk_pos_ : 0; (i == left_current_chunk_id_) ? left_current_chunk_pos_ : 0;
auto size = 0; auto size = 0;
if (segment_->type() == SegmentType::Growing) { if (segment_chunk_reader_.segment_->type() ==
size = (i == (left_num_chunk_ - 1)) SegmentType::Growing) {
? (active_count_ % size_per_chunk_ == 0 size =
? size_per_chunk_ - data_pos (i == (left_num_chunk_ - 1))
: active_count_ % size_per_chunk_ - data_pos) ? (segment_chunk_reader_.active_count_ %
: size_per_chunk_ - data_pos; segment_chunk_reader_.SizePerChunk() ==
0
? segment_chunk_reader_.SizePerChunk() - data_pos
: segment_chunk_reader_.active_count_ %
segment_chunk_reader_.SizePerChunk() -
data_pos)
: segment_chunk_reader_.SizePerChunk() - data_pos;
} else { } else {
size = segment_->chunk_size(left_field_, i) - data_pos; size =
segment_chunk_reader_.segment_->chunk_size(left_field_, i) -
data_pos;
} }
if (processed_size + size >= batch_size_) { if (processed_size + size >= batch_size_) {
@ -396,19 +334,6 @@ class PhyCompareFilterExpr : public Expr {
return processed_size; return processed_size;
} }
MultipleChunkDataAccessor
GetChunkData(DataType data_type,
FieldId field_id,
bool index,
int64_t& current_chunk_id,
int64_t& current_chunk_pos);
ChunkDataAccessor
GetChunkData(DataType data_type,
FieldId field_id,
int chunk_id,
int data_barrier);
template <typename OpType> template <typename OpType>
VectorPtr VectorPtr
ExecCompareExprDispatcher(OpType op); ExecCompareExprDispatcher(OpType op);
@ -432,7 +357,6 @@ class PhyCompareFilterExpr : public Expr {
const FieldId right_field_; const FieldId right_field_;
bool is_left_indexed_; bool is_left_indexed_;
bool is_right_indexed_; bool is_right_indexed_;
int64_t active_count_{0};
int64_t num_chunk_{0}; int64_t num_chunk_{0};
int64_t left_num_chunk_{0}; int64_t left_num_chunk_{0};
int64_t right_num_chunk_{0}; int64_t right_num_chunk_{0};
@ -442,9 +366,8 @@ class PhyCompareFilterExpr : public Expr {
int64_t right_current_chunk_pos_{0}; int64_t right_current_chunk_pos_{0};
int64_t current_chunk_id_{0}; int64_t current_chunk_id_{0};
int64_t current_chunk_pos_{0}; int64_t current_chunk_pos_{0};
int64_t size_per_chunk_{0};
const segcore::SegmentInternalInterface* segment_; const segcore::SegmentChunkReader segment_chunk_reader_;
int64_t batch_size_; int64_t batch_size_;
std::shared_ptr<const milvus::expr::CompareExpr> expr_; std::shared_ptr<const milvus::expr::CompareExpr> expr_;
}; };

View File

@ -16,9 +16,12 @@
#include "Expr.h" #include "Expr.h"
#include "common/EasyAssert.h"
#include "exec/expression/AlwaysTrueExpr.h" #include "exec/expression/AlwaysTrueExpr.h"
#include "exec/expression/BinaryArithOpEvalRangeExpr.h" #include "exec/expression/BinaryArithOpEvalRangeExpr.h"
#include "exec/expression/BinaryRangeExpr.h" #include "exec/expression/BinaryRangeExpr.h"
#include "exec/expression/CallExpr.h"
#include "exec/expression/ColumnExpr.h"
#include "exec/expression/CompareExpr.h" #include "exec/expression/CompareExpr.h"
#include "exec/expression/ConjunctExpr.h" #include "exec/expression/ConjunctExpr.h"
#include "exec/expression/ExistsExpr.h" #include "exec/expression/ExistsExpr.h"
@ -27,6 +30,10 @@
#include "exec/expression/LogicalUnaryExpr.h" #include "exec/expression/LogicalUnaryExpr.h"
#include "exec/expression/TermExpr.h" #include "exec/expression/TermExpr.h"
#include "exec/expression/UnaryExpr.h" #include "exec/expression/UnaryExpr.h"
#include "exec/expression/ValueExpr.h"
#include <memory>
namespace milvus { namespace milvus {
namespace exec { namespace exec {
@ -156,8 +163,14 @@ CompileExpression(const expr::TypedExprPtr& expr,
}; };
auto input_types = GetTypes(compiled_inputs); auto input_types = GetTypes(compiled_inputs);
if (auto call = dynamic_cast<const expr::CallTypeExpr*>(expr.get())) { if (auto call = std::dynamic_pointer_cast<const expr::CallExpr>(expr)) {
// TODO: support function register and search mode result = std::make_shared<PhyCallExpr>(
compiled_inputs,
call,
"PhyCallExpr",
context->get_segment(),
context->get_active_count(),
context->query_config()->get_expr_batch_size());
} else if (auto casted_expr = std::dynamic_pointer_cast< } else if (auto casted_expr = std::dynamic_pointer_cast<
const milvus::expr::UnaryRangeFilterExpr>(expr)) { const milvus::expr::UnaryRangeFilterExpr>(expr)) {
result = std::make_shared<PhyUnaryRangeFilterExpr>( result = std::make_shared<PhyUnaryRangeFilterExpr>(
@ -251,6 +264,29 @@ CompileExpression(const expr::TypedExprPtr& expr,
context->get_segment(), context->get_segment(),
context->get_active_count(), context->get_active_count(),
context->query_config()->get_expr_batch_size()); context->query_config()->get_expr_batch_size());
} else if (auto value_expr =
std::dynamic_pointer_cast<const milvus::expr::ValueExpr>(
expr)) {
// used for function call arguments, may emit any type
result = std::make_shared<PhyValueExpr>(
compiled_inputs,
value_expr,
"PhyValueExpr",
context->get_segment(),
context->get_active_count(),
context->query_config()->get_expr_batch_size());
} else if (auto column_expr =
std::dynamic_pointer_cast<const milvus::expr::ColumnExpr>(
expr)) {
result = std::make_shared<PhyColumnExpr>(
compiled_inputs,
column_expr,
"PhyColumnExpr",
context->get_segment(),
context->get_active_count(),
context->query_config()->get_expr_batch_size());
} else {
PanicInfo(ExprInvalid, "unsupport expr: ", expr->ToString());
} }
return result; return result;
} }
@ -261,4 +297,4 @@ OptimizeCompiledExprs(ExecContext* context, const std::vector<ExprPtr>& exprs) {
} }
} // namespace exec } // namespace exec
} // namespace milvus } // namespace milvus

View File

@ -77,6 +77,7 @@ class Expr {
DataType type_; DataType type_;
const std::vector<std::shared_ptr<Expr>> inputs_; const std::vector<std::shared_ptr<Expr>> inputs_;
std::string name_; std::string name_;
// NOTE: unused
std::shared_ptr<VectorFunction> vector_func_; std::shared_ptr<VectorFunction> vector_func_;
}; };
@ -84,6 +85,9 @@ using ExprPtr = std::shared_ptr<milvus::exec::Expr>;
using SkipFunc = bool (*)(const milvus::SkipIndex&, FieldId, int); using SkipFunc = bool (*)(const milvus::SkipIndex&, FieldId, int);
/*
* The expr has only one column.
*/
class SegmentExpr : public Expr { class SegmentExpr : public Expr {
public: public:
SegmentExpr(const std::vector<ExprPtr>&& input, SegmentExpr(const std::vector<ExprPtr>&& input,
@ -762,7 +766,8 @@ CompileExpression(const expr::TypedExprPtr& expr,
class ExprSet { class ExprSet {
public: public:
explicit ExprSet(const std::vector<expr::TypedExprPtr>& logical_exprs, explicit ExprSet(const std::vector<expr::TypedExprPtr>& logical_exprs,
ExecContext* exec_ctx) { ExecContext* exec_ctx)
: exec_ctx_(exec_ctx) {
exprs_ = CompileExpressions(logical_exprs, exec_ctx); exprs_ = CompileExpressions(logical_exprs, exec_ctx);
} }

View File

@ -0,0 +1,101 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ValueExpr.h"
#include "common/Vector.h"
namespace milvus {
namespace exec {
void
PhyValueExpr::Eval(EvalCtx& context, VectorPtr& result) {
int64_t real_batch_size = current_pos_ + batch_size_ >= active_count_
? active_count_ - current_pos_
: batch_size_;
if (real_batch_size == 0) {
result = nullptr;
return;
}
switch (expr_->type()) {
case DataType::NONE:
// null
result = std::make_shared<ConstantVector<bool>>(
expr_->type(), real_batch_size, false, 1);
break;
case DataType::BOOL:
result = std::make_shared<ConstantVector<bool>>(
expr_->type(),
real_batch_size,
expr_->GetGenericValue().bool_val());
break;
case DataType::INT8:
result = std::make_shared<ConstantVector<int8_t>>(
expr_->type(),
real_batch_size,
expr_->GetGenericValue().int64_val());
break;
case DataType::INT16:
result = std::make_shared<ConstantVector<int16_t>>(
expr_->type(),
real_batch_size,
expr_->GetGenericValue().int64_val());
break;
case DataType::INT32:
result = std::make_shared<ConstantVector<int32_t>>(
expr_->type(),
real_batch_size,
expr_->GetGenericValue().int64_val());
break;
case DataType::INT64:
result = std::make_shared<ConstantVector<int64_t>>(
expr_->type(),
real_batch_size,
expr_->GetGenericValue().int64_val());
break;
case DataType::FLOAT:
result = std::make_shared<ConstantVector<float>>(
expr_->type(),
real_batch_size,
expr_->GetGenericValue().float_val());
break;
case DataType::DOUBLE:
result = std::make_shared<ConstantVector<double>>(
expr_->type(),
real_batch_size,
expr_->GetGenericValue().float_val());
break;
case DataType::STRING:
case DataType::VARCHAR:
result = std::make_shared<ConstantVector<std::string>>(
expr_->type(),
real_batch_size,
expr_->GetGenericValue().string_val());
break;
// TODO: json and array type
case DataType::ARRAY:
case DataType::JSON:
default:
PanicInfo(DataTypeInvalid,
"PhyValueExpr not support data type " +
GetDataTypeName(expr_->type()));
}
current_pos_ += real_batch_size;
}
} //namespace exec
} // namespace milvus

View File

@ -0,0 +1,67 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <fmt/core.h>
#include "common/EasyAssert.h"
#include "common/Types.h"
#include "common/Vector.h"
#include "exec/expression/Expr.h"
#include "segcore/SegmentInterface.h"
namespace milvus {
namespace exec {
class PhyValueExpr : public Expr {
public:
PhyValueExpr(const std::vector<std::shared_ptr<Expr>>& input,
const std::shared_ptr<const milvus::expr::ValueExpr> expr,
const std::string& name,
const segcore::SegmentInternalInterface* segment,
int64_t active_count,
int64_t batch_size)
: Expr(expr->type(), std::move(input), name),
expr_(expr),
active_count_(active_count),
batch_size_(batch_size) {
AssertInfo(input.empty(),
"PhyValueExpr should not have input, but got " +
std::to_string(input.size()));
}
void
Eval(EvalCtx& context, VectorPtr& result) override;
void
MoveCursor() override {
int64_t real_batch_size = current_pos_ + batch_size_ >= active_count_
? active_count_ - current_pos_
: batch_size_;
current_pos_ += real_batch_size;
}
private:
std::shared_ptr<const milvus::expr::ValueExpr> expr_;
const int64_t active_count_;
int64_t current_pos_{0};
const int64_t batch_size_;
};
} //namespace exec
} // namespace milvus

View File

@ -0,0 +1,83 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "exec/expression/function/FunctionFactory.h"
#include <mutex>
#include "exec/expression/function/impl/StringFunctions.h"
#include "log/Log.h"
namespace milvus {
namespace exec {
namespace expression {
std::string
FilterFunctionRegisterKey::ToString() const {
std::ostringstream oss;
oss << func_name << "(";
for (size_t i = 0; i < func_param_type_list.size(); ++i) {
oss << GetDataTypeName(func_param_type_list[i]);
if (i < func_param_type_list.size() - 1) {
oss << ", ";
}
}
oss << ")";
return oss.str();
}
FunctionFactory&
FunctionFactory::Instance() {
static FunctionFactory factory;
return factory;
}
void
FunctionFactory::Initialize() {
std::call_once(init_flag_, &FunctionFactory::RegisterAllFunctions, this);
}
void
FunctionFactory::RegisterAllFunctions() {
RegisterFilterFunction(
"empty", {DataType::VARCHAR}, function::EmptyVarchar);
RegisterFilterFunction("starts_with",
{DataType::VARCHAR, DataType::VARCHAR},
function::StartsWithVarchar);
LOG_INFO("{} functions registered", GetFilterFunctionNum());
}
void
FunctionFactory::RegisterFilterFunction(
std::string func_name,
std::vector<DataType> func_param_type_list,
FilterFunctionPtr func) {
filter_function_map_[FilterFunctionRegisterKey{
func_name, func_param_type_list}] = func;
}
const FilterFunctionPtr
FunctionFactory::GetFilterFunction(
const FilterFunctionRegisterKey& func_sig) const {
auto iter = filter_function_map_.find(func_sig);
if (iter != filter_function_map_.end()) {
return iter->second;
}
return nullptr;
}
} // namespace expression
} // namespace exec
} // namespace milvus

View File

@ -0,0 +1,111 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cstddef>
#include <functional>
#include <mutex>
#include <string>
#include <string_view>
#include <unordered_map>
#include <vector>
#include <boost/variant.hpp>
#include "common/Vector.h"
namespace milvus {
namespace exec {
class EvalCtx;
class Expr;
class PhyCallExpr;
namespace expression {
struct FilterFunctionRegisterKey {
std::string func_name;
std::vector<DataType> func_param_type_list;
std::string
ToString() const;
bool
operator==(const FilterFunctionRegisterKey& other) const {
return func_name == other.func_name &&
func_param_type_list == other.func_param_type_list;
}
struct Hash {
size_t
operator()(const FilterFunctionRegisterKey& s) const {
size_t h1 = std::hash<std::string_view>{}(s.func_name);
size_t h2 = boost::hash_range(s.func_param_type_list.begin(),
s.func_param_type_list.end());
return h1 ^ h2;
}
};
};
using FilterFunctionParameter = std::shared_ptr<Expr>;
using FilterFunctionReturn = VectorPtr;
using FilterFunctionPtr = void (*)(const RowVector& args,
FilterFunctionReturn& result);
class FunctionFactory {
public:
static FunctionFactory&
Instance();
void
Initialize();
void
RegisterFilterFunction(std::string func_name,
std::vector<DataType> func_param_type_list,
FilterFunctionPtr func);
const FilterFunctionPtr
GetFilterFunction(const FilterFunctionRegisterKey& func_sig) const;
size_t
GetFilterFunctionNum() const {
return filter_function_map_.size();
}
std::vector<FilterFunctionRegisterKey>
ListAllFilterFunctions() const {
std::vector<FilterFunctionRegisterKey> result;
for (const auto& [key, value] : filter_function_map_) {
result.push_back(key);
}
return result;
}
private:
void
RegisterAllFunctions();
std::unordered_map<FilterFunctionRegisterKey,
FilterFunctionPtr,
FilterFunctionRegisterKey::Hash>
filter_function_map_;
std::once_flag init_flag_;
};
} // namespace expression
} // namespace exec
} // namespace milvus

View File

@ -0,0 +1,30 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "exec/expression/function/FunctionImplUtils.h"
#include "common/EasyAssert.h"
namespace milvus::exec::expression::function {
void
CheckVarcharOrStringType(std::shared_ptr<SimpleVector>& vec) {
if (vec->type() != DataType::VARCHAR && vec->type() != DataType::STRING) {
PanicInfo(ExprInvalid,
"invalid argument type, expect VARCHAR or STRING, actual {}",
vec->type());
}
}
} // namespace milvus::exec::expression::function

View File

@ -0,0 +1,25 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "common/Vector.h"
namespace milvus::exec::expression::function {
void
CheckVarcharOrStringType(std::shared_ptr<SimpleVector>& vec);
} // namespace milvus::exec::expression::function

View File

@ -0,0 +1,59 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "exec/expression/function/FunctionImplUtils.h"
#include "exec/expression/function/impl/StringFunctions.h"
#include <boost/variant/get.hpp>
#include <string>
#include "common/EasyAssert.h"
#include "exec/expression/function/FunctionFactory.h"
namespace milvus {
namespace exec {
namespace expression {
namespace function {
void
EmptyVarchar(const RowVector& args, FilterFunctionReturn& result) {
if (args.childrens().size() != 1) {
PanicInfo(ExprInvalid,
"invalid argument count, expect 1, actual {}",
args.childrens().size());
}
auto arg = args.child(0);
auto vec = std::dynamic_pointer_cast<SimpleVector>(arg);
Assert(vec != nullptr);
CheckVarcharOrStringType(vec);
TargetBitmap bitmap(vec->size(), false);
TargetBitmap valid_bitmap(vec->size(), true);
for (size_t i = 0; i < vec->size(); ++i) {
if (vec->ValidAt(i)) {
bitmap[i] = reinterpret_cast<std::string*>(
vec->RawValueAt(i, sizeof(std::string)))
->empty();
} else {
valid_bitmap[i] = false;
}
}
result = std::make_shared<ColumnVector>(std::move(bitmap),
std::move(valid_bitmap));
}
} // namespace function
} // namespace expression
} // namespace exec
} // namespace milvus

View File

@ -0,0 +1,64 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "exec/expression/function/FunctionImplUtils.h"
#include "exec/expression/function/impl/StringFunctions.h"
#include <boost/variant/get.hpp>
#include <string>
#include "common/EasyAssert.h"
#include "exec/expression/function/FunctionFactory.h"
namespace milvus {
namespace exec {
namespace expression {
namespace function {
void
StartsWithVarchar(const RowVector& args, FilterFunctionReturn& result) {
if (args.childrens().size() != 2) {
PanicInfo(ExprInvalid,
"invalid argument count, expect 2, actual {}",
args.childrens().size());
}
auto strs = std::dynamic_pointer_cast<SimpleVector>(args.child(0));
Assert(strs != nullptr);
CheckVarcharOrStringType(strs);
auto prefixes = std::dynamic_pointer_cast<SimpleVector>(args.child(1));
Assert(prefixes != nullptr);
CheckVarcharOrStringType(prefixes);
TargetBitmap bitmap(strs->size(), false);
TargetBitmap valid_bitmap(strs->size(), true);
for (size_t i = 0; i < strs->size(); ++i) {
if (strs->ValidAt(i) && prefixes->ValidAt(i)) {
auto* str_ptr = reinterpret_cast<std::string*>(
strs->RawValueAt(i, sizeof(std::string)));
auto* prefix_ptr = reinterpret_cast<std::string*>(
prefixes->RawValueAt(i, sizeof(std::string)));
bitmap.set(i, str_ptr->find(*prefix_ptr) == 0);
} else {
valid_bitmap[i] = false;
}
}
result = std::make_shared<ColumnVector>(std::move(bitmap),
std::move(valid_bitmap));
}
} // namespace function
} // namespace expression
} // namespace exec
} // namespace milvus

View File

@ -0,0 +1,35 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "common/Vector.h"
#include "exec/expression/function/FunctionFactory.h"
namespace milvus {
namespace exec {
namespace expression {
namespace function {
void
EmptyVarchar(const RowVector& args, FilterFunctionReturn& result);
void
StartsWithVarchar(const RowVector& args, FilterFunctionReturn& result);
} // namespace function
} // namespace expression
} // namespace exec
} // namespace milvus

View File

@ -0,0 +1,23 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "exec/expression/function/init_c.h"
#include "exec/expression/function/FunctionFactory.h"
void
InitExecExpressionFunctionFactory() {
milvus::exec::expression::FunctionFactory::Instance().Initialize();
}

View File

@ -0,0 +1,28 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#ifdef __cplusplus
extern "C" {
#endif
void
InitExecExpressionFunctionFactory();
#ifdef __cplusplus
};
#endif

View File

@ -76,13 +76,24 @@ PhyFilterBitsNode::GetOutput() {
"PhyFilterBitsNode result size should be size one and not " "PhyFilterBitsNode result size should be size one and not "
"be nullptr"); "be nullptr");
auto col_vec = std::dynamic_pointer_cast<ColumnVector>(results_[0]); if (auto col_vec =
auto col_vec_size = col_vec->size(); std::dynamic_pointer_cast<ColumnVector>(results_[0])) {
TargetBitmapView view(col_vec->GetRawData(), col_vec_size); if (col_vec->IsBitmap()) {
bitset.append(view); auto col_vec_size = col_vec->size();
TargetBitmapView valid_view(col_vec->GetValidRawData(), col_vec_size); TargetBitmapView view(col_vec->GetRawData(), col_vec_size);
valid_bitset.append(valid_view); bitset.append(view);
num_processed_rows_ += col_vec_size; TargetBitmapView valid_view(col_vec->GetValidRawData(),
col_vec_size);
valid_bitset.append(valid_view);
num_processed_rows_ += col_vec_size;
} else {
PanicInfo(ExprInvalid,
"PhyFilterBitsNode result should be bitmap");
}
} else {
PanicInfo(ExprInvalid,
"PhyFilterBitsNode result should be ColumnVector");
}
} }
bitset.flip(); bitset.flip();
Assert(bitset.size() == need_process_rows_); Assert(bitset.size() == need_process_rows_);
@ -102,4 +113,4 @@ PhyFilterBitsNode::GetOutput() {
} }
} // namespace exec } // namespace exec
} // namespace milvus } // namespace milvus

View File

@ -21,6 +21,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "exec/expression/function/FunctionFactory.h"
#include "common/Exception.h" #include "common/Exception.h"
#include "common/Schema.h" #include "common/Schema.h"
#include "common/Types.h" #include "common/Types.h"
@ -211,6 +212,7 @@ class ITypeExpr {
using TypedExprPtr = std::shared_ptr<const ITypeExpr>; using TypedExprPtr = std::shared_ptr<const ITypeExpr>;
// NOTE: unused
class InputTypeExpr : public ITypeExpr { class InputTypeExpr : public ITypeExpr {
public: public:
InputTypeExpr(DataType type) : ITypeExpr(type) { InputTypeExpr(DataType type) : ITypeExpr(type) {
@ -224,42 +226,7 @@ class InputTypeExpr : public ITypeExpr {
using InputTypeExprPtr = std::shared_ptr<const InputTypeExpr>; using InputTypeExprPtr = std::shared_ptr<const InputTypeExpr>;
class CallTypeExpr : public ITypeExpr { // NOTE: unused
public:
CallTypeExpr(DataType type,
const std::vector<TypedExprPtr>& inputs,
std::string fun_name)
: ITypeExpr{type, std::move(inputs)} {
}
virtual ~CallTypeExpr() = default;
virtual const std::string&
name() const {
return name_;
}
std::string
ToString() const override {
std::string str{};
str += name();
str += "(";
for (size_t i = 0; i < inputs_.size(); ++i) {
if (i != 0) {
str += ",";
}
str += inputs_[i]->ToString();
}
str += ")";
return str;
}
private:
std::string name_;
};
using CallTypeExprPtr = std::shared_ptr<const CallTypeExpr>;
class FieldAccessTypeExpr : public ITypeExpr { class FieldAccessTypeExpr : public ITypeExpr {
public: public:
FieldAccessTypeExpr(DataType type, const std::string& name) FieldAccessTypeExpr(DataType type, const std::string& name)
@ -311,6 +278,71 @@ class ITypeFilterExpr : public ITypeExpr {
virtual ~ITypeFilterExpr() = default; virtual ~ITypeFilterExpr() = default;
}; };
class ColumnExpr : public ITypeExpr {
public:
explicit ColumnExpr(const ColumnInfo& column)
: ITypeExpr(column.data_type_), column_(column) {
}
const ColumnInfo&
GetColumn() const {
return column_;
}
std::string
ToString() const override {
std::stringstream ss;
ss << "ColumnExpr: {columnInfo:" << column_.ToString() << "}";
return ss.str();
}
private:
const ColumnInfo column_;
};
class ValueExpr : public ITypeExpr {
public:
explicit ValueExpr(const proto::plan::GenericValue& val)
: ITypeExpr(DataType::NONE), val_(val) {
switch (val.val_case()) {
case proto::plan::GenericValue::ValCase::kBoolVal:
type_ = DataType::BOOL;
break;
case proto::plan::GenericValue::ValCase::kInt64Val:
type_ = DataType::INT64;
break;
case proto::plan::GenericValue::ValCase::kFloatVal:
type_ = DataType::FLOAT;
break;
case proto::plan::GenericValue::ValCase::kStringVal:
type_ = DataType::VARCHAR;
break;
case proto::plan::GenericValue::ValCase::kArrayVal:
type_ = DataType::ARRAY;
break;
case proto::plan::GenericValue::ValCase::VAL_NOT_SET:
type_ = DataType::NONE;
break;
}
}
std::string
ToString() const override {
std::stringstream ss;
ss << "ValueExpr: {"
<< " val:" << val_.DebugString() << "}";
return ss.str();
}
const proto::plan::GenericValue
GetGenericValue() const {
return val_;
}
private:
const proto::plan::GenericValue val_;
};
class UnaryRangeFilterExpr : public ITypeFilterExpr { class UnaryRangeFilterExpr : public ITypeFilterExpr {
public: public:
explicit UnaryRangeFilterExpr(const ColumnInfo& column, explicit UnaryRangeFilterExpr(const ColumnInfo& column,
@ -595,6 +627,46 @@ class BinaryArithOpEvalRangeExpr : public ITypeFilterExpr {
const proto::plan::GenericValue value_; const proto::plan::GenericValue value_;
}; };
class CallExpr : public ITypeFilterExpr {
public:
CallExpr(const std::string fun_name,
const std::vector<TypedExprPtr>& parameters,
const exec::expression::FilterFunctionPtr function_ptr)
: fun_name_(std::move(fun_name)), function_ptr_(function_ptr) {
inputs_.insert(inputs_.end(), parameters.begin(), parameters.end());
}
virtual ~CallExpr() = default;
const std::string&
fun_name() const {
return fun_name_;
}
const exec::expression::FilterFunctionPtr
function_ptr() const {
return function_ptr_;
}
std::string
ToString() const override {
std::string parameters;
for (auto& e : inputs_) {
parameters += e->ToString();
parameters += ", ";
}
return fmt::format("CallExpr:[Function Name: {}, Parameters: {}]",
fun_name_,
parameters);
}
private:
const std::string fun_name_;
const exec::expression::FilterFunctionPtr function_ptr_;
};
using CallExprPtr = std::shared_ptr<const CallExpr>;
class CompareExpr : public ITypeFilterExpr { class CompareExpr : public ITypeFilterExpr {
public: public:
CompareExpr(const FieldId& left_field, CompareExpr(const FieldId& left_field,

View File

@ -15,9 +15,11 @@
#include <cstdint> #include <cstdint>
#include <string> #include <string>
#include <vector>
#include "common/VectorTrait.h" #include "common/VectorTrait.h"
#include "common/EasyAssert.h" #include "common/EasyAssert.h"
#include "exec/expression/function/FunctionFactory.h"
#include "pb/plan.pb.h" #include "pb/plan.pb.h"
#include "query/Utils.h" #include "query/Utils.h"
#include "knowhere/comp/materialized_view.h" #include "knowhere/comp/materialized_view.h"
@ -256,6 +258,29 @@ ProtoParser::ParseBinaryRangeExprs(
expr_pb.upper_inclusive()); expr_pb.upper_inclusive());
} }
expr::TypedExprPtr
ProtoParser::ParseCallExprs(const proto::plan::CallExpr& expr_pb) {
std::vector<expr::TypedExprPtr> parameters;
std::vector<DataType> func_param_type_list;
for (auto& param_expr : expr_pb.function_parameters()) {
// function parameter can be any type
auto e = this->ParseExprs(param_expr, TypeIsAny);
parameters.push_back(e);
func_param_type_list.push_back(e->type());
}
auto& factory = exec::expression::FunctionFactory::Instance();
exec::expression::FilterFunctionRegisterKey func_sig{
expr_pb.function_name(), std::move(func_param_type_list)};
auto function = factory.GetFilterFunction(func_sig);
if (function == nullptr) {
PanicInfo(ExprInvalid,
"function " + func_sig.ToString() + " not found. ");
}
return std::make_shared<expr::CallExpr>(
expr_pb.function_name(), parameters, function);
}
expr::TypedExprPtr expr::TypedExprPtr
ProtoParser::ParseCompareExprs(const proto::plan::CompareExpr& expr_pb) { ProtoParser::ParseCompareExprs(const proto::plan::CompareExpr& expr_pb) {
auto& left_column_info = expr_pb.left_column_info(); auto& left_column_info = expr_pb.left_column_info();
@ -349,45 +374,80 @@ ProtoParser::ParseJsonContainsExprs(
std::move(values)); std::move(values));
} }
expr::TypedExprPtr
ProtoParser::ParseColumnExprs(const proto::plan::ColumnExpr& expr_pb) {
return std::make_shared<expr::ColumnExpr>(expr_pb.info());
}
expr::TypedExprPtr
ProtoParser::ParseValueExprs(const proto::plan::ValueExpr& expr_pb) {
return std::make_shared<expr::ValueExpr>(expr_pb.value());
}
expr::TypedExprPtr expr::TypedExprPtr
ProtoParser::CreateAlwaysTrueExprs() { ProtoParser::CreateAlwaysTrueExprs() {
return std::make_shared<expr::AlwaysTrueExpr>(); return std::make_shared<expr::AlwaysTrueExpr>();
} }
expr::TypedExprPtr expr::TypedExprPtr
ProtoParser::ParseExprs(const proto::plan::Expr& expr_pb) { ProtoParser::ParseExprs(const proto::plan::Expr& expr_pb,
TypeCheckFunction type_check) {
using ppe = proto::plan::Expr; using ppe = proto::plan::Expr;
expr::TypedExprPtr result;
switch (expr_pb.expr_case()) { switch (expr_pb.expr_case()) {
case ppe::kUnaryRangeExpr: { case ppe::kUnaryRangeExpr: {
return ParseUnaryRangeExprs(expr_pb.unary_range_expr()); result = ParseUnaryRangeExprs(expr_pb.unary_range_expr());
break;
} }
case ppe::kBinaryExpr: { case ppe::kBinaryExpr: {
return ParseBinaryExprs(expr_pb.binary_expr()); result = ParseBinaryExprs(expr_pb.binary_expr());
break;
} }
case ppe::kUnaryExpr: { case ppe::kUnaryExpr: {
return ParseUnaryExprs(expr_pb.unary_expr()); result = ParseUnaryExprs(expr_pb.unary_expr());
break;
} }
case ppe::kTermExpr: { case ppe::kTermExpr: {
return ParseTermExprs(expr_pb.term_expr()); result = ParseTermExprs(expr_pb.term_expr());
break;
} }
case ppe::kBinaryRangeExpr: { case ppe::kBinaryRangeExpr: {
return ParseBinaryRangeExprs(expr_pb.binary_range_expr()); result = ParseBinaryRangeExprs(expr_pb.binary_range_expr());
break;
} }
case ppe::kCompareExpr: { case ppe::kCompareExpr: {
return ParseCompareExprs(expr_pb.compare_expr()); result = ParseCompareExprs(expr_pb.compare_expr());
break;
} }
case ppe::kBinaryArithOpEvalRangeExpr: { case ppe::kBinaryArithOpEvalRangeExpr: {
return ParseBinaryArithOpEvalRangeExprs( result = ParseBinaryArithOpEvalRangeExprs(
expr_pb.binary_arith_op_eval_range_expr()); expr_pb.binary_arith_op_eval_range_expr());
break;
} }
case ppe::kExistsExpr: { case ppe::kExistsExpr: {
return ParseExistExprs(expr_pb.exists_expr()); result = ParseExistExprs(expr_pb.exists_expr());
break;
} }
case ppe::kAlwaysTrueExpr: { case ppe::kAlwaysTrueExpr: {
return CreateAlwaysTrueExprs(); result = CreateAlwaysTrueExprs();
break;
} }
case ppe::kJsonContainsExpr: { case ppe::kJsonContainsExpr: {
return ParseJsonContainsExprs(expr_pb.json_contains_expr()); result = ParseJsonContainsExprs(expr_pb.json_contains_expr());
break;
}
case ppe::kCallExpr: {
result = ParseCallExprs(expr_pb.call_expr());
break;
}
// may emit various types
case ppe::kColumnExpr: {
result = ParseColumnExprs(expr_pb.column_expr());
break;
}
case ppe::kValueExpr: {
result = ParseValueExprs(expr_pb.value_expr());
break;
} }
default: { default: {
std::string s; std::string s;
@ -396,6 +456,11 @@ ProtoParser::ParseExprs(const proto::plan::Expr& expr_pb) {
std::string("unsupported expr proto node: ") + s); std::string("unsupported expr proto node: ") + s);
} }
} }
if (type_check(result->type())) {
return result;
}
PanicInfo(
ExprInvalid, "expr type check failed, actual type: {}", result->type());
} }
} // namespace milvus::query } // namespace milvus::query

View File

@ -23,6 +23,17 @@
namespace milvus::query { namespace milvus::query {
class ProtoParser { class ProtoParser {
public:
using TypeCheckFunction = std::function<bool(const DataType)>;
static bool
TypeIsBool(const DataType type) {
return type == DataType::BOOL;
}
static bool
TypeIsAny(const DataType) {
return true;
}
public: public:
explicit ProtoParser(const Schema& schema) : schema(schema) { explicit ProtoParser(const Schema& schema) : schema(schema) {
} }
@ -40,10 +51,15 @@ class ProtoParser {
CreateRetrievePlan(const proto::plan::PlanNode& plan_node_proto); CreateRetrievePlan(const proto::plan::PlanNode& plan_node_proto);
expr::TypedExprPtr expr::TypedExprPtr
ParseUnaryRangeExprs(const proto::plan::UnaryRangeExpr& expr_pb); ParseExprs(const proto::plan::Expr& expr_pb,
TypeCheckFunction type_check = TypeIsBool);
private:
expr::TypedExprPtr
CreateAlwaysTrueExprs();
expr::TypedExprPtr expr::TypedExprPtr
ParseExprs(const proto::plan::Expr& expr_pb); ParseBinaryExprs(const proto::plan::BinaryExpr& expr_pb);
expr::TypedExprPtr expr::TypedExprPtr
ParseBinaryArithOpEvalRangeExprs( ParseBinaryArithOpEvalRangeExprs(
@ -52,18 +68,15 @@ class ProtoParser {
expr::TypedExprPtr expr::TypedExprPtr
ParseBinaryRangeExprs(const proto::plan::BinaryRangeExpr& expr_pb); ParseBinaryRangeExprs(const proto::plan::BinaryRangeExpr& expr_pb);
expr::TypedExprPtr
ParseCallExprs(const proto::plan::CallExpr& expr_pb);
expr::TypedExprPtr
ParseColumnExprs(const proto::plan::ColumnExpr& expr_pb);
expr::TypedExprPtr expr::TypedExprPtr
ParseCompareExprs(const proto::plan::CompareExpr& expr_pb); ParseCompareExprs(const proto::plan::CompareExpr& expr_pb);
expr::TypedExprPtr
ParseTermExprs(const proto::plan::TermExpr& expr_pb);
expr::TypedExprPtr
ParseUnaryExprs(const proto::plan::UnaryExpr& expr_pb);
expr::TypedExprPtr
ParseBinaryExprs(const proto::plan::BinaryExpr& expr_pb);
expr::TypedExprPtr expr::TypedExprPtr
ParseExistExprs(const proto::plan::ExistsExpr& expr_pb); ParseExistExprs(const proto::plan::ExistsExpr& expr_pb);
@ -71,14 +84,23 @@ class ProtoParser {
ParseJsonContainsExprs(const proto::plan::JSONContainsExpr& expr_pb); ParseJsonContainsExprs(const proto::plan::JSONContainsExpr& expr_pb);
expr::TypedExprPtr expr::TypedExprPtr
CreateAlwaysTrueExprs(); ParseTermExprs(const proto::plan::TermExpr& expr_pb);
expr::TypedExprPtr
ParseUnaryExprs(const proto::plan::UnaryExpr& expr_pb);
expr::TypedExprPtr
ParseUnaryRangeExprs(const proto::plan::UnaryRangeExpr& expr_pb);
expr::TypedExprPtr
ParseValueExprs(const proto::plan::ValueExpr& expr_pb);
private: private:
const Schema& schema; const Schema& schema;
}; };
} // namespace milvus::query } // namespace milvus::query
//
template <> template <>
struct fmt::formatter<milvus::proto::plan::GenericValue::ValCase> struct fmt::formatter<milvus::proto::plan::GenericValue::ValCase>
: formatter<string_view> { : formatter<string_view> {

View File

@ -0,0 +1,327 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "segcore/SegmentChunkReader.h"
namespace milvus::segcore {
template <typename T>
MultipleChunkDataAccessor
SegmentChunkReader::GetChunkDataAccessor(FieldId field_id,
bool index,
int64_t& current_chunk_id,
int64_t& current_chunk_pos) const {
if (index) {
auto& indexing = const_cast<index::ScalarIndex<T>&>(
segment_->chunk_scalar_index<T>(field_id, current_chunk_id));
auto current_chunk_size = segment_->type() == SegmentType::Growing
? SizePerChunk()
: active_count_;
if (indexing.HasRawData()) {
return [&, current_chunk_size]() -> const data_access_type {
if (current_chunk_pos >= current_chunk_size) {
current_chunk_id++;
current_chunk_pos = 0;
indexing = const_cast<index::ScalarIndex<T>&>(
segment_->chunk_scalar_index<T>(field_id,
current_chunk_id));
}
auto raw = indexing.Reverse_Lookup(current_chunk_pos);
current_chunk_pos++;
if (!raw.has_value()) {
return std::nullopt;
}
return raw.value();
};
}
}
auto chunk_data =
segment_->chunk_data<T>(field_id, current_chunk_id).data();
auto chunk_valid_data =
segment_->chunk_data<T>(field_id, current_chunk_id).valid_data();
auto current_chunk_size = segment_->chunk_size(field_id, current_chunk_id);
return [=,
&current_chunk_id,
&current_chunk_pos]() mutable -> const data_access_type {
if (current_chunk_pos >= current_chunk_size) {
current_chunk_id++;
current_chunk_pos = 0;
chunk_data =
segment_->chunk_data<T>(field_id, current_chunk_id).data();
chunk_valid_data =
segment_->chunk_data<T>(field_id, current_chunk_id)
.valid_data();
current_chunk_size =
segment_->chunk_size(field_id, current_chunk_id);
}
if (chunk_valid_data && !chunk_valid_data[current_chunk_pos]) {
current_chunk_pos++;
return std::nullopt;
}
return chunk_data[current_chunk_pos++];
};
}
template <>
MultipleChunkDataAccessor
SegmentChunkReader::GetChunkDataAccessor<std::string>(
FieldId field_id,
bool index,
int64_t& current_chunk_id,
int64_t& current_chunk_pos) const {
if (index) {
auto& indexing = const_cast<index::ScalarIndex<std::string>&>(
segment_->chunk_scalar_index<std::string>(field_id,
current_chunk_id));
auto current_chunk_size = segment_->type() == SegmentType::Growing
? SizePerChunk()
: active_count_;
if (indexing.HasRawData()) {
return [&, current_chunk_size]() mutable -> const data_access_type {
if (current_chunk_pos >= current_chunk_size) {
current_chunk_id++;
current_chunk_pos = 0;
indexing = const_cast<index::ScalarIndex<std::string>&>(
segment_->chunk_scalar_index<std::string>(
field_id, current_chunk_id));
}
auto raw = indexing.Reverse_Lookup(current_chunk_pos);
current_chunk_pos++;
if (!raw.has_value()) {
return std::nullopt;
}
return raw.value();
};
}
}
if (segment_->type() == SegmentType::Growing &&
!storage::MmapManager::GetInstance()
.GetMmapConfig()
.growing_enable_mmap) {
auto chunk_data =
segment_->chunk_data<std::string>(field_id, current_chunk_id)
.data();
auto chunk_valid_data =
segment_->chunk_data<std::string>(field_id, current_chunk_id)
.valid_data();
auto current_chunk_size =
segment_->chunk_size(field_id, current_chunk_id);
return [=,
&current_chunk_id,
&current_chunk_pos]() mutable -> const data_access_type {
if (current_chunk_pos >= current_chunk_size) {
current_chunk_id++;
current_chunk_pos = 0;
chunk_data =
segment_
->chunk_data<std::string>(field_id, current_chunk_id)
.data();
chunk_valid_data =
segment_
->chunk_data<std::string>(field_id, current_chunk_id)
.valid_data();
current_chunk_size =
segment_->chunk_size(field_id, current_chunk_id);
}
if (chunk_valid_data && !chunk_valid_data[current_chunk_pos]) {
current_chunk_pos++;
return std::nullopt;
}
return chunk_data[current_chunk_pos++];
};
} else {
auto chunk_data =
segment_->chunk_view<std::string_view>(field_id, current_chunk_id)
.first.data();
auto chunk_valid_data =
segment_->chunk_data<std::string_view>(field_id, current_chunk_id)
.valid_data();
auto current_chunk_size =
segment_->chunk_size(field_id, current_chunk_id);
return [=,
&current_chunk_id,
&current_chunk_pos]() mutable -> const data_access_type {
if (current_chunk_pos >= current_chunk_size) {
current_chunk_id++;
current_chunk_pos = 0;
chunk_data = segment_
->chunk_view<std::string_view>(
field_id, current_chunk_id)
.first.data();
chunk_valid_data = segment_
->chunk_data<std::string_view>(
field_id, current_chunk_id)
.valid_data();
current_chunk_size =
segment_->chunk_size(field_id, current_chunk_id);
}
if (chunk_valid_data && !chunk_valid_data[current_chunk_pos]) {
current_chunk_pos++;
return std::nullopt;
}
return std::string(chunk_data[current_chunk_pos++]);
};
}
}
MultipleChunkDataAccessor
SegmentChunkReader::GetChunkDataAccessor(DataType data_type,
FieldId field_id,
bool index,
int64_t& current_chunk_id,
int64_t& current_chunk_pos) const {
switch (data_type) {
case DataType::BOOL:
return GetChunkDataAccessor<bool>(
field_id, index, current_chunk_id, current_chunk_pos);
case DataType::INT8:
return GetChunkDataAccessor<int8_t>(
field_id, index, current_chunk_id, current_chunk_pos);
case DataType::INT16:
return GetChunkDataAccessor<int16_t>(
field_id, index, current_chunk_id, current_chunk_pos);
case DataType::INT32:
return GetChunkDataAccessor<int32_t>(
field_id, index, current_chunk_id, current_chunk_pos);
case DataType::INT64:
return GetChunkDataAccessor<int64_t>(
field_id, index, current_chunk_id, current_chunk_pos);
case DataType::FLOAT:
return GetChunkDataAccessor<float>(
field_id, index, current_chunk_id, current_chunk_pos);
case DataType::DOUBLE:
return GetChunkDataAccessor<double>(
field_id, index, current_chunk_id, current_chunk_pos);
case DataType::VARCHAR: {
return GetChunkDataAccessor<std::string>(
field_id, index, current_chunk_id, current_chunk_pos);
}
default:
PanicInfo(DataTypeInvalid, "unsupported data type: {}", data_type);
}
}
template <typename T>
ChunkDataAccessor
SegmentChunkReader::GetChunkDataAccessor(FieldId field_id,
int chunk_id,
int data_barrier) const {
if (chunk_id >= data_barrier) {
auto& indexing = segment_->chunk_scalar_index<T>(field_id, chunk_id);
if (indexing.HasRawData()) {
return [&indexing](int i) -> const data_access_type {
auto raw = indexing.Reverse_Lookup(i);
if (!raw.has_value()) {
return std::nullopt;
}
return raw.value();
};
}
}
auto chunk_data = segment_->chunk_data<T>(field_id, chunk_id).data();
auto chunk_valid_data =
segment_->chunk_data<T>(field_id, chunk_id).valid_data();
return [chunk_data, chunk_valid_data](int i) -> const data_access_type {
if (chunk_valid_data && !chunk_valid_data[i]) {
return std::nullopt;
}
return chunk_data[i];
};
}
template <>
ChunkDataAccessor
SegmentChunkReader::GetChunkDataAccessor<std::string>(FieldId field_id,
int chunk_id,
int data_barrier) const {
if (chunk_id >= data_barrier) {
auto& indexing =
segment_->chunk_scalar_index<std::string>(field_id, chunk_id);
if (indexing.HasRawData()) {
return [&indexing](int i) -> const data_access_type {
auto raw = indexing.Reverse_Lookup(i);
if (!raw.has_value()) {
return std::nullopt;
}
return raw.value();
};
}
}
if (segment_->type() == SegmentType::Growing &&
!storage::MmapManager::GetInstance()
.GetMmapConfig()
.growing_enable_mmap) {
auto chunk_data =
segment_->chunk_data<std::string>(field_id, chunk_id).data();
auto chunk_valid_data =
segment_->chunk_data<std::string>(field_id, chunk_id).valid_data();
return [chunk_data, chunk_valid_data](int i) -> const data_access_type {
if (chunk_valid_data && !chunk_valid_data[i]) {
return std::nullopt;
}
return chunk_data[i];
};
} else {
auto chunk_info =
segment_->chunk_view<std::string_view>(field_id, chunk_id);
return [chunk_data = std::move(chunk_info.first),
chunk_valid_data = std::move(chunk_info.second)](
int i) -> const data_access_type {
if (i < chunk_valid_data.size() && !chunk_valid_data[i]) {
return std::nullopt;
}
return std::string(chunk_data[i]);
};
}
}
ChunkDataAccessor
SegmentChunkReader::GetChunkDataAccessor(DataType data_type,
FieldId field_id,
int chunk_id,
int data_barrier) const {
switch (data_type) {
case DataType::BOOL:
return GetChunkDataAccessor<bool>(field_id, chunk_id, data_barrier);
case DataType::INT8:
return GetChunkDataAccessor<int8_t>(
field_id, chunk_id, data_barrier);
case DataType::INT16:
return GetChunkDataAccessor<int16_t>(
field_id, chunk_id, data_barrier);
case DataType::INT32:
return GetChunkDataAccessor<int32_t>(
field_id, chunk_id, data_barrier);
case DataType::INT64:
return GetChunkDataAccessor<int64_t>(
field_id, chunk_id, data_barrier);
case DataType::FLOAT:
return GetChunkDataAccessor<float>(
field_id, chunk_id, data_barrier);
case DataType::DOUBLE:
return GetChunkDataAccessor<double>(
field_id, chunk_id, data_barrier);
case DataType::VARCHAR: {
return GetChunkDataAccessor<std::string>(
field_id, chunk_id, data_barrier);
}
default:
PanicInfo(DataTypeInvalid, "unsupported data type: {}", data_type);
}
}
} // namespace milvus::segcore

View File

@ -0,0 +1,141 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <fmt/core.h>
#include <boost/variant.hpp>
#include <optional>
#include "common/Types.h"
#include "segcore/SegmentInterface.h"
namespace milvus::segcore {
using data_access_type = std::optional<boost::variant<bool,
int8_t,
int16_t,
int32_t,
int64_t,
float,
double,
std::string>>;
using ChunkDataAccessor = std::function<const data_access_type(int)>;
using MultipleChunkDataAccessor = std::function<const data_access_type()>;
class SegmentChunkReader {
public:
SegmentChunkReader(const segcore::SegmentInternalInterface* segment,
int64_t active_count)
: segment_(segment),
active_count_(active_count),
size_per_chunk_(segment->size_per_chunk()) {
}
MultipleChunkDataAccessor
GetChunkDataAccessor(DataType data_type,
FieldId field_id,
bool index,
int64_t& current_chunk_id,
int64_t& current_chunk_pos) const;
ChunkDataAccessor
GetChunkDataAccessor(DataType data_type,
FieldId field_id,
int chunk_id,
int data_barrier) const;
void
MoveCursorForMultipleChunk(int64_t& current_chunk_id,
int64_t& current_chunk_pos,
const FieldId field_id,
const int64_t num_chunk,
const int64_t batch_size) const {
int64_t processed_rows = 0;
for (int64_t chunk_id = current_chunk_id; chunk_id < num_chunk;
++chunk_id) {
int64_t chunk_size = 0;
if (segment_->type() == SegmentType::Growing) {
const auto size_per_chunk = SizePerChunk();
chunk_size = chunk_id == num_chunk - 1
? active_count_ - chunk_id * size_per_chunk
: size_per_chunk;
} else {
chunk_size = segment_->chunk_size(field_id, chunk_id);
}
for (int64_t i = chunk_id == current_chunk_id ? current_chunk_pos
: 0;
i < chunk_size;
++i) {
if (++processed_rows >= batch_size) {
current_chunk_id = chunk_id;
current_chunk_pos = i + 1;
}
}
}
}
void
MoveCursorForSingleChunk(int64_t& current_chunk_id,
int64_t& current_chunk_pos,
const int64_t num_chunk,
const int64_t batch_size) const {
int64_t processed_rows = 0;
for (int64_t chunk_id = current_chunk_id; chunk_id < num_chunk;
++chunk_id) {
auto chunk_size = chunk_id == num_chunk - 1
? active_count_ - chunk_id * SizePerChunk()
: SizePerChunk();
for (int64_t i = chunk_id == current_chunk_id ? current_chunk_pos
: 0;
i < chunk_size;
++i) {
if (++processed_rows >= batch_size) {
current_chunk_id = chunk_id;
current_chunk_pos = i + 1;
}
}
}
}
int64_t
SizePerChunk() const {
return size_per_chunk_;
}
const int64_t active_count_;
const segcore::SegmentInternalInterface* segment_;
private:
template <typename T>
MultipleChunkDataAccessor
GetChunkDataAccessor(FieldId field_id,
bool index,
int64_t& current_chunk_id,
int64_t& current_chunk_pos) const;
template <typename T>
ChunkDataAccessor
GetChunkDataAccessor(FieldId field_id,
int chunk_id,
int data_barrier) const;
const int64_t size_per_chunk_;
};
} // namespace milvus::segcore

View File

@ -149,9 +149,7 @@ class SegmentInternalInterface : public SegmentInterface {
template <typename ViewType> template <typename ViewType>
std::pair<std::vector<ViewType>, FixedVector<bool>> std::pair<std::vector<ViewType>, FixedVector<bool>>
chunk_view(FieldId field_id, int64_t chunk_id) const { chunk_view(FieldId field_id, int64_t chunk_id) const {
auto chunk_info = chunk_view_impl(field_id, chunk_id); auto [string_views, valid_data] = chunk_view_impl(field_id, chunk_id);
auto string_views = chunk_info.first;
auto valid_data = chunk_info.second;
if constexpr (std::is_same_v<ViewType, std::string_view>) { if constexpr (std::is_same_v<ViewType, std::string_view>) {
return std::make_pair(std::move(string_views), return std::make_pair(std::move(string_views),
std::move(valid_data)); std::move(valid_data));

View File

@ -48,6 +48,7 @@ set(MILVUS_TEST_FILES
test_expr.cpp test_expr.cpp
test_expr_materialized_view.cpp test_expr_materialized_view.cpp
test_float16.cpp test_float16.cpp
test_function.cpp
test_futures.cpp test_futures.cpp
test_group_by.cpp test_group_by.cpp
test_growing.cpp test_growing.cpp

View File

@ -27,6 +27,7 @@
#include "exec/QueryContext.h" #include "exec/QueryContext.h"
#include "expr/ITypeExpr.h" #include "expr/ITypeExpr.h"
#include "exec/expression/Expr.h" #include "exec/expression/Expr.h"
#include "exec/expression/function/FunctionFactory.h"
using namespace milvus; using namespace milvus;
using namespace milvus::exec; using namespace milvus::exec;
@ -40,6 +41,10 @@ class TaskTest : public testing::TestWithParam<DataType> {
using namespace milvus; using namespace milvus;
using namespace milvus::query; using namespace milvus::query;
using namespace milvus::segcore; using namespace milvus::segcore;
milvus::exec::expression::FunctionFactory& factory =
milvus::exec::expression::FunctionFactory::Instance();
factory.Initialize();
auto schema = std::make_shared<Schema>(); auto schema = std::make_shared<Schema>();
auto vec_fid = schema->AddDebugField( auto vec_fid = schema->AddDebugField(
"fakevec", GetParam(), 16, knowhere::metric::L2); "fakevec", GetParam(), 16, knowhere::metric::L2);
@ -113,6 +118,62 @@ INSTANTIATE_TEST_SUITE_P(TaskTestSuite,
::testing::Values(DataType::VECTOR_FLOAT, ::testing::Values(DataType::VECTOR_FLOAT,
DataType::VECTOR_SPARSE_FLOAT)); DataType::VECTOR_SPARSE_FLOAT));
TEST_P(TaskTest, RegisterFunction) {
milvus::exec::expression::FunctionFactory& factory =
milvus::exec::expression::FunctionFactory::Instance();
ASSERT_EQ(factory.GetFilterFunctionNum(), 2);
auto all_functions = factory.ListAllFilterFunctions();
// for (auto& f : all_functions) {
// std::cout << f.toString() << std::endl;
// }
auto func_ptr = factory.GetFilterFunction(
milvus::exec::expression::FilterFunctionRegisterKey{
"empty", {DataType::VARCHAR}});
ASSERT_TRUE(func_ptr != nullptr);
}
TEST_P(TaskTest, CallExprEmpty) {
expr::ColumnInfo col(field_map_["string1"], DataType::VARCHAR);
std::vector<milvus::expr::TypedExprPtr> parameters;
parameters.push_back(std::make_shared<milvus::expr::ColumnExpr>(col));
milvus::exec::expression::FunctionFactory& factory =
milvus::exec::expression::FunctionFactory::Instance();
auto empty_function_ptr = factory.GetFilterFunction(
milvus::exec::expression::FilterFunctionRegisterKey{
"empty", {DataType::VARCHAR}});
auto call_expr = std::make_shared<milvus::expr::CallExpr>(
"empty", parameters, empty_function_ptr);
ASSERT_EQ(call_expr->inputs().size(), 1);
std::vector<milvus::plan::PlanNodePtr> sources;
auto filter_node = std::make_shared<milvus::plan::FilterBitsNode>(
"plannode id 1", call_expr, sources);
auto plan = plan::PlanFragment(filter_node);
auto query_context = std::make_shared<milvus::exec::QueryContext>(
"test1",
segment_.get(),
1000000,
MAX_TIMESTAMP,
std::make_shared<milvus::exec::QueryConfig>(
std::unordered_map<std::string, std::string>{}));
auto start = std::chrono::steady_clock::now();
auto task = Task::Create("task_call_expr_empty", plan, 0, query_context);
int64_t num_rows = 0;
for (;;) {
auto result = task->Next();
if (!result) {
break;
}
num_rows += result->size();
}
auto cost = std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::steady_clock::now() - start)
.count();
std::cout << "cost: " << cost << "us" << std::endl;
EXPECT_EQ(num_rows, num_rows_);
}
TEST_P(TaskTest, UnaryExpr) { TEST_P(TaskTest, UnaryExpr) {
::milvus::proto::plan::GenericValue value; ::milvus::proto::plan::GenericValue value;
value.set_int64_val(-1); value.set_int64_val(-1);
@ -355,4 +416,4 @@ TEST_P(TaskTest, CompileInputs_or_with_and) {
"PhyUnaryRangeFilterExpr"); "PhyUnaryRangeFilterExpr");
} }
} }
} }

View File

@ -15,6 +15,7 @@
#include <cstdint> #include <cstdint>
#include <memory> #include <memory>
#include <regex> #include <regex>
#include <string>
#include <vector> #include <vector>
#include <chrono> #include <chrono>
#include <roaring/roaring.hh> #include <roaring/roaring.hh>
@ -33,6 +34,7 @@
#include "index/IndexFactory.h" #include "index/IndexFactory.h"
#include "exec/expression/Expr.h" #include "exec/expression/Expr.h"
#include "exec/Task.h" #include "exec/Task.h"
#include "exec/expression/function/FunctionFactory.h"
#include "expr/ITypeExpr.h" #include "expr/ITypeExpr.h"
#include "index/BitmapIndex.h" #include "index/BitmapIndex.h"
#include "index/InvertedIndexTantivy.h" #include "index/InvertedIndexTantivy.h"
@ -1739,6 +1741,175 @@ TEST_P(ExprTest, TestTermNullable) {
} }
} }
TEST_P(ExprTest, TestCall) {
milvus::exec::expression::FunctionFactory& factory =
milvus::exec::expression::FunctionFactory::Instance();
factory.Initialize();
auto schema = std::make_shared<Schema>();
auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type);
auto varchar_fid = schema->AddDebugField("address", DataType::VARCHAR);
schema->set_primary_field_id(varchar_fid);
auto seg = CreateGrowingSegment(schema, empty_index_meta);
int N = 1000;
std::vector<std::string> address_col;
int num_iters = 1;
for (int iter = 0; iter < num_iters; ++iter) {
auto raw_data = DataGen(schema, N, iter);
auto new_address_col = raw_data.get_col<std::string>(varchar_fid);
address_col.insert(
address_col.end(), new_address_col.begin(), new_address_col.end());
seg->PreInsert(N);
seg->Insert(iter * N,
N,
raw_data.row_ids_.data(),
raw_data.timestamps_.data(),
raw_data.raw_);
}
auto seg_promote = dynamic_cast<SegmentGrowingImpl*>(seg.get());
std::tuple<std::string, std::function<bool(std::string&)>> test_cases[] = {
{R"(vector_anns: <
field_id: 100
predicates: <
call_expr: <
function_name: "empty"
function_parameters: <
column_expr: <
info: <
field_id: 101
data_type: VarChar
>
>
>
>
>
query_info: <
topk: 10
round_decimal: 3
metric_type: "L2"
search_params: "{\"nprobe\": 10}"
>
placeholder_tag: "$0"
>)",
[](std::string& v) { return v.empty(); }},
{R"(vector_anns: <
field_id: 100
predicates: <
call_expr: <
function_name: "starts_with"
function_parameters: <
column_expr: <
info: <
field_id: 101
data_type: VarChar
>
>
>
function_parameters: <
column_expr: <
info: <
field_id: 101
data_type: VarChar
>
>
>
>
>
query_info: <
topk: 10
round_decimal: 3
metric_type: "L2"
search_params: "{\"nprobe\": 10}"
>
placeholder_tag: "$0"
>)",
[](std::string&) { return true; }}};
for (auto& [raw_plan, ref_func] : test_cases) {
auto plan_str = translate_text_plan_with_metric_type(raw_plan);
auto plan =
CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size());
BitsetType final;
final = ExecuteQueryExpr(
plan->plan_node_->plannodes_->sources()[0]->sources()[0],
seg_promote,
N * num_iters,
MAX_TIMESTAMP);
EXPECT_EQ(final.size(), N * num_iters);
for (int i = 0; i < N * num_iters; ++i) {
auto ans = final[i];
ASSERT_EQ(ans, ref_func(address_col[i]))
<< "@" << i << "!!" << address_col[i];
}
}
std::string incorrect_test_cases[] = {
R"(vector_anns: <
field_id: 100
predicates: <
call_expr: <
function_name: "empty"
function_parameters: <
column_expr: <
info: <
field_id: 101
data_type: VarChar
>
>
>
function_parameters: <
column_expr: <
info: <
field_id: 101
data_type: VarChar
>
>
>
>
>
query_info: <
topk: 10
round_decimal: 3
metric_type: "L2"
search_params: "{\"nprobe\": 10}"
>
placeholder_tag: "$0"
>)",
R"(vector_anns: <
field_id: 100
predicates: <
call_expr: <
function_name: "starts_with"
function_parameters: <
column_expr: <
info: <
field_id: 101
data_type: VarChar
>
>
>
>
>
query_info: <
topk: 10
round_decimal: 3
metric_type: "L2"
search_params: "{\"nprobe\": 10}"
>
placeholder_tag: "$0"
>)"};
for (auto& raw_plan : incorrect_test_cases) {
auto plan_str = translate_text_plan_with_metric_type(raw_plan);
EXPECT_ANY_THROW(
CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()));
}
}
TEST_P(ExprTest, TestCompare) { TEST_P(ExprTest, TestCompare) {
std::vector<std::tuple<std::string, std::function<bool(int, int64_t)>>> std::vector<std::tuple<std::string, std::function<bool(int, int64_t)>>>
testcases = { testcases = {

View File

@ -0,0 +1,239 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#include <gtest/gtest.h>
#include <vector>
#include "common/Types.h"
#include "common/Vector.h"
#include "exec/expression/function/impl/StringFunctions.h"
using namespace milvus;
using namespace milvus::exec::expression::function;
class FunctionTest : public ::testing::Test {
protected:
void
SetUp() override {
}
void
TearDown() override {
}
};
TEST_F(FunctionTest, Empty) {
std::vector<milvus::VectorPtr> arg_vec;
auto col1 =
std::make_shared<milvus::ColumnVector>(milvus::DataType::STRING, 15);
auto* col1_data = col1->RawAsValues<std::string>();
for (int i = 0; i <= 10; ++i) {
col1_data[i] = std::to_string(i);
}
for (int i = 11; i < 15; ++i) {
col1_data[i] = "";
}
arg_vec.push_back(col1);
milvus::RowVector args(std::move(arg_vec));
VectorPtr result;
EmptyVarchar(args, result);
auto result_vec = std::dynamic_pointer_cast<milvus::ColumnVector>(result);
ASSERT_NE(result_vec, nullptr);
TargetBitmapView bitmap(result_vec->GetRawData(), result_vec->size());
for (int i = 0; i < 15; ++i) {
EXPECT_TRUE(result_vec->ValidAt(i)) << "i: " << i;
EXPECT_EQ(bitmap[i], i >= 11) << "i: " << i;
}
}
TEST_F(FunctionTest, EmptyNull) {
std::vector<milvus::VectorPtr> arg_vec;
auto col1 = std::make_shared<milvus::ColumnVector>(
milvus::DataType::STRING, 15, 15);
arg_vec.push_back(col1);
milvus::RowVector args(std::move(arg_vec));
VectorPtr result;
EmptyVarchar(args, result);
auto result_vec = std::dynamic_pointer_cast<milvus::ColumnVector>(result);
ASSERT_NE(result_vec, nullptr);
TargetBitmapView bitmap(result_vec->GetRawData(), result_vec->size());
for (int i = 0; i < 15; ++i) {
EXPECT_FALSE(result_vec->ValidAt(i)) << "i: " << i;
}
}
TEST_F(FunctionTest, EmptyConstant) {
std::vector<milvus::VectorPtr> arg_vec;
auto col1 = std::make_shared<milvus::ConstantVector<std::string>>(
milvus::DataType::STRING, 15, "xx");
arg_vec.push_back(col1);
milvus::RowVector args(std::move(arg_vec));
VectorPtr result;
EmptyVarchar(args, result);
auto result_vec = std::dynamic_pointer_cast<milvus::ColumnVector>(result);
ASSERT_NE(result_vec, nullptr);
TargetBitmapView bitmap(result_vec->GetRawData(), result_vec->size());
for (int i = 0; i < 15; ++i) {
EXPECT_TRUE(result_vec->ValidAt(i)) << "i: " << i;
EXPECT_FALSE(bitmap[i]) << "i: " << i;
}
}
TEST_F(FunctionTest, EmptyIncorrectArgs) {
VectorPtr result;
std::vector<milvus::VectorPtr> arg_vec;
// empty args
milvus::RowVector empty_args(arg_vec);
EXPECT_ANY_THROW(EmptyVarchar(empty_args, result));
// incorrect type, expected string or varchar
arg_vec.push_back(std::make_shared<milvus::ColumnVector>(
milvus::DataType::INT32, 15, 15));
milvus::RowVector int_args(arg_vec);
EXPECT_ANY_THROW(EmptyVarchar(int_args, result));
arg_vec.clear();
// incorrect size, expected 1
arg_vec.push_back(std::make_shared<milvus::ColumnVector>(
milvus::DataType::STRING, 15, 15));
arg_vec.push_back(std::make_shared<milvus::ColumnVector>(
milvus::DataType::STRING, 15, 15));
milvus::RowVector multi_args(arg_vec);
EXPECT_ANY_THROW(EmptyVarchar(multi_args, result));
}
static constexpr int64_t STARTS_WITH_ROW_COUNT = 8;
static void
InitStrsForStartWith(std::shared_ptr<milvus::ColumnVector> col1) {
auto* col1_data = col1->RawAsValues<std::string>();
col1_data[0] = "123";
col1_data[1] = "";
col1_data[2] = "";
TargetBitmapView valid_bitmap_col1(col1->GetValidRawData(), col1->size());
valid_bitmap_col1[2] = false;
col1_data[3] = "aaabbbaaa";
col1_data[4] = "aaabbbaaa";
col1_data[5] = "xx";
col1_data[6] = "xx";
col1_data[7] = "1";
}
static void
StartWithCheck(VectorPtr result, bool valid[], bool expected[]) {
auto result_vec = std::dynamic_pointer_cast<milvus::ColumnVector>(result);
ASSERT_NE(result_vec, nullptr);
TargetBitmapView bitmap(result_vec->GetRawData(), result_vec->size());
for (int i = 0; i < STARTS_WITH_ROW_COUNT; ++i) {
EXPECT_EQ(result_vec->ValidAt(i), valid[i]) << "i: " << i;
EXPECT_EQ(bitmap[i], expected[i]) << "i: " << i;
}
}
TEST_F(FunctionTest, StartsWithColumnVector) {
std::vector<milvus::VectorPtr> arg_vec;
auto col1 = std::make_shared<milvus::ColumnVector>(milvus::DataType::STRING,
STARTS_WITH_ROW_COUNT);
InitStrsForStartWith(col1);
arg_vec.push_back(col1);
auto col2 = std::make_shared<milvus::ColumnVector>(milvus::DataType::STRING,
STARTS_WITH_ROW_COUNT);
auto* col2_data = col2->RawAsValues<std::string>();
col2_data[0] = "12";
col2_data[1] = "1";
col2_data[3] = "aaabbbaaac";
col2_data[4] = "aaabbbaax";
col2_data[5] = "";
col2_data[6] = "";
TargetBitmapView valid_bitmap_col2(col2->GetValidRawData(), col2->size());
valid_bitmap_col2[6] = false;
col2_data[7] = "124";
arg_vec.push_back(col2);
milvus::RowVector args(std::move(arg_vec));
bool valid[STARTS_WITH_ROW_COUNT] = {
true, true, false, true, true, true, false, true};
bool expected[STARTS_WITH_ROW_COUNT] = {
true, false, false, false, false, true, false, false};
VectorPtr result;
StartsWithVarchar(args, result);
StartWithCheck(result, valid, expected);
}
TEST_F(FunctionTest, StartsWithColumnAndConstantVector) {
std::vector<milvus::VectorPtr> arg_vec;
auto col1 = std::make_shared<milvus::ColumnVector>(milvus::DataType::STRING,
STARTS_WITH_ROW_COUNT);
InitStrsForStartWith(col1);
arg_vec.push_back(col1);
const std::string constant_str = "1";
auto col2 = std::make_shared<milvus::ConstantVector<std::string>>(
milvus::DataType::STRING, STARTS_WITH_ROW_COUNT, constant_str);
arg_vec.push_back(col2);
milvus::RowVector args(std::move(arg_vec));
bool valid[STARTS_WITH_ROW_COUNT] = {
true, true, false, true, true, true, true, true};
bool expected[STARTS_WITH_ROW_COUNT] = {
true, false, false, false, false, false, false, true};
VectorPtr result;
StartsWithVarchar(args, result);
StartWithCheck(result, valid, expected);
}
TEST_F(FunctionTest, StartsWithIncorrectArgs) {
VectorPtr result;
std::vector<milvus::VectorPtr> arg_vec;
// empty args
milvus::RowVector empty_args(arg_vec);
EXPECT_ANY_THROW(StartsWithVarchar(empty_args, result));
// incorrect type, expected string or varchar
arg_vec.push_back(std::make_shared<milvus::ColumnVector>(
milvus::DataType::STRING, 15, 15));
arg_vec.push_back(std::make_shared<milvus::ColumnVector>(
milvus::DataType::INT32, 15, 15));
milvus::RowVector string_int_args(arg_vec);
EXPECT_ANY_THROW(StartsWithVarchar(string_int_args, result));
arg_vec.clear();
// incorrect size, expected 2
arg_vec.push_back(std::make_shared<milvus::ColumnVector>(
milvus::DataType::STRING, 15, 15));
milvus::RowVector single_args(arg_vec);
EXPECT_ANY_THROW(StartsWithVarchar(single_args, result));
arg_vec.clear();
arg_vec.push_back(std::make_shared<milvus::ColumnVector>(
milvus::DataType::STRING, 15, 15));
arg_vec.push_back(std::make_shared<milvus::ColumnVector>(
milvus::DataType::STRING, 15, 15));
arg_vec.push_back(std::make_shared<milvus::ColumnVector>(
milvus::DataType::STRING, 15, 15));
milvus::RowVector three_args(arg_vec);
EXPECT_ANY_THROW(StartsWithVarchar(three_args, result));
}

View File

@ -24,6 +24,7 @@ expr:
| (JSONContainsAll | ArrayContainsAll)'('expr',' expr')' # JSONContainsAll | (JSONContainsAll | ArrayContainsAll)'('expr',' expr')' # JSONContainsAll
| (JSONContainsAny | ArrayContainsAny)'('expr',' expr')' # JSONContainsAny | (JSONContainsAny | ArrayContainsAny)'('expr',' expr')' # JSONContainsAny
| ArrayLength'('(Identifier | JSONIdentifier)')' # ArrayLength | ArrayLength'('(Identifier | JSONIdentifier)')' # ArrayLength
| Identifier '(' ( expr (',' expr )* ','? )? ')' # Call
| expr op1 = (LT | LE) (Identifier | JSONIdentifier) op2 = (LT | LE) expr # Range | expr op1 = (LT | LE) (Identifier | JSONIdentifier) op2 = (LT | LE) expr # Range
| expr op1 = (GT | GE) (Identifier | JSONIdentifier) op2 = (GT | GE) expr # ReverseRange | expr op1 = (GT | GE) (Identifier | JSONIdentifier) op2 = (GT | GE) expr # ReverseRange
| expr op = (LT | LE | GT | GE) expr # Relational | expr op = (LT | LE | GT | GE) expr # Relational

View File

@ -14,17 +14,27 @@ func TestCheckIdentical(t *testing.T) {
helper, err := typeutil.CreateSchemaHelper(schema) helper, err := typeutil.CreateSchemaHelper(schema)
assert.NoError(t, err) assert.NoError(t, err)
exprStr1 := `not (((Int64Field > 0) and (FloatField <= 20.0)) or ((Int32Field in [1, 2, 3]) and (VarCharField < "str")))` exprStr1Arr := []string{
exprStr2 := `Int32Field in [1, 2, 3]` `not (((Int64Field > 0) and (FloatField <= 20.0)) or ((Int32Field in [1, 2, 3]) and (VarCharField < "str")))`,
`f1()`,
}
exprStr2Arr := []string{
`Int32Field in [1, 2, 3]`,
`f2(Int32Field, Int64Field)`,
}
for i := range exprStr1Arr {
exprStr1 := exprStr1Arr[i]
exprStr2 := exprStr2Arr[i]
expr1, err := ParseExpr(helper, exprStr1) expr1, err := ParseExpr(helper, exprStr1)
assert.NoError(t, err) assert.NoError(t, err)
expr2, err := ParseExpr(helper, exprStr2) expr2, err := ParseExpr(helper, exprStr2)
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, CheckPredicatesIdentical(expr1, expr1)) assert.True(t, CheckPredicatesIdentical(expr1, expr1))
assert.True(t, CheckPredicatesIdentical(expr2, expr2)) assert.True(t, CheckPredicatesIdentical(expr2, expr2))
assert.False(t, CheckPredicatesIdentical(expr1, expr2)) assert.False(t, CheckPredicatesIdentical(expr1, expr2))
}
} }
func TestCheckQueryInfoIdentical(t *testing.T) { func TestCheckQueryInfoIdentical(t *testing.T) {

View File

@ -101,4 +101,4 @@ expr
atn: atn:
[4, 1, 46, 123, 2, 0, 7, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 5, 0, 18, 8, 0, 10, 0, 12, 0, 21, 9, 0, 1, 0, 3, 0, 24, 8, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 3, 0, 64, 8, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 3, 0, 80, 8, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 5, 0, 118, 8, 0, 10, 0, 12, 0, 121, 9, 0, 1, 0, 0, 1, 0, 1, 0, 0, 12, 2, 0, 15, 16, 28, 29, 2, 0, 32, 32, 35, 35, 2, 0, 33, 33, 36, 36, 2, 0, 34, 34, 37, 37, 2, 0, 42, 42, 44, 44, 1, 0, 17, 19, 1, 0, 15, 16, 1, 0, 21, 22, 1, 0, 6, 7, 1, 0, 8, 9, 1, 0, 6, 9, 1, 0, 10, 11, 154, 0, 63, 1, 0, 0, 0, 2, 3, 6, 0, -1, 0, 3, 64, 5, 40, 0, 0, 4, 64, 5, 41, 0, 0, 5, 64, 5, 39, 0, 0, 6, 64, 5, 43, 0, 0, 7, 64, 5, 42, 0, 0, 8, 64, 5, 44, 0, 0, 9, 10, 5, 1, 0, 0, 10, 11, 3, 0, 0, 0, 11, 12, 5, 2, 0, 0, 12, 64, 1, 0, 0, 0, 13, 14, 5, 3, 0, 0, 14, 19, 3, 0, 0, 0, 15, 16, 5, 4, 0, 0, 16, 18, 3, 0, 0, 0, 17, 15, 1, 0, 0, 0, 18, 21, 1, 0, 0, 0, 19, 17, 1, 0, 0, 0, 19, 20, 1, 0, 0, 0, 20, 23, 1, 0, 0, 0, 21, 19, 1, 0, 0, 0, 22, 24, 5, 4, 0, 0, 23, 22, 1, 0, 0, 0, 23, 24, 1, 0, 0, 0, 24, 25, 1, 0, 0, 0, 25, 26, 5, 5, 0, 0, 26, 64, 1, 0, 0, 0, 27, 64, 5, 31, 0, 0, 28, 29, 5, 14, 0, 0, 29, 30, 5, 1, 0, 0, 30, 31, 5, 42, 0, 0, 31, 32, 5, 4, 0, 0, 32, 33, 5, 43, 0, 0, 33, 64, 5, 2, 0, 0, 34, 35, 7, 0, 0, 0, 35, 64, 3, 0, 0, 19, 36, 37, 7, 1, 0, 0, 37, 38, 5, 1, 0, 0, 38, 39, 3, 0, 0, 0, 39, 40, 5, 4, 0, 0, 40, 41, 3, 0, 0, 0, 41, 42, 5, 2, 0, 0, 42, 64, 1, 0, 0, 0, 43, 44, 7, 2, 0, 0, 44, 45, 5, 1, 0, 0, 45, 46, 3, 0, 0, 0, 46, 47, 5, 4, 0, 0, 47, 48, 3, 0, 0, 0, 48, 49, 5, 2, 0, 0, 49, 64, 1, 0, 0, 0, 50, 51, 7, 3, 0, 0, 51, 52, 5, 1, 0, 0, 52, 53, 3, 0, 0, 0, 53, 54, 5, 4, 0, 0, 54, 55, 3, 0, 0, 0, 55, 56, 5, 2, 0, 0, 56, 64, 1, 0, 0, 0, 57, 58, 5, 38, 0, 0, 58, 59, 5, 1, 0, 0, 59, 60, 7, 4, 0, 0, 60, 64, 5, 2, 0, 0, 61, 62, 5, 13, 0, 0, 62, 64, 3, 0, 0, 1, 63, 2, 1, 0, 0, 0, 63, 4, 1, 0, 0, 0, 63, 5, 1, 0, 0, 0, 63, 6, 1, 0, 0, 0, 63, 7, 1, 0, 0, 0, 63, 8, 1, 0, 0, 0, 63, 9, 1, 0, 0, 0, 63, 13, 1, 0, 0, 0, 63, 27, 1, 0, 0, 0, 63, 28, 1, 0, 0, 0, 63, 34, 1, 0, 0, 0, 63, 36, 1, 0, 0, 0, 63, 43, 1, 0, 0, 0, 63, 50, 1, 0, 0, 0, 63, 57, 1, 0, 0, 0, 63, 61, 1, 0, 0, 0, 64, 119, 1, 0, 0, 0, 65, 66, 10, 20, 0, 0, 66, 67, 5, 20, 0, 0, 67, 118, 3, 0, 0, 21, 68, 69, 10, 18, 0, 0, 69, 70, 7, 5, 0, 0, 70, 118, 3, 0, 0, 19, 71, 72, 10, 17, 0, 0, 72, 73, 7, 6, 0, 0, 73, 118, 3, 0, 0, 18, 74, 75, 10, 16, 0, 0, 75, 76, 7, 7, 0, 0, 76, 118, 3, 0, 0, 17, 77, 79, 10, 15, 0, 0, 78, 80, 5, 29, 0, 0, 79, 78, 1, 0, 0, 0, 79, 80, 1, 0, 0, 0, 80, 81, 1, 0, 0, 0, 81, 82, 5, 30, 0, 0, 82, 118, 3, 0, 0, 16, 83, 84, 10, 10, 0, 0, 84, 85, 7, 8, 0, 0, 85, 86, 7, 4, 0, 0, 86, 87, 7, 8, 0, 0, 87, 118, 3, 0, 0, 11, 88, 89, 10, 9, 0, 0, 89, 90, 7, 9, 0, 0, 90, 91, 7, 4, 0, 0, 91, 92, 7, 9, 0, 0, 92, 118, 3, 0, 0, 10, 93, 94, 10, 8, 0, 0, 94, 95, 7, 10, 0, 0, 95, 118, 3, 0, 0, 9, 96, 97, 10, 7, 0, 0, 97, 98, 7, 11, 0, 0, 98, 118, 3, 0, 0, 8, 99, 100, 10, 6, 0, 0, 100, 101, 5, 23, 0, 0, 101, 118, 3, 0, 0, 7, 102, 103, 10, 5, 0, 0, 103, 104, 5, 25, 0, 0, 104, 118, 3, 0, 0, 6, 105, 106, 10, 4, 0, 0, 106, 107, 5, 24, 0, 0, 107, 118, 3, 0, 0, 5, 108, 109, 10, 3, 0, 0, 109, 110, 5, 26, 0, 0, 110, 118, 3, 0, 0, 4, 111, 112, 10, 2, 0, 0, 112, 113, 5, 27, 0, 0, 113, 118, 3, 0, 0, 3, 114, 115, 10, 22, 0, 0, 115, 116, 5, 12, 0, 0, 116, 118, 5, 43, 0, 0, 117, 65, 1, 0, 0, 0, 117, 68, 1, 0, 0, 0, 117, 71, 1, 0, 0, 0, 117, 74, 1, 0, 0, 0, 117, 77, 1, 0, 0, 0, 117, 83, 1, 0, 0, 0, 117, 88, 1, 0, 0, 0, 117, 93, 1, 0, 0, 0, 117, 96, 1, 0, 0, 0, 117, 99, 1, 0, 0, 0, 117, 102, 1, 0, 0, 0, 117, 105, 1, 0, 0, 0, 117, 108, 1, 0, 0, 0, 117, 111, 1, 0, 0, 0, 117, 114, 1, 0, 0, 0, 118, 121, 1, 0, 0, 0, 119, 117, 1, 0, 0, 0, 119, 120, 1, 0, 0, 0, 120, 1, 1, 0, 0, 0, 121, 119, 1, 0, 0, 0, 6, 19, 23, 63, 79, 117, 119] [4, 1, 46, 139, 2, 0, 7, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 5, 0, 18, 8, 0, 10, 0, 12, 0, 21, 9, 0, 1, 0, 3, 0, 24, 8, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 5, 0, 67, 8, 0, 10, 0, 12, 0, 70, 9, 0, 1, 0, 3, 0, 73, 8, 0, 3, 0, 75, 8, 0, 1, 0, 1, 0, 1, 0, 3, 0, 80, 8, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 3, 0, 96, 8, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 5, 0, 134, 8, 0, 10, 0, 12, 0, 137, 9, 0, 1, 0, 0, 1, 0, 1, 0, 0, 12, 2, 0, 15, 16, 28, 29, 2, 0, 32, 32, 35, 35, 2, 0, 33, 33, 36, 36, 2, 0, 34, 34, 37, 37, 2, 0, 42, 42, 44, 44, 1, 0, 17, 19, 1, 0, 15, 16, 1, 0, 21, 22, 1, 0, 6, 7, 1, 0, 8, 9, 1, 0, 6, 9, 1, 0, 10, 11, 174, 0, 79, 1, 0, 0, 0, 2, 3, 6, 0, -1, 0, 3, 80, 5, 40, 0, 0, 4, 80, 5, 41, 0, 0, 5, 80, 5, 39, 0, 0, 6, 80, 5, 43, 0, 0, 7, 80, 5, 42, 0, 0, 8, 80, 5, 44, 0, 0, 9, 10, 5, 1, 0, 0, 10, 11, 3, 0, 0, 0, 11, 12, 5, 2, 0, 0, 12, 80, 1, 0, 0, 0, 13, 14, 5, 3, 0, 0, 14, 19, 3, 0, 0, 0, 15, 16, 5, 4, 0, 0, 16, 18, 3, 0, 0, 0, 17, 15, 1, 0, 0, 0, 18, 21, 1, 0, 0, 0, 19, 17, 1, 0, 0, 0, 19, 20, 1, 0, 0, 0, 20, 23, 1, 0, 0, 0, 21, 19, 1, 0, 0, 0, 22, 24, 5, 4, 0, 0, 23, 22, 1, 0, 0, 0, 23, 24, 1, 0, 0, 0, 24, 25, 1, 0, 0, 0, 25, 26, 5, 5, 0, 0, 26, 80, 1, 0, 0, 0, 27, 80, 5, 31, 0, 0, 28, 29, 5, 14, 0, 0, 29, 30, 5, 1, 0, 0, 30, 31, 5, 42, 0, 0, 31, 32, 5, 4, 0, 0, 32, 33, 5, 43, 0, 0, 33, 80, 5, 2, 0, 0, 34, 35, 7, 0, 0, 0, 35, 80, 3, 0, 0, 20, 36, 37, 7, 1, 0, 0, 37, 38, 5, 1, 0, 0, 38, 39, 3, 0, 0, 0, 39, 40, 5, 4, 0, 0, 40, 41, 3, 0, 0, 0, 41, 42, 5, 2, 0, 0, 42, 80, 1, 0, 0, 0, 43, 44, 7, 2, 0, 0, 44, 45, 5, 1, 0, 0, 45, 46, 3, 0, 0, 0, 46, 47, 5, 4, 0, 0, 47, 48, 3, 0, 0, 0, 48, 49, 5, 2, 0, 0, 49, 80, 1, 0, 0, 0, 50, 51, 7, 3, 0, 0, 51, 52, 5, 1, 0, 0, 52, 53, 3, 0, 0, 0, 53, 54, 5, 4, 0, 0, 54, 55, 3, 0, 0, 0, 55, 56, 5, 2, 0, 0, 56, 80, 1, 0, 0, 0, 57, 58, 5, 38, 0, 0, 58, 59, 5, 1, 0, 0, 59, 60, 7, 4, 0, 0, 60, 80, 5, 2, 0, 0, 61, 62, 5, 42, 0, 0, 62, 74, 5, 1, 0, 0, 63, 68, 3, 0, 0, 0, 64, 65, 5, 4, 0, 0, 65, 67, 3, 0, 0, 0, 66, 64, 1, 0, 0, 0, 67, 70, 1, 0, 0, 0, 68, 66, 1, 0, 0, 0, 68, 69, 1, 0, 0, 0, 69, 72, 1, 0, 0, 0, 70, 68, 1, 0, 0, 0, 71, 73, 5, 4, 0, 0, 72, 71, 1, 0, 0, 0, 72, 73, 1, 0, 0, 0, 73, 75, 1, 0, 0, 0, 74, 63, 1, 0, 0, 0, 74, 75, 1, 0, 0, 0, 75, 76, 1, 0, 0, 0, 76, 80, 5, 2, 0, 0, 77, 78, 5, 13, 0, 0, 78, 80, 3, 0, 0, 1, 79, 2, 1, 0, 0, 0, 79, 4, 1, 0, 0, 0, 79, 5, 1, 0, 0, 0, 79, 6, 1, 0, 0, 0, 79, 7, 1, 0, 0, 0, 79, 8, 1, 0, 0, 0, 79, 9, 1, 0, 0, 0, 79, 13, 1, 0, 0, 0, 79, 27, 1, 0, 0, 0, 79, 28, 1, 0, 0, 0, 79, 34, 1, 0, 0, 0, 79, 36, 1, 0, 0, 0, 79, 43, 1, 0, 0, 0, 79, 50, 1, 0, 0, 0, 79, 57, 1, 0, 0, 0, 79, 61, 1, 0, 0, 0, 79, 77, 1, 0, 0, 0, 80, 135, 1, 0, 0, 0, 81, 82, 10, 21, 0, 0, 82, 83, 5, 20, 0, 0, 83, 134, 3, 0, 0, 22, 84, 85, 10, 19, 0, 0, 85, 86, 7, 5, 0, 0, 86, 134, 3, 0, 0, 20, 87, 88, 10, 18, 0, 0, 88, 89, 7, 6, 0, 0, 89, 134, 3, 0, 0, 19, 90, 91, 10, 17, 0, 0, 91, 92, 7, 7, 0, 0, 92, 134, 3, 0, 0, 18, 93, 95, 10, 16, 0, 0, 94, 96, 5, 29, 0, 0, 95, 94, 1, 0, 0, 0, 95, 96, 1, 0, 0, 0, 96, 97, 1, 0, 0, 0, 97, 98, 5, 30, 0, 0, 98, 134, 3, 0, 0, 17, 99, 100, 10, 10, 0, 0, 100, 101, 7, 8, 0, 0, 101, 102, 7, 4, 0, 0, 102, 103, 7, 8, 0, 0, 103, 134, 3, 0, 0, 11, 104, 105, 10, 9, 0, 0, 105, 106, 7, 9, 0, 0, 106, 107, 7, 4, 0, 0, 107, 108, 7, 9, 0, 0, 108, 134, 3, 0, 0, 10, 109, 110, 10, 8, 0, 0, 110, 111, 7, 10, 0, 0, 111, 134, 3, 0, 0, 9, 112, 113, 10, 7, 0, 0, 113, 114, 7, 11, 0, 0, 114, 134, 3, 0, 0, 8, 115, 116, 10, 6, 0, 0, 116, 117, 5, 23, 0, 0, 117, 134, 3, 0, 0, 7, 118, 119, 10, 5, 0, 0, 119, 120, 5, 25, 0, 0, 120, 134, 3, 0, 0, 6, 121, 122, 10, 4, 0, 0, 122, 123, 5, 24, 0, 0, 123, 134, 3, 0, 0, 5, 124, 125, 10, 3, 0, 0, 125, 126, 5, 26, 0, 0, 126, 134, 3, 0, 0, 4, 127, 128, 10, 2, 0, 0, 128, 129, 5, 27, 0, 0, 129, 134, 3, 0, 0, 3, 130, 131, 10, 23, 0, 0, 131, 132, 5, 12, 0, 0, 132, 134, 5, 43, 0, 0, 133, 81, 1, 0, 0, 0, 133, 84, 1, 0, 0, 0, 133, 87, 1, 0, 0, 0, 133, 90, 1, 0, 0, 0, 133, 93, 1, 0, 0, 0, 133, 99, 1, 0, 0, 0, 133, 104, 1, 0, 0, 0, 133, 109, 1, 0, 0, 0, 133, 112, 1, 0, 0, 0, 133, 115, 1, 0, 0, 0, 133, 118, 1, 0, 0, 0, 133, 121, 1, 0, 0, 0, 133, 124, 1, 0, 0, 0, 133, 127, 1, 0, 0, 0, 133, 130, 1, 0, 0, 0, 134, 137, 1, 0, 0, 0, 135, 133, 1, 0, 0, 0, 135, 136, 1, 0, 0, 0, 136, 1, 1, 0, 0, 0, 137, 135, 1, 0, 0, 0, 9, 19, 23, 68, 72, 74, 79, 95, 133, 135]

View File

@ -59,6 +59,10 @@ func (v *BasePlanVisitor) VisitShift(ctx *ShiftContext) interface{} {
return v.VisitChildren(ctx) return v.VisitChildren(ctx)
} }
func (v *BasePlanVisitor) VisitCall(ctx *CallContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitReverseRange(ctx *ReverseRangeContext) interface{} { func (v *BasePlanVisitor) VisitReverseRange(ctx *ReverseRangeContext) interface{} {
return v.VisitChildren(ctx) return v.VisitChildren(ctx)
} }

View File

@ -50,65 +50,73 @@ func planParserInit() {
} }
staticData.PredictionContextCache = antlr.NewPredictionContextCache() staticData.PredictionContextCache = antlr.NewPredictionContextCache()
staticData.serializedATN = []int32{ staticData.serializedATN = []int32{
4, 1, 46, 123, 2, 0, 7, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 4, 1, 46, 139, 2, 0, 7, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 5, 0, 18, 8, 0, 10, 0, 12, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 5, 0, 18, 8, 0, 10, 0, 12,
0, 21, 9, 0, 1, 0, 3, 0, 24, 8, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 21, 9, 0, 1, 0, 3, 0, 24, 8, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,
1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,
1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,
1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 3, 0, 64, 8, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 5, 0,
0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 67, 8, 0, 10, 0, 12, 0, 70, 9, 0, 1, 0, 3, 0, 73, 8, 0, 3, 0, 75, 8, 0,
0, 3, 0, 80, 8, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 3, 0, 80, 8, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 3, 0, 96, 8, 0, 1, 0, 1, 0,
1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,
1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,
1, 0, 1, 0, 1, 0, 5, 0, 118, 8, 0, 10, 0, 12, 0, 121, 9, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 5, 0, 134,
0, 1, 0, 0, 12, 2, 0, 15, 16, 28, 29, 2, 0, 32, 32, 35, 35, 2, 0, 33, 33, 8, 0, 10, 0, 12, 0, 137, 9, 0, 1, 0, 0, 1, 0, 1, 0, 0, 12, 2, 0, 15, 16,
36, 36, 2, 0, 34, 34, 37, 37, 2, 0, 42, 42, 44, 44, 1, 0, 17, 19, 1, 0, 28, 29, 2, 0, 32, 32, 35, 35, 2, 0, 33, 33, 36, 36, 2, 0, 34, 34, 37, 37,
15, 16, 1, 0, 21, 22, 1, 0, 6, 7, 1, 0, 8, 9, 1, 0, 6, 9, 1, 0, 10, 11, 2, 0, 42, 42, 44, 44, 1, 0, 17, 19, 1, 0, 15, 16, 1, 0, 21, 22, 1, 0, 6,
154, 0, 63, 1, 0, 0, 0, 2, 3, 6, 0, -1, 0, 3, 64, 5, 40, 0, 0, 4, 64, 5, 7, 1, 0, 8, 9, 1, 0, 6, 9, 1, 0, 10, 11, 174, 0, 79, 1, 0, 0, 0, 2, 3,
41, 0, 0, 5, 64, 5, 39, 0, 0, 6, 64, 5, 43, 0, 0, 7, 64, 5, 42, 0, 0, 8, 6, 0, -1, 0, 3, 80, 5, 40, 0, 0, 4, 80, 5, 41, 0, 0, 5, 80, 5, 39, 0, 0,
64, 5, 44, 0, 0, 9, 10, 5, 1, 0, 0, 10, 11, 3, 0, 0, 0, 11, 12, 5, 2, 0, 6, 80, 5, 43, 0, 0, 7, 80, 5, 42, 0, 0, 8, 80, 5, 44, 0, 0, 9, 10, 5, 1,
0, 12, 64, 1, 0, 0, 0, 13, 14, 5, 3, 0, 0, 14, 19, 3, 0, 0, 0, 15, 16, 0, 0, 10, 11, 3, 0, 0, 0, 11, 12, 5, 2, 0, 0, 12, 80, 1, 0, 0, 0, 13, 14,
5, 4, 0, 0, 16, 18, 3, 0, 0, 0, 17, 15, 1, 0, 0, 0, 18, 21, 1, 0, 0, 0, 5, 3, 0, 0, 14, 19, 3, 0, 0, 0, 15, 16, 5, 4, 0, 0, 16, 18, 3, 0, 0, 0,
19, 17, 1, 0, 0, 0, 19, 20, 1, 0, 0, 0, 20, 23, 1, 0, 0, 0, 21, 19, 1, 17, 15, 1, 0, 0, 0, 18, 21, 1, 0, 0, 0, 19, 17, 1, 0, 0, 0, 19, 20, 1,
0, 0, 0, 22, 24, 5, 4, 0, 0, 23, 22, 1, 0, 0, 0, 23, 24, 1, 0, 0, 0, 24, 0, 0, 0, 20, 23, 1, 0, 0, 0, 21, 19, 1, 0, 0, 0, 22, 24, 5, 4, 0, 0, 23,
25, 1, 0, 0, 0, 25, 26, 5, 5, 0, 0, 26, 64, 1, 0, 0, 0, 27, 64, 5, 31, 22, 1, 0, 0, 0, 23, 24, 1, 0, 0, 0, 24, 25, 1, 0, 0, 0, 25, 26, 5, 5, 0,
0, 0, 28, 29, 5, 14, 0, 0, 29, 30, 5, 1, 0, 0, 30, 31, 5, 42, 0, 0, 31, 0, 26, 80, 1, 0, 0, 0, 27, 80, 5, 31, 0, 0, 28, 29, 5, 14, 0, 0, 29, 30,
32, 5, 4, 0, 0, 32, 33, 5, 43, 0, 0, 33, 64, 5, 2, 0, 0, 34, 35, 7, 0, 5, 1, 0, 0, 30, 31, 5, 42, 0, 0, 31, 32, 5, 4, 0, 0, 32, 33, 5, 43, 0,
0, 0, 35, 64, 3, 0, 0, 19, 36, 37, 7, 1, 0, 0, 37, 38, 5, 1, 0, 0, 38, 0, 33, 80, 5, 2, 0, 0, 34, 35, 7, 0, 0, 0, 35, 80, 3, 0, 0, 20, 36, 37,
39, 3, 0, 0, 0, 39, 40, 5, 4, 0, 0, 40, 41, 3, 0, 0, 0, 41, 42, 5, 2, 0, 7, 1, 0, 0, 37, 38, 5, 1, 0, 0, 38, 39, 3, 0, 0, 0, 39, 40, 5, 4, 0, 0,
0, 42, 64, 1, 0, 0, 0, 43, 44, 7, 2, 0, 0, 44, 45, 5, 1, 0, 0, 45, 46, 40, 41, 3, 0, 0, 0, 41, 42, 5, 2, 0, 0, 42, 80, 1, 0, 0, 0, 43, 44, 7,
3, 0, 0, 0, 46, 47, 5, 4, 0, 0, 47, 48, 3, 0, 0, 0, 48, 49, 5, 2, 0, 0, 2, 0, 0, 44, 45, 5, 1, 0, 0, 45, 46, 3, 0, 0, 0, 46, 47, 5, 4, 0, 0, 47,
49, 64, 1, 0, 0, 0, 50, 51, 7, 3, 0, 0, 51, 52, 5, 1, 0, 0, 52, 53, 3, 48, 3, 0, 0, 0, 48, 49, 5, 2, 0, 0, 49, 80, 1, 0, 0, 0, 50, 51, 7, 3, 0,
0, 0, 0, 53, 54, 5, 4, 0, 0, 54, 55, 3, 0, 0, 0, 55, 56, 5, 2, 0, 0, 56, 0, 51, 52, 5, 1, 0, 0, 52, 53, 3, 0, 0, 0, 53, 54, 5, 4, 0, 0, 54, 55,
64, 1, 0, 0, 0, 57, 58, 5, 38, 0, 0, 58, 59, 5, 1, 0, 0, 59, 60, 7, 4, 3, 0, 0, 0, 55, 56, 5, 2, 0, 0, 56, 80, 1, 0, 0, 0, 57, 58, 5, 38, 0, 0,
0, 0, 60, 64, 5, 2, 0, 0, 61, 62, 5, 13, 0, 0, 62, 64, 3, 0, 0, 1, 63, 58, 59, 5, 1, 0, 0, 59, 60, 7, 4, 0, 0, 60, 80, 5, 2, 0, 0, 61, 62, 5,
2, 1, 0, 0, 0, 63, 4, 1, 0, 0, 0, 63, 5, 1, 0, 0, 0, 63, 6, 1, 0, 0, 0, 42, 0, 0, 62, 74, 5, 1, 0, 0, 63, 68, 3, 0, 0, 0, 64, 65, 5, 4, 0, 0, 65,
63, 7, 1, 0, 0, 0, 63, 8, 1, 0, 0, 0, 63, 9, 1, 0, 0, 0, 63, 13, 1, 0, 67, 3, 0, 0, 0, 66, 64, 1, 0, 0, 0, 67, 70, 1, 0, 0, 0, 68, 66, 1, 0, 0,
0, 0, 63, 27, 1, 0, 0, 0, 63, 28, 1, 0, 0, 0, 63, 34, 1, 0, 0, 0, 63, 36, 0, 68, 69, 1, 0, 0, 0, 69, 72, 1, 0, 0, 0, 70, 68, 1, 0, 0, 0, 71, 73,
1, 0, 0, 0, 63, 43, 1, 0, 0, 0, 63, 50, 1, 0, 0, 0, 63, 57, 1, 0, 0, 0, 5, 4, 0, 0, 72, 71, 1, 0, 0, 0, 72, 73, 1, 0, 0, 0, 73, 75, 1, 0, 0, 0,
63, 61, 1, 0, 0, 0, 64, 119, 1, 0, 0, 0, 65, 66, 10, 20, 0, 0, 66, 67, 74, 63, 1, 0, 0, 0, 74, 75, 1, 0, 0, 0, 75, 76, 1, 0, 0, 0, 76, 80, 5,
5, 20, 0, 0, 67, 118, 3, 0, 0, 21, 68, 69, 10, 18, 0, 0, 69, 70, 7, 5, 2, 0, 0, 77, 78, 5, 13, 0, 0, 78, 80, 3, 0, 0, 1, 79, 2, 1, 0, 0, 0, 79,
0, 0, 70, 118, 3, 0, 0, 19, 71, 72, 10, 17, 0, 0, 72, 73, 7, 6, 0, 0, 73, 4, 1, 0, 0, 0, 79, 5, 1, 0, 0, 0, 79, 6, 1, 0, 0, 0, 79, 7, 1, 0, 0, 0,
118, 3, 0, 0, 18, 74, 75, 10, 16, 0, 0, 75, 76, 7, 7, 0, 0, 76, 118, 3, 79, 8, 1, 0, 0, 0, 79, 9, 1, 0, 0, 0, 79, 13, 1, 0, 0, 0, 79, 27, 1, 0,
0, 0, 17, 77, 79, 10, 15, 0, 0, 78, 80, 5, 29, 0, 0, 79, 78, 1, 0, 0, 0, 0, 0, 79, 28, 1, 0, 0, 0, 79, 34, 1, 0, 0, 0, 79, 36, 1, 0, 0, 0, 79, 43,
79, 80, 1, 0, 0, 0, 80, 81, 1, 0, 0, 0, 81, 82, 5, 30, 0, 0, 82, 118, 3, 1, 0, 0, 0, 79, 50, 1, 0, 0, 0, 79, 57, 1, 0, 0, 0, 79, 61, 1, 0, 0, 0,
0, 0, 16, 83, 84, 10, 10, 0, 0, 84, 85, 7, 8, 0, 0, 85, 86, 7, 4, 0, 0, 79, 77, 1, 0, 0, 0, 80, 135, 1, 0, 0, 0, 81, 82, 10, 21, 0, 0, 82, 83,
86, 87, 7, 8, 0, 0, 87, 118, 3, 0, 0, 11, 88, 89, 10, 9, 0, 0, 89, 90, 5, 20, 0, 0, 83, 134, 3, 0, 0, 22, 84, 85, 10, 19, 0, 0, 85, 86, 7, 5,
7, 9, 0, 0, 90, 91, 7, 4, 0, 0, 91, 92, 7, 9, 0, 0, 92, 118, 3, 0, 0, 10, 0, 0, 86, 134, 3, 0, 0, 20, 87, 88, 10, 18, 0, 0, 88, 89, 7, 6, 0, 0, 89,
93, 94, 10, 8, 0, 0, 94, 95, 7, 10, 0, 0, 95, 118, 3, 0, 0, 9, 96, 97, 134, 3, 0, 0, 19, 90, 91, 10, 17, 0, 0, 91, 92, 7, 7, 0, 0, 92, 134, 3,
10, 7, 0, 0, 97, 98, 7, 11, 0, 0, 98, 118, 3, 0, 0, 8, 99, 100, 10, 6, 0, 0, 18, 93, 95, 10, 16, 0, 0, 94, 96, 5, 29, 0, 0, 95, 94, 1, 0, 0, 0,
0, 0, 100, 101, 5, 23, 0, 0, 101, 118, 3, 0, 0, 7, 102, 103, 10, 5, 0, 95, 96, 1, 0, 0, 0, 96, 97, 1, 0, 0, 0, 97, 98, 5, 30, 0, 0, 98, 134, 3,
0, 103, 104, 5, 25, 0, 0, 104, 118, 3, 0, 0, 6, 105, 106, 10, 4, 0, 0, 0, 0, 17, 99, 100, 10, 10, 0, 0, 100, 101, 7, 8, 0, 0, 101, 102, 7, 4,
106, 107, 5, 24, 0, 0, 107, 118, 3, 0, 0, 5, 108, 109, 10, 3, 0, 0, 109, 0, 0, 102, 103, 7, 8, 0, 0, 103, 134, 3, 0, 0, 11, 104, 105, 10, 9, 0,
110, 5, 26, 0, 0, 110, 118, 3, 0, 0, 4, 111, 112, 10, 2, 0, 0, 112, 113, 0, 105, 106, 7, 9, 0, 0, 106, 107, 7, 4, 0, 0, 107, 108, 7, 9, 0, 0, 108,
5, 27, 0, 0, 113, 118, 3, 0, 0, 3, 114, 115, 10, 22, 0, 0, 115, 116, 5, 134, 3, 0, 0, 10, 109, 110, 10, 8, 0, 0, 110, 111, 7, 10, 0, 0, 111, 134,
12, 0, 0, 116, 118, 5, 43, 0, 0, 117, 65, 1, 0, 0, 0, 117, 68, 1, 0, 0, 3, 0, 0, 9, 112, 113, 10, 7, 0, 0, 113, 114, 7, 11, 0, 0, 114, 134, 3,
0, 117, 71, 1, 0, 0, 0, 117, 74, 1, 0, 0, 0, 117, 77, 1, 0, 0, 0, 117, 0, 0, 8, 115, 116, 10, 6, 0, 0, 116, 117, 5, 23, 0, 0, 117, 134, 3, 0,
83, 1, 0, 0, 0, 117, 88, 1, 0, 0, 0, 117, 93, 1, 0, 0, 0, 117, 96, 1, 0, 0, 7, 118, 119, 10, 5, 0, 0, 119, 120, 5, 25, 0, 0, 120, 134, 3, 0, 0,
0, 0, 117, 99, 1, 0, 0, 0, 117, 102, 1, 0, 0, 0, 117, 105, 1, 0, 0, 0, 6, 121, 122, 10, 4, 0, 0, 122, 123, 5, 24, 0, 0, 123, 134, 3, 0, 0, 5,
117, 108, 1, 0, 0, 0, 117, 111, 1, 0, 0, 0, 117, 114, 1, 0, 0, 0, 118, 124, 125, 10, 3, 0, 0, 125, 126, 5, 26, 0, 0, 126, 134, 3, 0, 0, 4, 127,
121, 1, 0, 0, 0, 119, 117, 1, 0, 0, 0, 119, 120, 1, 0, 0, 0, 120, 1, 1, 128, 10, 2, 0, 0, 128, 129, 5, 27, 0, 0, 129, 134, 3, 0, 0, 3, 130, 131,
0, 0, 0, 121, 119, 1, 0, 0, 0, 6, 19, 23, 63, 79, 117, 119, 10, 23, 0, 0, 131, 132, 5, 12, 0, 0, 132, 134, 5, 43, 0, 0, 133, 81, 1,
0, 0, 0, 133, 84, 1, 0, 0, 0, 133, 87, 1, 0, 0, 0, 133, 90, 1, 0, 0, 0,
133, 93, 1, 0, 0, 0, 133, 99, 1, 0, 0, 0, 133, 104, 1, 0, 0, 0, 133, 109,
1, 0, 0, 0, 133, 112, 1, 0, 0, 0, 133, 115, 1, 0, 0, 0, 133, 118, 1, 0,
0, 0, 133, 121, 1, 0, 0, 0, 133, 124, 1, 0, 0, 0, 133, 127, 1, 0, 0, 0,
133, 130, 1, 0, 0, 0, 134, 137, 1, 0, 0, 0, 135, 133, 1, 0, 0, 0, 135,
136, 1, 0, 0, 0, 136, 1, 1, 0, 0, 0, 137, 135, 1, 0, 0, 0, 9, 19, 23, 68,
72, 74, 79, 95, 133, 135,
} }
deserializer := antlr.NewATNDeserializer(nil) deserializer := antlr.NewATNDeserializer(nil)
staticData.atn = deserializer.Deserialize(staticData.serializedATN) staticData.atn = deserializer.Deserialize(staticData.serializedATN)
@ -981,6 +989,79 @@ func (s *ShiftContext) Accept(visitor antlr.ParseTreeVisitor) interface{} {
} }
} }
type CallContext struct {
ExprContext
}
func NewCallContext(parser antlr.Parser, ctx antlr.ParserRuleContext) *CallContext {
var p = new(CallContext)
InitEmptyExprContext(&p.ExprContext)
p.parser = parser
p.CopyAll(ctx.(*ExprContext))
return p
}
func (s *CallContext) GetRuleContext() antlr.RuleContext {
return s
}
func (s *CallContext) Identifier() antlr.TerminalNode {
return s.GetToken(PlanParserIdentifier, 0)
}
func (s *CallContext) AllExpr() []IExprContext {
children := s.GetChildren()
len := 0
for _, ctx := range children {
if _, ok := ctx.(IExprContext); ok {
len++
}
}
tst := make([]IExprContext, len)
i := 0
for _, ctx := range children {
if t, ok := ctx.(IExprContext); ok {
tst[i] = t.(IExprContext)
i++
}
}
return tst
}
func (s *CallContext) Expr(i int) IExprContext {
var t antlr.RuleContext
j := 0
for _, ctx := range s.GetChildren() {
if _, ok := ctx.(IExprContext); ok {
if j == i {
t = ctx.(antlr.RuleContext)
break
}
j++
}
}
if t == nil {
return nil
}
return t.(IExprContext)
}
func (s *CallContext) Accept(visitor antlr.ParseTreeVisitor) interface{} {
switch t := visitor.(type) {
case PlanVisitor:
return t.VisitCall(s)
default:
return t.VisitChildren(s)
}
}
type ReverseRangeContext struct { type ReverseRangeContext struct {
ExprContext ExprContext
op1 antlr.Token op1 antlr.Token
@ -2231,14 +2312,14 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) {
var _alt int var _alt int
p.EnterOuterAlt(localctx, 1) p.EnterOuterAlt(localctx, 1)
p.SetState(63) p.SetState(79)
p.GetErrorHandler().Sync(p) p.GetErrorHandler().Sync(p)
if p.HasError() { if p.HasError() {
goto errorExit goto errorExit
} }
switch p.GetTokenStream().LA(1) { switch p.GetInterpreter().AdaptivePredict(p.BaseParser, p.GetTokenStream(), 5, p.GetParserRuleContext()) {
case PlanParserIntegerConstant: case 1:
localctx = NewIntegerContext(p, localctx) localctx = NewIntegerContext(p, localctx)
p.SetParserRuleContext(localctx) p.SetParserRuleContext(localctx)
_prevctx = localctx _prevctx = localctx
@ -2252,7 +2333,7 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) {
} }
} }
case PlanParserFloatingConstant: case 2:
localctx = NewFloatingContext(p, localctx) localctx = NewFloatingContext(p, localctx)
p.SetParserRuleContext(localctx) p.SetParserRuleContext(localctx)
_prevctx = localctx _prevctx = localctx
@ -2265,7 +2346,7 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) {
} }
} }
case PlanParserBooleanConstant: case 3:
localctx = NewBooleanContext(p, localctx) localctx = NewBooleanContext(p, localctx)
p.SetParserRuleContext(localctx) p.SetParserRuleContext(localctx)
_prevctx = localctx _prevctx = localctx
@ -2278,7 +2359,7 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) {
} }
} }
case PlanParserStringLiteral: case 4:
localctx = NewStringContext(p, localctx) localctx = NewStringContext(p, localctx)
p.SetParserRuleContext(localctx) p.SetParserRuleContext(localctx)
_prevctx = localctx _prevctx = localctx
@ -2291,7 +2372,7 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) {
} }
} }
case PlanParserIdentifier: case 5:
localctx = NewIdentifierContext(p, localctx) localctx = NewIdentifierContext(p, localctx)
p.SetParserRuleContext(localctx) p.SetParserRuleContext(localctx)
_prevctx = localctx _prevctx = localctx
@ -2304,7 +2385,7 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) {
} }
} }
case PlanParserJSONIdentifier: case 6:
localctx = NewJSONIdentifierContext(p, localctx) localctx = NewJSONIdentifierContext(p, localctx)
p.SetParserRuleContext(localctx) p.SetParserRuleContext(localctx)
_prevctx = localctx _prevctx = localctx
@ -2317,7 +2398,7 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) {
} }
} }
case PlanParserT__0: case 7:
localctx = NewParensContext(p, localctx) localctx = NewParensContext(p, localctx)
p.SetParserRuleContext(localctx) p.SetParserRuleContext(localctx)
_prevctx = localctx _prevctx = localctx
@ -2342,7 +2423,7 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) {
} }
} }
case PlanParserT__2: case 8:
localctx = NewArrayContext(p, localctx) localctx = NewArrayContext(p, localctx)
p.SetParserRuleContext(localctx) p.SetParserRuleContext(localctx)
_prevctx = localctx _prevctx = localctx
@ -2420,7 +2501,7 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) {
} }
} }
case PlanParserEmptyArray: case 9:
localctx = NewEmptyArrayContext(p, localctx) localctx = NewEmptyArrayContext(p, localctx)
p.SetParserRuleContext(localctx) p.SetParserRuleContext(localctx)
_prevctx = localctx _prevctx = localctx
@ -2433,7 +2514,7 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) {
} }
} }
case PlanParserTEXTMATCH: case 10:
localctx = NewTextMatchContext(p, localctx) localctx = NewTextMatchContext(p, localctx)
p.SetParserRuleContext(localctx) p.SetParserRuleContext(localctx)
_prevctx = localctx _prevctx = localctx
@ -2486,7 +2567,7 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) {
} }
} }
case PlanParserADD, PlanParserSUB, PlanParserBNOT, PlanParserNOT: case 11:
localctx = NewUnaryContext(p, localctx) localctx = NewUnaryContext(p, localctx)
p.SetParserRuleContext(localctx) p.SetParserRuleContext(localctx)
_prevctx = localctx _prevctx = localctx
@ -2510,10 +2591,10 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) {
} }
{ {
p.SetState(35) p.SetState(35)
p.expr(19) p.expr(20)
} }
case PlanParserJSONContains, PlanParserArrayContains: case 12:
localctx = NewJSONContainsContext(p, localctx) localctx = NewJSONContainsContext(p, localctx)
p.SetParserRuleContext(localctx) p.SetParserRuleContext(localctx)
_prevctx = localctx _prevctx = localctx
@ -2561,7 +2642,7 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) {
} }
} }
case PlanParserJSONContainsAll, PlanParserArrayContainsAll: case 13:
localctx = NewJSONContainsAllContext(p, localctx) localctx = NewJSONContainsAllContext(p, localctx)
p.SetParserRuleContext(localctx) p.SetParserRuleContext(localctx)
_prevctx = localctx _prevctx = localctx
@ -2609,7 +2690,7 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) {
} }
} }
case PlanParserJSONContainsAny, PlanParserArrayContainsAny: case 14:
localctx = NewJSONContainsAnyContext(p, localctx) localctx = NewJSONContainsAnyContext(p, localctx)
p.SetParserRuleContext(localctx) p.SetParserRuleContext(localctx)
_prevctx = localctx _prevctx = localctx
@ -2657,7 +2738,7 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) {
} }
} }
case PlanParserArrayLength: case 15:
localctx = NewArrayLengthContext(p, localctx) localctx = NewArrayLengthContext(p, localctx)
p.SetParserRuleContext(localctx) p.SetParserRuleContext(localctx)
_prevctx = localctx _prevctx = localctx
@ -2697,13 +2778,13 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) {
} }
} }
case PlanParserEXISTS: case 16:
localctx = NewExistsContext(p, localctx) localctx = NewCallContext(p, localctx)
p.SetParserRuleContext(localctx) p.SetParserRuleContext(localctx)
_prevctx = localctx _prevctx = localctx
{ {
p.SetState(61) p.SetState(61)
p.Match(PlanParserEXISTS) p.Match(PlanParserIdentifier)
if p.HasError() { if p.HasError() {
// Recognition error - abort rule // Recognition error - abort rule
goto errorExit goto errorExit
@ -2711,20 +2792,115 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) {
} }
{ {
p.SetState(62) p.SetState(62)
p.Match(PlanParserT__0)
if p.HasError() {
// Recognition error - abort rule
goto errorExit
}
}
p.SetState(74)
p.GetErrorHandler().Sync(p)
if p.HasError() {
goto errorExit
}
_la = p.GetTokenStream().LA(1)
if (int64(_la) & ^0x3f) == 0 && ((int64(1)<<_la)&35183030034442) != 0 {
{
p.SetState(63)
p.expr(0)
}
p.SetState(68)
p.GetErrorHandler().Sync(p)
if p.HasError() {
goto errorExit
}
_alt = p.GetInterpreter().AdaptivePredict(p.BaseParser, p.GetTokenStream(), 2, p.GetParserRuleContext())
if p.HasError() {
goto errorExit
}
for _alt != 2 && _alt != antlr.ATNInvalidAltNumber {
if _alt == 1 {
{
p.SetState(64)
p.Match(PlanParserT__3)
if p.HasError() {
// Recognition error - abort rule
goto errorExit
}
}
{
p.SetState(65)
p.expr(0)
}
}
p.SetState(70)
p.GetErrorHandler().Sync(p)
if p.HasError() {
goto errorExit
}
_alt = p.GetInterpreter().AdaptivePredict(p.BaseParser, p.GetTokenStream(), 2, p.GetParserRuleContext())
if p.HasError() {
goto errorExit
}
}
p.SetState(72)
p.GetErrorHandler().Sync(p)
if p.HasError() {
goto errorExit
}
_la = p.GetTokenStream().LA(1)
if _la == PlanParserT__3 {
{
p.SetState(71)
p.Match(PlanParserT__3)
if p.HasError() {
// Recognition error - abort rule
goto errorExit
}
}
}
}
{
p.SetState(76)
p.Match(PlanParserT__1)
if p.HasError() {
// Recognition error - abort rule
goto errorExit
}
}
case 17:
localctx = NewExistsContext(p, localctx)
p.SetParserRuleContext(localctx)
_prevctx = localctx
{
p.SetState(77)
p.Match(PlanParserEXISTS)
if p.HasError() {
// Recognition error - abort rule
goto errorExit
}
}
{
p.SetState(78)
p.expr(1) p.expr(1)
} }
default: case antlr.ATNInvalidAltNumber:
p.SetError(antlr.NewNoViableAltException(p, nil, nil, nil, nil, nil))
goto errorExit goto errorExit
} }
p.GetParserRuleContext().SetStop(p.GetTokenStream().LT(-1)) p.GetParserRuleContext().SetStop(p.GetTokenStream().LT(-1))
p.SetState(119) p.SetState(135)
p.GetErrorHandler().Sync(p) p.GetErrorHandler().Sync(p)
if p.HasError() { if p.HasError() {
goto errorExit goto errorExit
} }
_alt = p.GetInterpreter().AdaptivePredict(p.BaseParser, p.GetTokenStream(), 5, p.GetParserRuleContext()) _alt = p.GetInterpreter().AdaptivePredict(p.BaseParser, p.GetTokenStream(), 8, p.GetParserRuleContext())
if p.HasError() { if p.HasError() {
goto errorExit goto errorExit
} }
@ -2734,24 +2910,24 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) {
p.TriggerExitRuleEvent() p.TriggerExitRuleEvent()
} }
_prevctx = localctx _prevctx = localctx
p.SetState(117) p.SetState(133)
p.GetErrorHandler().Sync(p) p.GetErrorHandler().Sync(p)
if p.HasError() { if p.HasError() {
goto errorExit goto errorExit
} }
switch p.GetInterpreter().AdaptivePredict(p.BaseParser, p.GetTokenStream(), 4, p.GetParserRuleContext()) { switch p.GetInterpreter().AdaptivePredict(p.BaseParser, p.GetTokenStream(), 7, p.GetParserRuleContext()) {
case 1: case 1:
localctx = NewPowerContext(p, NewExprContext(p, _parentctx, _parentState)) localctx = NewPowerContext(p, NewExprContext(p, _parentctx, _parentState))
p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr) p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr)
p.SetState(65) p.SetState(81)
if !(p.Precpred(p.GetParserRuleContext(), 20)) { if !(p.Precpred(p.GetParserRuleContext(), 21)) {
p.SetError(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 20)", "")) p.SetError(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 21)", ""))
goto errorExit goto errorExit
} }
{ {
p.SetState(66) p.SetState(82)
p.Match(PlanParserPOW) p.Match(PlanParserPOW)
if p.HasError() { if p.HasError() {
// Recognition error - abort rule // Recognition error - abort rule
@ -2759,21 +2935,21 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) {
} }
} }
{ {
p.SetState(67) p.SetState(83)
p.expr(21) p.expr(22)
} }
case 2: case 2:
localctx = NewMulDivModContext(p, NewExprContext(p, _parentctx, _parentState)) localctx = NewMulDivModContext(p, NewExprContext(p, _parentctx, _parentState))
p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr) p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr)
p.SetState(68) p.SetState(84)
if !(p.Precpred(p.GetParserRuleContext(), 18)) { if !(p.Precpred(p.GetParserRuleContext(), 19)) {
p.SetError(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 18)", "")) p.SetError(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 19)", ""))
goto errorExit goto errorExit
} }
{ {
p.SetState(69) p.SetState(85)
var _lt = p.GetTokenStream().LT(1) var _lt = p.GetTokenStream().LT(1)
@ -2791,21 +2967,21 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) {
} }
} }
{ {
p.SetState(70) p.SetState(86)
p.expr(19) p.expr(20)
} }
case 3: case 3:
localctx = NewAddSubContext(p, NewExprContext(p, _parentctx, _parentState)) localctx = NewAddSubContext(p, NewExprContext(p, _parentctx, _parentState))
p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr) p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr)
p.SetState(71) p.SetState(87)
if !(p.Precpred(p.GetParserRuleContext(), 17)) { if !(p.Precpred(p.GetParserRuleContext(), 18)) {
p.SetError(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 17)", "")) p.SetError(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 18)", ""))
goto errorExit goto errorExit
} }
{ {
p.SetState(72) p.SetState(88)
var _lt = p.GetTokenStream().LT(1) var _lt = p.GetTokenStream().LT(1)
@ -2823,21 +2999,21 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) {
} }
} }
{ {
p.SetState(73) p.SetState(89)
p.expr(18) p.expr(19)
} }
case 4: case 4:
localctx = NewShiftContext(p, NewExprContext(p, _parentctx, _parentState)) localctx = NewShiftContext(p, NewExprContext(p, _parentctx, _parentState))
p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr) p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr)
p.SetState(74) p.SetState(90)
if !(p.Precpred(p.GetParserRuleContext(), 16)) { if !(p.Precpred(p.GetParserRuleContext(), 17)) {
p.SetError(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 16)", "")) p.SetError(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 17)", ""))
goto errorExit goto errorExit
} }
{ {
p.SetState(75) p.SetState(91)
var _lt = p.GetTokenStream().LT(1) var _lt = p.GetTokenStream().LT(1)
@ -2855,20 +3031,20 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) {
} }
} }
{ {
p.SetState(76) p.SetState(92)
p.expr(17) p.expr(18)
} }
case 5: case 5:
localctx = NewTermContext(p, NewExprContext(p, _parentctx, _parentState)) localctx = NewTermContext(p, NewExprContext(p, _parentctx, _parentState))
p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr) p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr)
p.SetState(77) p.SetState(93)
if !(p.Precpred(p.GetParserRuleContext(), 15)) { if !(p.Precpred(p.GetParserRuleContext(), 16)) {
p.SetError(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 15)", "")) p.SetError(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 16)", ""))
goto errorExit goto errorExit
} }
p.SetState(79) p.SetState(95)
p.GetErrorHandler().Sync(p) p.GetErrorHandler().Sync(p)
if p.HasError() { if p.HasError() {
goto errorExit goto errorExit
@ -2877,7 +3053,7 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) {
if _la == PlanParserNOT { if _la == PlanParserNOT {
{ {
p.SetState(78) p.SetState(94)
var _m = p.Match(PlanParserNOT) var _m = p.Match(PlanParserNOT)
@ -2890,7 +3066,7 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) {
} }
{ {
p.SetState(81) p.SetState(97)
p.Match(PlanParserIN) p.Match(PlanParserIN)
if p.HasError() { if p.HasError() {
// Recognition error - abort rule // Recognition error - abort rule
@ -2898,21 +3074,21 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) {
} }
} }
{ {
p.SetState(82) p.SetState(98)
p.expr(16) p.expr(17)
} }
case 6: case 6:
localctx = NewRangeContext(p, NewExprContext(p, _parentctx, _parentState)) localctx = NewRangeContext(p, NewExprContext(p, _parentctx, _parentState))
p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr) p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr)
p.SetState(83) p.SetState(99)
if !(p.Precpred(p.GetParserRuleContext(), 10)) { if !(p.Precpred(p.GetParserRuleContext(), 10)) {
p.SetError(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 10)", "")) p.SetError(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 10)", ""))
goto errorExit goto errorExit
} }
{ {
p.SetState(84) p.SetState(100)
var _lt = p.GetTokenStream().LT(1) var _lt = p.GetTokenStream().LT(1)
@ -2930,7 +3106,7 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) {
} }
} }
{ {
p.SetState(85) p.SetState(101)
_la = p.GetTokenStream().LA(1) _la = p.GetTokenStream().LA(1)
if !(_la == PlanParserIdentifier || _la == PlanParserJSONIdentifier) { if !(_la == PlanParserIdentifier || _la == PlanParserJSONIdentifier) {
@ -2941,7 +3117,7 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) {
} }
} }
{ {
p.SetState(86) p.SetState(102)
var _lt = p.GetTokenStream().LT(1) var _lt = p.GetTokenStream().LT(1)
@ -2959,21 +3135,21 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) {
} }
} }
{ {
p.SetState(87) p.SetState(103)
p.expr(11) p.expr(11)
} }
case 7: case 7:
localctx = NewReverseRangeContext(p, NewExprContext(p, _parentctx, _parentState)) localctx = NewReverseRangeContext(p, NewExprContext(p, _parentctx, _parentState))
p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr) p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr)
p.SetState(88) p.SetState(104)
if !(p.Precpred(p.GetParserRuleContext(), 9)) { if !(p.Precpred(p.GetParserRuleContext(), 9)) {
p.SetError(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 9)", "")) p.SetError(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 9)", ""))
goto errorExit goto errorExit
} }
{ {
p.SetState(89) p.SetState(105)
var _lt = p.GetTokenStream().LT(1) var _lt = p.GetTokenStream().LT(1)
@ -2991,7 +3167,7 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) {
} }
} }
{ {
p.SetState(90) p.SetState(106)
_la = p.GetTokenStream().LA(1) _la = p.GetTokenStream().LA(1)
if !(_la == PlanParserIdentifier || _la == PlanParserJSONIdentifier) { if !(_la == PlanParserIdentifier || _la == PlanParserJSONIdentifier) {
@ -3002,7 +3178,7 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) {
} }
} }
{ {
p.SetState(91) p.SetState(107)
var _lt = p.GetTokenStream().LT(1) var _lt = p.GetTokenStream().LT(1)
@ -3020,21 +3196,21 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) {
} }
} }
{ {
p.SetState(92) p.SetState(108)
p.expr(10) p.expr(10)
} }
case 8: case 8:
localctx = NewRelationalContext(p, NewExprContext(p, _parentctx, _parentState)) localctx = NewRelationalContext(p, NewExprContext(p, _parentctx, _parentState))
p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr) p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr)
p.SetState(93) p.SetState(109)
if !(p.Precpred(p.GetParserRuleContext(), 8)) { if !(p.Precpred(p.GetParserRuleContext(), 8)) {
p.SetError(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 8)", "")) p.SetError(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 8)", ""))
goto errorExit goto errorExit
} }
{ {
p.SetState(94) p.SetState(110)
var _lt = p.GetTokenStream().LT(1) var _lt = p.GetTokenStream().LT(1)
@ -3052,21 +3228,21 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) {
} }
} }
{ {
p.SetState(95) p.SetState(111)
p.expr(9) p.expr(9)
} }
case 9: case 9:
localctx = NewEqualityContext(p, NewExprContext(p, _parentctx, _parentState)) localctx = NewEqualityContext(p, NewExprContext(p, _parentctx, _parentState))
p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr) p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr)
p.SetState(96) p.SetState(112)
if !(p.Precpred(p.GetParserRuleContext(), 7)) { if !(p.Precpred(p.GetParserRuleContext(), 7)) {
p.SetError(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 7)", "")) p.SetError(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 7)", ""))
goto errorExit goto errorExit
} }
{ {
p.SetState(97) p.SetState(113)
var _lt = p.GetTokenStream().LT(1) var _lt = p.GetTokenStream().LT(1)
@ -3084,21 +3260,21 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) {
} }
} }
{ {
p.SetState(98) p.SetState(114)
p.expr(8) p.expr(8)
} }
case 10: case 10:
localctx = NewBitAndContext(p, NewExprContext(p, _parentctx, _parentState)) localctx = NewBitAndContext(p, NewExprContext(p, _parentctx, _parentState))
p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr) p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr)
p.SetState(99) p.SetState(115)
if !(p.Precpred(p.GetParserRuleContext(), 6)) { if !(p.Precpred(p.GetParserRuleContext(), 6)) {
p.SetError(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 6)", "")) p.SetError(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 6)", ""))
goto errorExit goto errorExit
} }
{ {
p.SetState(100) p.SetState(116)
p.Match(PlanParserBAND) p.Match(PlanParserBAND)
if p.HasError() { if p.HasError() {
// Recognition error - abort rule // Recognition error - abort rule
@ -3106,21 +3282,21 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) {
} }
} }
{ {
p.SetState(101) p.SetState(117)
p.expr(7) p.expr(7)
} }
case 11: case 11:
localctx = NewBitXorContext(p, NewExprContext(p, _parentctx, _parentState)) localctx = NewBitXorContext(p, NewExprContext(p, _parentctx, _parentState))
p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr) p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr)
p.SetState(102) p.SetState(118)
if !(p.Precpred(p.GetParserRuleContext(), 5)) { if !(p.Precpred(p.GetParserRuleContext(), 5)) {
p.SetError(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 5)", "")) p.SetError(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 5)", ""))
goto errorExit goto errorExit
} }
{ {
p.SetState(103) p.SetState(119)
p.Match(PlanParserBXOR) p.Match(PlanParserBXOR)
if p.HasError() { if p.HasError() {
// Recognition error - abort rule // Recognition error - abort rule
@ -3128,21 +3304,21 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) {
} }
} }
{ {
p.SetState(104) p.SetState(120)
p.expr(6) p.expr(6)
} }
case 12: case 12:
localctx = NewBitOrContext(p, NewExprContext(p, _parentctx, _parentState)) localctx = NewBitOrContext(p, NewExprContext(p, _parentctx, _parentState))
p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr) p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr)
p.SetState(105) p.SetState(121)
if !(p.Precpred(p.GetParserRuleContext(), 4)) { if !(p.Precpred(p.GetParserRuleContext(), 4)) {
p.SetError(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 4)", "")) p.SetError(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 4)", ""))
goto errorExit goto errorExit
} }
{ {
p.SetState(106) p.SetState(122)
p.Match(PlanParserBOR) p.Match(PlanParserBOR)
if p.HasError() { if p.HasError() {
// Recognition error - abort rule // Recognition error - abort rule
@ -3150,21 +3326,21 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) {
} }
} }
{ {
p.SetState(107) p.SetState(123)
p.expr(5) p.expr(5)
} }
case 13: case 13:
localctx = NewLogicalAndContext(p, NewExprContext(p, _parentctx, _parentState)) localctx = NewLogicalAndContext(p, NewExprContext(p, _parentctx, _parentState))
p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr) p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr)
p.SetState(108) p.SetState(124)
if !(p.Precpred(p.GetParserRuleContext(), 3)) { if !(p.Precpred(p.GetParserRuleContext(), 3)) {
p.SetError(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 3)", "")) p.SetError(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 3)", ""))
goto errorExit goto errorExit
} }
{ {
p.SetState(109) p.SetState(125)
p.Match(PlanParserAND) p.Match(PlanParserAND)
if p.HasError() { if p.HasError() {
// Recognition error - abort rule // Recognition error - abort rule
@ -3172,21 +3348,21 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) {
} }
} }
{ {
p.SetState(110) p.SetState(126)
p.expr(4) p.expr(4)
} }
case 14: case 14:
localctx = NewLogicalOrContext(p, NewExprContext(p, _parentctx, _parentState)) localctx = NewLogicalOrContext(p, NewExprContext(p, _parentctx, _parentState))
p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr) p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr)
p.SetState(111) p.SetState(127)
if !(p.Precpred(p.GetParserRuleContext(), 2)) { if !(p.Precpred(p.GetParserRuleContext(), 2)) {
p.SetError(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 2)", "")) p.SetError(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 2)", ""))
goto errorExit goto errorExit
} }
{ {
p.SetState(112) p.SetState(128)
p.Match(PlanParserOR) p.Match(PlanParserOR)
if p.HasError() { if p.HasError() {
// Recognition error - abort rule // Recognition error - abort rule
@ -3194,21 +3370,21 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) {
} }
} }
{ {
p.SetState(113) p.SetState(129)
p.expr(3) p.expr(3)
} }
case 15: case 15:
localctx = NewLikeContext(p, NewExprContext(p, _parentctx, _parentState)) localctx = NewLikeContext(p, NewExprContext(p, _parentctx, _parentState))
p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr) p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr)
p.SetState(114) p.SetState(130)
if !(p.Precpred(p.GetParserRuleContext(), 22)) { if !(p.Precpred(p.GetParserRuleContext(), 23)) {
p.SetError(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 22)", "")) p.SetError(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 23)", ""))
goto errorExit goto errorExit
} }
{ {
p.SetState(115) p.SetState(131)
p.Match(PlanParserLIKE) p.Match(PlanParserLIKE)
if p.HasError() { if p.HasError() {
// Recognition error - abort rule // Recognition error - abort rule
@ -3216,7 +3392,7 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) {
} }
} }
{ {
p.SetState(116) p.SetState(132)
p.Match(PlanParserStringLiteral) p.Match(PlanParserStringLiteral)
if p.HasError() { if p.HasError() {
// Recognition error - abort rule // Recognition error - abort rule
@ -3229,12 +3405,12 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) {
} }
} }
p.SetState(121) p.SetState(137)
p.GetErrorHandler().Sync(p) p.GetErrorHandler().Sync(p)
if p.HasError() { if p.HasError() {
goto errorExit goto errorExit
} }
_alt = p.GetInterpreter().AdaptivePredict(p.BaseParser, p.GetTokenStream(), 5, p.GetParserRuleContext()) _alt = p.GetInterpreter().AdaptivePredict(p.BaseParser, p.GetTokenStream(), 8, p.GetParserRuleContext())
if p.HasError() { if p.HasError() {
goto errorExit goto errorExit
} }
@ -3270,19 +3446,19 @@ func (p *PlanParser) Sempred(localctx antlr.RuleContext, ruleIndex, predIndex in
func (p *PlanParser) Expr_Sempred(localctx antlr.RuleContext, predIndex int) bool { func (p *PlanParser) Expr_Sempred(localctx antlr.RuleContext, predIndex int) bool {
switch predIndex { switch predIndex {
case 0: case 0:
return p.Precpred(p.GetParserRuleContext(), 20) return p.Precpred(p.GetParserRuleContext(), 21)
case 1: case 1:
return p.Precpred(p.GetParserRuleContext(), 18) return p.Precpred(p.GetParserRuleContext(), 19)
case 2: case 2:
return p.Precpred(p.GetParserRuleContext(), 17) return p.Precpred(p.GetParserRuleContext(), 18)
case 3: case 3:
return p.Precpred(p.GetParserRuleContext(), 16) return p.Precpred(p.GetParserRuleContext(), 17)
case 4: case 4:
return p.Precpred(p.GetParserRuleContext(), 15) return p.Precpred(p.GetParserRuleContext(), 16)
case 5: case 5:
return p.Precpred(p.GetParserRuleContext(), 10) return p.Precpred(p.GetParserRuleContext(), 10)
@ -3312,7 +3488,7 @@ func (p *PlanParser) Expr_Sempred(localctx antlr.RuleContext, predIndex int) boo
return p.Precpred(p.GetParserRuleContext(), 2) return p.Precpred(p.GetParserRuleContext(), 2)
case 14: case 14:
return p.Precpred(p.GetParserRuleContext(), 22) return p.Precpred(p.GetParserRuleContext(), 23)
default: default:
panic("No predicate with index: " + fmt.Sprint(predIndex)) panic("No predicate with index: " + fmt.Sprint(predIndex))

View File

@ -46,6 +46,9 @@ type PlanVisitor interface {
// Visit a parse tree produced by PlanParser#Shift. // Visit a parse tree produced by PlanParser#Shift.
VisitShift(ctx *ShiftContext) interface{} VisitShift(ctx *ShiftContext) interface{}
// Visit a parse tree produced by PlanParser#Call.
VisitCall(ctx *CallContext) interface{}
// Visit a parse tree produced by PlanParser#ReverseRange. // Visit a parse tree produced by PlanParser#ReverseRange.
VisitReverseRange(ctx *ReverseRangeContext) interface{} VisitReverseRange(ctx *ReverseRangeContext) interface{}

View File

@ -594,6 +594,28 @@ func (v *ParserVisitor) getChildColumnInfo(identifier, child antlr.TerminalNode)
return v.getColumnInfoFromJSONIdentifier(child.GetText()) return v.getColumnInfoFromJSONIdentifier(child.GetText())
} }
// VisitCall parses the expr to call plan.
func (v *ParserVisitor) VisitCall(ctx *parser.CallContext) interface{} {
functionName := strings.ToLower(ctx.Identifier().GetText())
numParams := len(ctx.AllExpr())
funcParameters := make([]*planpb.Expr, 0, numParams)
for _, param := range ctx.AllExpr() {
paramExpr := getExpr(param.Accept(v))
funcParameters = append(funcParameters, paramExpr.expr)
}
return &ExprWithType{
expr: &planpb.Expr{
Expr: &planpb.Expr_CallExpr{
CallExpr: &planpb.CallExpr{
FunctionName: functionName,
FunctionParameters: funcParameters,
},
},
},
dataType: schemapb.DataType_Bool,
}
}
// VisitRange translates expr to range plan. // VisitRange translates expr to range plan.
func (v *ParserVisitor) VisitRange(ctx *parser.RangeContext) interface{} { func (v *ParserVisitor) VisitRange(ctx *parser.RangeContext) interface{} {
columnInfo, err := v.getChildColumnInfo(ctx.Identifier(), ctx.JSONIdentifier()) columnInfo, err := v.getChildColumnInfo(ctx.Identifier(), ctx.JSONIdentifier())

View File

@ -58,13 +58,11 @@ func newTestSchemaHelper(t *testing.T) *typeutil.SchemaHelper {
} }
func assertValidExpr(t *testing.T, helper *typeutil.SchemaHelper, exprStr string) { func assertValidExpr(t *testing.T, helper *typeutil.SchemaHelper, exprStr string) {
_, err := ParseExpr(helper, exprStr) expr, err := ParseExpr(helper, exprStr)
assert.NoError(t, err, exprStr) assert.NoError(t, err, exprStr)
// expr, err := ParseExpr(helper, exprStr)
// assert.NoError(t, err, exprStr)
// fmt.Printf("expr: %s\n", exprStr) // fmt.Printf("expr: %s\n", exprStr)
// ShowExpr(expr) assert.NotNil(t, expr, exprStr)
ShowExpr(expr)
} }
func assertInvalidExpr(t *testing.T, helper *typeutil.SchemaHelper, exprStr string) { func assertInvalidExpr(t *testing.T, helper *typeutil.SchemaHelper, exprStr string) {
@ -106,6 +104,43 @@ func TestExpr_Term(t *testing.T) {
} }
} }
func TestExpr_Call(t *testing.T) {
schema := newTestSchema()
helper, err := typeutil.CreateSchemaHelper(schema)
assert.NoError(t, err)
testcases := []struct {
CallExpr string
FunctionName string
ParameterNum int
}{
{`hello123()`, "hello123", 0},
{`lt(Int32Field)`, "lt", 1},
// test parens
{`lt((((Int32Field))))`, "lt", 1},
{`empty(VarCharField,)`, "empty", 1},
{`f2(Int64Field)`, "f2", 1},
{`f2(Int64Field, 4)`, "f2", 2},
{`f3(JSON_FIELD["A"], Int32Field)`, "f3", 2},
{`f5(3+3, Int32Field)`, "f5", 2},
}
for _, testcase := range testcases {
expr, err := ParseExpr(helper, testcase.CallExpr)
assert.NoError(t, err, testcase)
assert.Equal(t, testcase.FunctionName, expr.GetCallExpr().FunctionName, testcase)
assert.Equal(t, testcase.ParameterNum, len(expr.GetCallExpr().FunctionParameters), testcase)
ShowExpr(expr)
}
expr, err := ParseExpr(helper, "xxx(1+1, !true, f(10+10))")
assert.NoError(t, err)
assert.Equal(t, "xxx", expr.GetCallExpr().FunctionName)
assert.Equal(t, 3, len(expr.GetCallExpr().FunctionParameters))
assert.Equal(t, int64(2), expr.GetCallExpr().GetFunctionParameters()[0].GetValueExpr().GetValue().GetInt64Val())
assert.Equal(t, false, expr.GetCallExpr().GetFunctionParameters()[1].GetValueExpr().GetValue().GetBoolVal())
assert.Equal(t, int64(20), expr.GetCallExpr().GetFunctionParameters()[2].GetCallExpr().GetFunctionParameters()[0].GetValueExpr().GetValue().GetInt64Val())
}
func TestExpr_Compare(t *testing.T) { func TestExpr_Compare(t *testing.T) {
schema := newTestSchema() schema := newTestSchema()
helper, err := typeutil.CreateSchemaHelper(schema) helper, err := typeutil.CreateSchemaHelper(schema)
@ -247,7 +282,7 @@ func TestExpr_BinaryArith(t *testing.T) {
exprStrs := []string{ exprStrs := []string{
`Int64Field % 10 == 9`, `Int64Field % 10 == 9`,
`Int64Field % 10 != 9`, `Int64Field % 10 != 9`,
`Int64Field + 1.1 == 2.1`, `FloatField + 1.1 == 2.1`,
`A % 10 != 2`, `A % 10 != 2`,
`Int8Field + 1 < 2`, `Int8Field + 1 < 2`,
`Int16Field - 3 <= 4`, `Int16Field - 3 <= 4`,
@ -265,6 +300,13 @@ func TestExpr_BinaryArith(t *testing.T) {
assertValidExpr(t, helper, exprStr) assertValidExpr(t, helper, exprStr)
} }
invalidExprs := []string{
`Int64Field + 1.1 == 2.1`,
}
for _, exprStr := range invalidExprs {
assertInvalidExpr(t, helper, exprStr)
}
// TODO: enable these after execution backend is ready. // TODO: enable these after execution backend is ready.
unsupported := []string{ unsupported := []string{
`ArrayField + 15 == 16`, `ArrayField + 15 == 16`,
@ -286,6 +328,7 @@ func TestExpr_Value(t *testing.T) {
`true`, `true`,
`false`, `false`,
`"str"`, `"str"`,
`3 > 2`,
} }
for _, exprStr := range exprStrs { for _, exprStr := range exprStrs {
expr := handleExpr(helper, exprStr) expr := handleExpr(helper, exprStr)
@ -935,6 +978,8 @@ func Test_JSONContains(t *testing.T) {
`json_contains(JSONField["x"], 5)`, `json_contains(JSONField["x"], 5)`,
`not json_contains(JSONField["x"], 5)`, `not json_contains(JSONField["x"], 5)`,
`JSON_CONTAINS(JSONField["x"], 5)`, `JSON_CONTAINS(JSONField["x"], 5)`,
`json_Contains(JSONField, 5)`,
`JSON_contains(JSONField, 5)`,
`json_contains(A, [1,2,3])`, `json_contains(A, [1,2,3])`,
`array_contains(A, [1,2,3])`, `array_contains(A, [1,2,3])`,
`array_contains(ArrayField, [1,2,3])`, `array_contains(ArrayField, [1,2,3])`,
@ -970,8 +1015,6 @@ func Test_InvalidJSONContains(t *testing.T) {
`json_contains(A, StringField > 5)`, `json_contains(A, StringField > 5)`,
`json_contains(A)`, `json_contains(A)`,
`json_contains(A, 5, C)`, `json_contains(A, 5, C)`,
`json_Contains(JSONField, 5)`,
`JSON_contains(JSONField, 5)`,
} }
for _, expr = range exprs { for _, expr = range exprs {
_, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{

View File

@ -46,6 +46,8 @@ func (v *ShowExprVisitor) VisitExpr(expr *planpb.Expr) interface{} {
js["expr"] = v.VisitUnaryExpr(realExpr.UnaryExpr) js["expr"] = v.VisitUnaryExpr(realExpr.UnaryExpr)
case *planpb.Expr_BinaryExpr: case *planpb.Expr_BinaryExpr:
js["expr"] = v.VisitBinaryExpr(realExpr.BinaryExpr) js["expr"] = v.VisitBinaryExpr(realExpr.BinaryExpr)
case *planpb.Expr_CallExpr:
js["expr"] = v.VisitCallExpr(realExpr.CallExpr)
case *planpb.Expr_CompareExpr: case *planpb.Expr_CompareExpr:
js["expr"] = v.VisitCompareExpr(realExpr.CompareExpr) js["expr"] = v.VisitCompareExpr(realExpr.CompareExpr)
case *planpb.Expr_UnaryRangeExpr: case *planpb.Expr_UnaryRangeExpr:
@ -93,6 +95,18 @@ func (v *ShowExprVisitor) VisitBinaryExpr(expr *planpb.BinaryExpr) interface{} {
return js return js
} }
func (v *ShowExprVisitor) VisitCallExpr(expr *planpb.CallExpr) interface{} {
js := make(map[string]interface{})
js["expr_type"] = "call"
js["func_name"] = expr.FunctionName
params := make([]interface{}, 0, len(expr.FunctionParameters))
for _, p := range expr.FunctionParameters {
params = append(params, v.VisitExpr(p))
}
js["func_parameters"] = params
return js
}
func (v *ShowExprVisitor) VisitCompareExpr(expr *planpb.CompareExpr) interface{} { func (v *ShowExprVisitor) VisitCompareExpr(expr *planpb.CompareExpr) interface{} {
js := make(map[string]interface{}) js := make(map[string]interface{})
js["expr_type"] = "compare" js["expr_type"] = "compare"
@ -164,6 +178,6 @@ func NewShowExprVisitor() LogicalExprVisitor {
func ShowExpr(expr *planpb.Expr) { func ShowExpr(expr *planpb.Expr) {
v := NewShowExprVisitor() v := NewShowExprVisitor()
js := v.VisitExpr(expr) js := v.VisitExpr(expr)
b, _ := json.MarshalIndent(js, "", " ") b, _ := json.Marshal(js)
log.Info("[ShowExpr]", zap.String("expr", string(b))) log.Info("[ShowExpr]", zap.String("expr", string(b)))
} }

View File

@ -223,14 +223,14 @@ func castValue(dataType schemapb.DataType, value *planpb.GenericValue) (*planpb.
return nil, fmt.Errorf("cannot cast value to %s, value: %s", dataType.String(), value) return nil, fmt.Errorf("cannot cast value to %s, value: %s", dataType.String(), value)
} }
func combineBinaryArithExpr(op planpb.OpType, arithOp planpb.ArithOpType, columnInfo *planpb.ColumnInfo, operand *planpb.GenericValue, value *planpb.GenericValue) *planpb.Expr { func combineBinaryArithExpr(op planpb.OpType, arithOp planpb.ArithOpType, columnInfo *planpb.ColumnInfo, operand *planpb.GenericValue, value *planpb.GenericValue) (*planpb.Expr, error) {
dataType := columnInfo.GetDataType() dataType := columnInfo.GetDataType()
if typeutil.IsArrayType(dataType) && len(columnInfo.GetNestedPath()) != 0 { if typeutil.IsArrayType(dataType) && len(columnInfo.GetNestedPath()) != 0 {
dataType = columnInfo.GetElementType() dataType = columnInfo.GetElementType()
} }
castedValue, err := castValue(dataType, operand) castedValue, err := castValue(dataType, operand)
if err != nil { if err != nil {
return nil return nil, err
} }
return &planpb.Expr{ return &planpb.Expr{
Expr: &planpb.Expr_BinaryArithOpEvalRangeExpr{ Expr: &planpb.Expr_BinaryArithOpEvalRangeExpr{
@ -242,7 +242,7 @@ func combineBinaryArithExpr(op planpb.OpType, arithOp planpb.ArithOpType, column
Value: value, Value: value,
}, },
}, },
} }, nil
} }
func combineArrayLengthExpr(op planpb.OpType, arithOp planpb.ArithOpType, columnInfo *planpb.ColumnInfo, value *planpb.GenericValue) (*planpb.Expr, error) { func combineArrayLengthExpr(op planpb.OpType, arithOp planpb.ArithOpType, columnInfo *planpb.ColumnInfo, value *planpb.GenericValue) (*planpb.Expr, error) {
@ -282,7 +282,7 @@ func handleBinaryArithExpr(op planpb.OpType, arithExpr *planpb.BinaryArithExpr,
// a * 2 == 3 // a * 2 == 3
// a / 2 == 3 // a / 2 == 3
// a % 2 == 3 // a % 2 == 3
return combineBinaryArithExpr(op, arithOp, leftExpr.GetInfo(), rightValue.GetValue(), valueExpr.GetValue()), nil return combineBinaryArithExpr(op, arithOp, leftExpr.GetInfo(), rightValue.GetValue(), valueExpr.GetValue())
} else if rightExpr != nil && leftValue != nil { } else if rightExpr != nil && leftValue != nil {
// 2 + a == 3 // 2 + a == 3
// 2 - a == 3 // 2 - a == 3
@ -292,7 +292,7 @@ func handleBinaryArithExpr(op planpb.OpType, arithExpr *planpb.BinaryArithExpr,
switch arithExpr.GetOp() { switch arithExpr.GetOp() {
case planpb.ArithOpType_Add, planpb.ArithOpType_Mul: case planpb.ArithOpType_Add, planpb.ArithOpType_Mul:
return combineBinaryArithExpr(op, arithOp, rightExpr.GetInfo(), leftValue.GetValue(), valueExpr.GetValue()), nil return combineBinaryArithExpr(op, arithOp, rightExpr.GetInfo(), leftValue.GetValue(), valueExpr.GetValue())
default: default:
return nil, fmt.Errorf("module field is not yet supported") return nil, fmt.Errorf("module field is not yet supported")
} }

View File

@ -105,6 +105,11 @@ message BinaryRangeExpr {
GenericValue upper_value = 5; GenericValue upper_value = 5;
} }
message CallExpr {
string function_name = 1;
repeated Expr function_parameters = 2;
}
message CompareExpr { message CompareExpr {
ColumnInfo left_column_info = 1; ColumnInfo left_column_info = 1;
ColumnInfo right_column_info = 2; ColumnInfo right_column_info = 2;
@ -191,6 +196,7 @@ message Expr {
ExistsExpr exists_expr = 11; ExistsExpr exists_expr = 11;
AlwaysTrueExpr always_true_expr = 12; AlwaysTrueExpr always_true_expr = 12;
JSONContainsExpr json_contains_expr = 13; JSONContainsExpr json_contains_expr = 13;
CallExpr call_expr = 14;
}; };
} }

View File

@ -23,6 +23,7 @@ package querynodev2
#include "segcore/segment_c.h" #include "segcore/segment_c.h"
#include "segcore/segcore_init_c.h" #include "segcore/segcore_init_c.h"
#include "common/init_c.h" #include "common/init_c.h"
#include "exec/expression/function/init_c.h"
*/ */
import "C" import "C"
@ -252,6 +253,7 @@ func (node *QueryNode) InitSegcore() error {
} }
initcore.InitTraceConfig(paramtable.Get()) initcore.InitTraceConfig(paramtable.Get())
C.InitExecExpressionFunctionFactory()
return nil return nil
} }

View File

@ -151,8 +151,8 @@ var InvalidExpressions = []InvalidExprStruct{
{Expr: fmt.Sprintf("json_contains (%s['list'], [2])", DefaultJSONFieldName), ErrNil: true, ErrMsg: ""}, {Expr: fmt.Sprintf("json_contains (%s['list'], [2])", DefaultJSONFieldName), ErrNil: true, ErrMsg: ""},
{Expr: fmt.Sprintf("json_contains_all (%s['list'], 2)", DefaultJSONFieldName), ErrNil: false, ErrMsg: "contains_all operation element must be an array"}, {Expr: fmt.Sprintf("json_contains_all (%s['list'], 2)", DefaultJSONFieldName), ErrNil: false, ErrMsg: "contains_all operation element must be an array"},
{Expr: fmt.Sprintf("JSON_CONTAINS_ANY (%s['list'], 2)", DefaultJSONFieldName), ErrNil: false, ErrMsg: "contains_any operation element must be an array"}, {Expr: fmt.Sprintf("JSON_CONTAINS_ANY (%s['list'], 2)", DefaultJSONFieldName), ErrNil: false, ErrMsg: "contains_any operation element must be an array"},
{Expr: fmt.Sprintf("json_contains_aby (%s['list'], 2)", DefaultJSONFieldName), ErrNil: false, ErrMsg: "invalid expression: json_contains_aby"}, {Expr: fmt.Sprintf("json_contains_aby (%s['list'], 2)", DefaultJSONFieldName), ErrNil: false, ErrMsg: "function json_contains_aby(json, int64_t) not found."},
{Expr: fmt.Sprintf("json_contains_aby (%s['list'], 2)", DefaultJSONFieldName), ErrNil: false, ErrMsg: "invalid expression: json_contains_aby"}, {Expr: fmt.Sprintf("json_contains_aby (%s['list'], 2)", DefaultJSONFieldName), ErrNil: false, ErrMsg: "function json_contains_aby(json, int64_t) not found."},
{Expr: fmt.Sprintf("%s[-1] > %d", DefaultInt8ArrayField, TestCapacity), ErrNil: false, ErrMsg: "cannot parse expression"}, // array[-1] > {Expr: fmt.Sprintf("%s[-1] > %d", DefaultInt8ArrayField, TestCapacity), ErrNil: false, ErrMsg: "cannot parse expression"}, // array[-1] >
{Expr: fmt.Sprintf("%s[-1] > 1", DefaultJSONFieldName), ErrNil: false, ErrMsg: "invalid expression"}, // json[-1] > {Expr: fmt.Sprintf("%s[-1] > 1", DefaultJSONFieldName), ErrNil: false, ErrMsg: "invalid expression"}, // json[-1] >
} }

View File

@ -5415,3 +5415,59 @@ class TestQueryTextMatchNegative(TestcaseBase):
check_task=CheckTasks.err_res, check_task=CheckTasks.err_res,
check_items=error, check_items=error,
) )
class TestQueryFunction(TestcaseBase):
@pytest.mark.tags(CaseLabel.L1)
def test_query_function_calls(self):
"""
target: test query data
method: create collection and insert data
query with mix call expr in string field and int field
expected: query successfully
"""
collection_w, vectors = self.init_collection_general(prefix, insert_data=True,
primary_field=ct.default_string_field_name)[0:2]
res = vectors[0].iloc[:, 1:3].to_dict('records')
output_fields = [default_float_field_name, default_string_field_name]
for mixed_call_expr in [
"not empty(varchar) && int64 >= 0",
# function call is case-insensitive
"not EmPty(varchar) && int64 >= 0",
"not EMPTY(varchar) && int64 >= 0",
"starts_with(varchar, varchar) && int64 >= 0",
]:
collection_w.query(
mixed_call_expr,
output_fields=output_fields,
check_task=CheckTasks.check_query_results,
check_items={exp_res: res},
)
@pytest.mark.tags(CaseLabel.L1)
def test_query_invalid(self):
"""
target: test query with invalid call expression
method: query with invalid call expr
expected: raise exception
"""
collection_w, entities = self.init_collection_general(
prefix, insert_data=True, nb=10
)[0:2]
test_cases = [
(
"A_FUNCTION_THAT_DOES_NOT_EXIST()",
"function A_FUNCTION_THAT_DOES_NOT_EXIST() not found",
),
# empty
("empty()", "function empty() not found"),
(f"empty({default_int_field_name})", "function empty(int64_t) not found"),
# starts_with
(f"starts_with({default_int_field_name})", "function starts_with(int64_t) not found"),
(f"starts_with({default_int_field_name}, {default_int_field_name})", "function starts_with(int64_t, int64_t) not found"),
]
for call_expr, err_msg in test_cases:
error = {ct.err_code: 65535, ct.err_msg: err_msg}
collection_w.query(
call_expr, check_task=CheckTasks.err_res, check_items=error
)