Support SIMD of several Expr (#23715) (#23717)

Signed-off-by: luzhang <luzhang@zilliz.com>
Co-authored-by: luzhang <luzhang@zilliz.com>
This commit is contained in:
zhagnlu 2023-05-12 14:11:20 +08:00 committed by GitHub
parent 7d0c47dd65
commit 113f9a0ebc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 872 additions and 187 deletions

View File

@ -25,6 +25,7 @@ useasan = false
ifeq (${USE_ASAN}, true)
useasan = true
endif
opensimd = OFF
export GIT_BRANCH=master
@ -157,19 +158,19 @@ generated-proto: download-milvus-proto build-3rdparty
build-cpp: generated-proto
@echo "Building Milvus cpp library ..."
@(env bash $(PWD)/scripts/core_build.sh -t ${mode} -f "$(CUSTOM_THIRDPARTY_PATH)" -n ${disk_index})
@(env bash $(PWD)/scripts/core_build.sh -t ${mode} -f "$(CUSTOM_THIRDPARTY_PATH)" -n ${disk_index} -i ${opensimd})
build-cpp-gpu: generated-proto
@echo "Building Milvus cpp gpu library ..."
@(env bash $(PWD)/scripts/core_build.sh -t ${mode} -g -f "$(CUSTOM_THIRDPARTY_PATH)" -n ${disk_index})
@(env bash $(PWD)/scripts/core_build.sh -t ${mode} -g -f "$(CUSTOM_THIRDPARTY_PATH)" -n ${disk_index} -i ${opensimd})
build-cpp-with-unittest: generated-proto
@echo "Building Milvus cpp library with unittest ..."
@(env bash $(PWD)/scripts/core_build.sh -t ${mode} -u -f "$(CUSTOM_THIRDPARTY_PATH)" -n ${disk_index})
@(env bash $(PWD)/scripts/core_build.sh -t ${mode} -u -f "$(CUSTOM_THIRDPARTY_PATH)" -n ${disk_index} -i ${opensimd})
build-cpp-with-coverage: generated-proto
@echo "Building Milvus cpp library with coverage and unittest ..."
@(env bash $(PWD)/scripts/core_build.sh -t ${mode} -u -a ${useasan} -c -f "$(CUSTOM_THIRDPARTY_PATH)" -n ${disk_index})
@(env bash $(PWD)/scripts/core_build.sh -t ${mode} -u -a ${useasan} -c -f "$(CUSTOM_THIRDPARTY_PATH)" -n ${disk_index} -i ${opensimd})
check-proto-product: generated-proto
@(env bash $(PWD)/scripts/check_proto_product.sh)

View File

@ -133,9 +133,24 @@ if ( APPLE )
"-Wno-unused-parameter"
"-Wno-deprecated"
"-DBOOST_STACKTRACE_GNU_SOURCE_NOT_REQUIRED=1"
#"-fvisibility=hidden"
#"-fvisibility-inlines-hidden"
)
endif ()
# Set SIMD to CMAKE_CXX_FLAGS
if (OPEN_SIMD)
message(STATUS "open simd function, CPU_ARCH:${CPU_ARCH}")
if (${CPU_ARCH} STREQUAL "avx")
#set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -ftree-vectorize -mavx2 -mfma -mavx -mf16c ")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx2 -mfma -mavx -mf16c ")
elseif (${CPU_ARCH} STREQUAL "sse")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -msse4.2 ")
elseif (${CPU_ARCH} STREQUAL "arm64")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mcpu=apple-m1+crc ")
endif()
endif ()
# **************************** Coding style check tools ****************************
find_package( ClangTools )
set( BUILD_SUPPORT_DIR "${CMAKE_SOURCE_DIR}/build-support" )

View File

@ -125,7 +125,7 @@ template <typename Type>
using FixedVector = boost::container::vector<Type>;
using Config = nlohmann::json;
using TargetBitmap = boost::dynamic_bitset<>;
using TargetBitmap = FixedVector<bool>;
using TargetBitmapPtr = std::unique_ptr<TargetBitmap>;
using BinaryPtr = knowhere::BinaryPtr;
@ -138,6 +138,10 @@ using IndexType = knowhere::IndexType;
// Plus 1 because we can't use greater(>) symbol
constexpr size_t REF_SIZE_THRESHOLD = 16 + 1;
using BitSetBlockType = BitsetType::block_type;
constexpr size_t BITSET_BLOCK_SIZE = sizeof(BitsetType::block_type);
constexpr size_t BITSET_BLOCK_BIT_SIZE = sizeof(BitsetType::block_type) * 8;
template <typename T>
using MayRef = std::conditional_t<!std::is_trivially_copyable_v<T> ||
sizeof(T) >= REF_SIZE_THRESHOLD,

View File

@ -23,7 +23,7 @@
namespace milvus::index {
template <typename T>
const TargetBitmapPtr
const TargetBitmap
ScalarIndex<T>::Query(const DatasetPtr& dataset) {
auto op = dataset->Get<OpType>(OPERATOR_TYPE);
switch (op) {

View File

@ -45,16 +45,16 @@ class ScalarIndex : public IndexBase {
virtual void
Build(size_t n, const T* values) = 0;
virtual const TargetBitmapPtr
virtual const TargetBitmap
In(size_t n, const T* values) = 0;
virtual const TargetBitmapPtr
virtual const TargetBitmap
NotIn(size_t n, const T* values) = 0;
virtual const TargetBitmapPtr
virtual const TargetBitmap
Range(T value, OpType op) = 0;
virtual const TargetBitmapPtr
virtual const TargetBitmap
Range(T lower_bound_value,
bool lb_inclusive,
T upper_bound_value,
@ -63,7 +63,7 @@ class ScalarIndex : public IndexBase {
virtual T
Reverse_Lookup(size_t offset) const = 0;
virtual const TargetBitmapPtr
virtual const TargetBitmap
Query(const DatasetPtr& dataset);
virtual int64_t

View File

@ -101,10 +101,10 @@ ScalarIndexSort<T>::Load(const BinarySet& index_binary, const Config& config) {
}
template <typename T>
inline const TargetBitmapPtr
inline const TargetBitmap
ScalarIndexSort<T>::In(const size_t n, const T* values) {
AssertInfo(is_built_, "index has not been built");
TargetBitmapPtr bitset = std::make_unique<TargetBitmap>(data_.size());
TargetBitmap bitset(data_.size());
for (size_t i = 0; i < n; ++i) {
auto lb = std::lower_bound(
data_.begin(), data_.end(), IndexStructure<T>(*(values + i)));
@ -116,18 +116,17 @@ ScalarIndexSort<T>::In(const size_t n, const T* values) {
"experted value is: "
<< *(values + i) << ", but real value is: " << lb->a_;
}
bitset->set(lb->idx_);
bitset[lb->idx_] = true;
}
}
return bitset;
}
template <typename T>
inline const TargetBitmapPtr
inline const TargetBitmap
ScalarIndexSort<T>::NotIn(const size_t n, const T* values) {
AssertInfo(is_built_, "index has not been built");
TargetBitmapPtr bitset = std::make_unique<TargetBitmap>(data_.size());
bitset->set();
TargetBitmap bitset(data_.size(), true);
for (size_t i = 0; i < n; ++i) {
auto lb = std::lower_bound(
data_.begin(), data_.end(), IndexStructure<T>(*(values + i)));
@ -139,17 +138,17 @@ ScalarIndexSort<T>::NotIn(const size_t n, const T* values) {
"experted value is: "
<< *(values + i) << ", but real value is: " << lb->a_;
}
bitset->reset(lb->idx_);
bitset[lb->idx_] = false;
}
}
return bitset;
}
template <typename T>
inline const TargetBitmapPtr
inline const TargetBitmap
ScalarIndexSort<T>::Range(const T value, const OpType op) {
AssertInfo(is_built_, "index has not been built");
TargetBitmapPtr bitset = std::make_unique<TargetBitmap>(data_.size());
TargetBitmap bitset(data_.size());
auto lb = data_.begin();
auto ub = data_.end();
switch (op) {
@ -174,19 +173,19 @@ ScalarIndexSort<T>::Range(const T value, const OpType op) {
std::to_string((int)op) + "!");
}
for (; lb < ub; ++lb) {
bitset->set(lb->idx_);
bitset[lb->idx_] = true;
}
return bitset;
}
template <typename T>
inline const TargetBitmapPtr
inline const TargetBitmap
ScalarIndexSort<T>::Range(T lower_bound_value,
bool lb_inclusive,
T upper_bound_value,
bool ub_inclusive) {
AssertInfo(is_built_, "index has not been built");
TargetBitmapPtr bitset = std::make_unique<TargetBitmap>(data_.size());
TargetBitmap bitset(data_.size());
if (lower_bound_value > upper_bound_value ||
(lower_bound_value == upper_bound_value &&
!(lb_inclusive && ub_inclusive))) {
@ -209,7 +208,7 @@ ScalarIndexSort<T>::Range(T lower_bound_value,
data_.begin(), data_.end(), IndexStructure<T>(upper_bound_value));
}
for (; lb < ub; ++lb) {
bitset->set(lb->idx_);
bitset[lb->idx_] = true;
}
return bitset;
}

View File

@ -46,16 +46,16 @@ class ScalarIndexSort : public ScalarIndex<T> {
void
Build(size_t n, const T* values) override;
const TargetBitmapPtr
const TargetBitmap
In(size_t n, const T* values) override;
const TargetBitmapPtr
const TargetBitmap
NotIn(size_t n, const T* values) override;
const TargetBitmapPtr
const TargetBitmap
Range(T value, OpType op) override;
const TargetBitmapPtr
const TargetBitmap
Range(T lower_bound_value,
bool lb_inclusive,
T upper_bound_value,

View File

@ -29,7 +29,7 @@ namespace milvus::index {
class StringIndex : public ScalarIndex<std::string> {
public:
const TargetBitmapPtr
const TargetBitmap
Query(const DatasetPtr& dataset) override {
auto op = dataset->Get<OpType>(OPERATOR_TYPE);
if (op == OpType::PrefixMatch) {
@ -39,7 +39,7 @@ class StringIndex : public ScalarIndex<std::string> {
return ScalarIndex<std::string>::Query(dataset);
}
virtual const TargetBitmapPtr
virtual const TargetBitmap
PrefixMatch(const std::string_view prefix) = 0;
};
using StringIndexPtr = std::unique_ptr<StringIndex>;

View File

@ -127,43 +127,42 @@ valid_str_id(size_t str_id) {
return str_id >= 0 && str_id != MARISA_INVALID_KEY_ID;
}
const TargetBitmapPtr
const TargetBitmap
StringIndexMarisa::In(size_t n, const std::string* values) {
TargetBitmapPtr bitset = std::make_unique<TargetBitmap>(str_ids_.size());
TargetBitmap bitset(str_ids_.size());
for (size_t i = 0; i < n; i++) {
auto str = values[i];
auto str_id = lookup(str);
if (valid_str_id(str_id)) {
auto offsets = str_ids_to_offsets_[str_id];
for (auto offset : offsets) {
bitset->set(offset);
bitset[offset] = true;
}
}
}
return bitset;
}
const TargetBitmapPtr
const TargetBitmap
StringIndexMarisa::NotIn(size_t n, const std::string* values) {
TargetBitmapPtr bitset = std::make_unique<TargetBitmap>(str_ids_.size());
bitset->set();
TargetBitmap bitset(str_ids_.size(), true);
for (size_t i = 0; i < n; i++) {
auto str = values[i];
auto str_id = lookup(str);
if (valid_str_id(str_id)) {
auto offsets = str_ids_to_offsets_[str_id];
for (auto offset : offsets) {
bitset->reset(offset);
bitset[offset] = false;
}
}
}
return bitset;
}
const TargetBitmapPtr
const TargetBitmap
StringIndexMarisa::Range(std::string value, OpType op) {
auto count = Count();
TargetBitmapPtr bitset = std::make_unique<TargetBitmap>(count);
TargetBitmap bitset(count);
marisa::Agent agent;
for (size_t offset = 0; offset < count; ++offset) {
agent.set_query(str_ids_[offset]);
@ -189,19 +188,19 @@ StringIndexMarisa::Range(std::string value, OpType op) {
std::to_string((int)op) + "!");
}
if (set) {
bitset->set(offset);
bitset[offset] = true;
}
}
return bitset;
}
const TargetBitmapPtr
const TargetBitmap
StringIndexMarisa::Range(std::string lower_bound_value,
bool lb_inclusive,
std::string upper_bound_value,
bool ub_inclusive) {
auto count = Count();
TargetBitmapPtr bitset = std::make_unique<TargetBitmap>(count);
TargetBitmap bitset(count);
if (lower_bound_value.compare(upper_bound_value) > 0 ||
(lower_bound_value.compare(upper_bound_value) == 0 &&
!(lb_inclusive && ub_inclusive))) {
@ -224,20 +223,20 @@ StringIndexMarisa::Range(std::string lower_bound_value,
set &= raw_data.compare(upper_bound_value) < 0;
}
if (set) {
bitset->set(offset);
bitset[offset] = true;
}
}
return bitset;
}
const TargetBitmapPtr
const TargetBitmap
StringIndexMarisa::PrefixMatch(std::string_view prefix) {
TargetBitmapPtr bitset = std::make_unique<TargetBitmap>(str_ids_.size());
TargetBitmap bitset(str_ids_.size());
auto matched = prefix_match(prefix);
for (const auto str_id : matched) {
auto offsets = str_ids_to_offsets_[str_id];
for (auto offset : offsets) {
bitset->set(offset);
bitset[offset] = true;
}
}
return bitset;

View File

@ -48,22 +48,22 @@ class StringIndexMarisa : public StringIndex {
void
Build(size_t n, const std::string* values) override;
const TargetBitmapPtr
const TargetBitmap
In(size_t n, const std::string* values) override;
const TargetBitmapPtr
const TargetBitmap
NotIn(size_t n, const std::string* values) override;
const TargetBitmapPtr
const TargetBitmap
Range(std::string value, OpType op) override;
const TargetBitmapPtr
const TargetBitmap
Range(std::string lower_bound_value,
bool lb_inclusive,
std::string upper_bound_value,
bool ub_inclusive) override;
const TargetBitmapPtr
const TargetBitmap
PrefixMatch(const std::string_view prefix) override;
std::string

View File

@ -27,7 +27,7 @@ namespace milvus::index {
// TODO: should inherit from StringIndex?
class StringIndexSort : public ScalarIndexSort<std::string> {
public:
const TargetBitmapPtr
const TargetBitmap
Query(const DatasetPtr& dataset) override {
auto op = dataset->Get<OpType>(OPERATOR_TYPE);
if (op == OpType::PrefixMatch) {
@ -37,10 +37,10 @@ class StringIndexSort : public ScalarIndexSort<std::string> {
return ScalarIndex<std::string>::Query(dataset);
}
const TargetBitmapPtr
const TargetBitmap
PrefixMatch(std::string_view prefix) {
auto data = GetData();
TargetBitmapPtr bitset = std::make_unique<TargetBitmap>(data.size());
TargetBitmap bitset(data.size());
auto it = std::lower_bound(
data.begin(),
data.end(),
@ -51,7 +51,7 @@ class StringIndexSort : public ScalarIndexSort<std::string> {
if (!milvus::PrefixMatch(it->a_, prefix)) {
break;
}
bitset->set(it->idx_);
bitset[it->idx_] = true;
}
return bitset;
}

View File

@ -224,7 +224,6 @@ ProtoParser::RetrievePlanNodeFromProto(
std::unique_ptr<Plan>
ProtoParser::CreatePlan(const proto::plan::PlanNode& plan_node_proto) {
// std::cout << plan_node_proto.DebugString() << std::endl;
auto plan = std::make_unique<Plan>(schema);
auto plan_node = PlanNodeFromProto(plan_node_proto);

View File

@ -22,6 +22,10 @@
#include "ExprVisitor.h"
namespace milvus::query {
void
AppendOneChunk(BitsetType& result, const FixedVector<bool>& chunk_res);
class ExecExprVisitor : public ExprVisitor {
public:
void
@ -122,6 +126,27 @@ class ExecExprVisitor : public ExprVisitor {
ExecCompareExprDispatcher(CompareExpr& expr, CmpFunc cmp_func)
-> BitsetType;
template <typename CmpFunc>
BitsetType
ExecCompareExprDispatcherForNonIndexedSegment(CompareExpr& expr,
CmpFunc cmp_func);
// This function only used to compare sealed segment
// which has only one chunk.
template <typename T, typename U, typename CmpFunc>
TargetBitmap
ExecCompareRightType(const T* left_raw_data,
const FieldId& right_field_id,
const int64_t current_chunk_id,
CmpFunc cmp_func);
template <typename T, typename CmpFunc>
BitsetType
ExecCompareLeftType(const FieldId& left_field_id,
const FieldId& right_field_id,
const DataType& right_field_type,
CmpFunc cmp_func);
private:
const segcore::SegmentInternalInterface& segment_;
Timestamp timestamp_;

View File

@ -12,6 +12,7 @@
#include "query/generated/ExecExprVisitor.h"
#include <boost/variant.hpp>
#include <boost/utility/binary.hpp>
#include <deque>
#include <optional>
#include <string>
@ -152,6 +153,10 @@ static auto
Assemble(const std::deque<BitsetType>& srcs) -> BitsetType {
BitsetType res;
if (srcs.size() == 1) {
return srcs[0];
}
int64_t total_size = 0;
for (auto& chunk : srcs) {
total_size += chunk.size();
@ -168,6 +173,69 @@ Assemble(const std::deque<BitsetType>& srcs) -> BitsetType {
return res;
}
void
AppendOneChunk(BitsetType& result, const FixedVector<bool>& chunk_res) {
// Append a value once instead of BITSET_BLOCK_BIT_SIZE times.
auto AppendBlock = [&result](const bool* ptr, int n) {
for (int i = 0; i < n; ++i) {
BitSetBlockType val = 0;
// This can use CPU SIMD optimzation
uint8_t vals[BITSET_BLOCK_SIZE] = {0};
for (size_t j = 0; j < 8; ++j) {
for (size_t k = 0; k < BITSET_BLOCK_SIZE; ++k) {
vals[k] |= uint8_t(*(ptr + k * 8 + j)) << j;
}
}
for (size_t j = 0; j < BITSET_BLOCK_SIZE; ++j) {
val |= BitSetBlockType(vals[j]) << (8 * j);
}
result.append(val);
ptr += BITSET_BLOCK_SIZE * 8;
}
};
// Append bit for these bits that can not be union as a block
// Usually n less than BITSET_BLOCK_BIT_SIZE.
auto AppendBit = [&result](const bool* ptr, int n) {
for (int i = 0; i < n; ++i) {
bool bit = *ptr++;
result.push_back(bit);
}
};
size_t res_len = result.size();
size_t chunk_len = chunk_res.size();
const bool* chunk_ptr = chunk_res.data();
int n_prefix =
res_len % BITSET_BLOCK_BIT_SIZE == 0
? 0
: std::min(BITSET_BLOCK_BIT_SIZE - res_len % BITSET_BLOCK_BIT_SIZE,
chunk_len);
AppendBit(chunk_ptr, n_prefix);
if (n_prefix == chunk_len)
return;
size_t n_block = (chunk_len - n_prefix) / BITSET_BLOCK_BIT_SIZE;
size_t n_suffix = (chunk_len - n_prefix) % BITSET_BLOCK_BIT_SIZE;
AppendBlock(chunk_ptr + n_prefix, n_block);
AppendBit(chunk_ptr + n_prefix + n_block * BITSET_BLOCK_BIT_SIZE, n_suffix);
return;
}
BitsetType
AssembleChunk(const std::vector<FixedVector<bool>>& results) {
BitsetType assemble_result;
for (auto& result : results) {
AppendOneChunk(assemble_result, result);
}
return assemble_result;
}
template <typename T, typename IndexFunc, typename ElementFunc>
auto
ExecExprVisitor::ExecRangeVisitorImpl(FieldId field_id,
@ -178,8 +246,7 @@ ExecExprVisitor::ExecRangeVisitorImpl(FieldId field_id,
auto indexing_barrier = segment_.num_chunk_index(field_id);
auto size_per_chunk = segment_.size_per_chunk();
auto num_chunk = upper_div(row_count_, size_per_chunk);
std::deque<BitsetType> results;
std::vector<FixedVector<bool>> results;
typedef std::
conditional_t<std::is_same_v<T, std::string_view>, std::string, T>
IndexInnerType;
@ -190,24 +257,25 @@ ExecExprVisitor::ExecRangeVisitorImpl(FieldId field_id,
// NOTE: knowhere is not const-ready
// This is a dirty workaround
auto data = index_func(const_cast<Index*>(&indexing));
AssertInfo(data->size() == size_per_chunk,
AssertInfo(data.size() == size_per_chunk,
"[ExecExprVisitor]Data size not equal to size_per_chunk");
results.emplace_back(std::move(*data));
results.emplace_back(std::move(data));
}
for (auto chunk_id = indexing_barrier; chunk_id < num_chunk; ++chunk_id) {
auto this_size = chunk_id == num_chunk - 1
? row_count_ - chunk_id * size_per_chunk
: size_per_chunk;
BitsetType result(this_size);
FixedVector<bool> chunk_res(this_size);
auto chunk = segment_.chunk_data<T>(field_id, chunk_id);
const T* data = chunk.data();
// Can use CPU SIMD optimazation to speed up
for (int index = 0; index < this_size; ++index) {
result[index] = element_func(data[index]);
auto x = data[index];
chunk_res[index] = element_func(x);
}
results.emplace_back(std::move(result));
results.emplace_back(std::move(chunk_res));
}
auto final_result = Assemble(results);
auto final_result = AssembleChunk(results);
AssertInfo(final_result.size() == row_count_,
"[ExecExprVisitor]Final result size not equal to row count");
return final_result;
@ -227,7 +295,7 @@ ExecExprVisitor::ExecDataRangeVisitorImpl(FieldId field_id,
auto data_barrier = segment_.num_chunk_data(field_id);
AssertInfo(std::max(data_barrier, indexing_barrier) == num_chunk,
"max(data_barrier, index_barrier) not equal to num_chunk");
std::deque<BitsetType> results;
std::vector<FixedVector<bool>> results;
// for growing segment, indexing_barrier will always less than data_barrier
// so growing segment will always execute expr plan using raw data
@ -237,15 +305,15 @@ ExecExprVisitor::ExecDataRangeVisitorImpl(FieldId field_id,
auto this_size = chunk_id == num_chunk - 1
? row_count_ - chunk_id * size_per_chunk
: size_per_chunk;
BitsetType result(this_size);
FixedVector<bool> result(this_size);
auto chunk = segment_.chunk_data<T>(field_id, chunk_id);
const T* data = chunk.data();
for (int index = 0; index < this_size; ++index) {
result[index] = element_func(data[index]);
}
AssertInfo(
result.size() == this_size,
"[ExecExprVisitor]Chunk result size not equal to expected size");
AssertInfo(result.size() == this_size,
"[ExecExprVisitor]Chunk result size not equal to "
"expected size");
results.emplace_back(std::move(result));
}
@ -260,14 +328,14 @@ ExecExprVisitor::ExecDataRangeVisitorImpl(FieldId field_id,
auto& indexing =
segment_.chunk_scalar_index<IndexInnerType>(field_id, chunk_id);
auto this_size = const_cast<Index*>(&indexing)->Count();
BitsetType result(this_size);
FixedVector<bool> result(this_size);
for (int offset = 0; offset < this_size; ++offset) {
result[offset] = index_func(const_cast<Index*>(&indexing), offset);
}
results.emplace_back(std::move(result));
}
auto final_result = Assemble(results);
auto final_result = AssembleChunk(results);
AssertInfo(final_result.size() == row_count_,
"[ExecExprVisitor]Final result size not equal to row count");
return final_result;
@ -360,7 +428,7 @@ ExecExprVisitor::ExecUnaryRangeVisitorDispatcherJson(UnaryRangeExpr& expr_raw)
auto val = expr.value_;
auto& nested_path = expr.column_.nested_path;
auto field_id = expr.column_.field_id;
auto index_func = [=](Index* index) { return TargetBitmapPtr{}; };
auto index_func = [=](Index* index) { return TargetBitmap{}; };
switch (op) {
case OpType::Equal: {
auto elem_func = [val, nested_path](const milvus::Json& json) {
@ -861,7 +929,7 @@ ExecExprVisitor::ExecBinaryRangeVisitorDispatcherJson(BinaryRangeExpr& expr_raw)
auto& nested_path = expr.column_.nested_path;
// no json index now
auto index_func = [=](Index* index) { return TargetBitmapPtr{}; };
auto index_func = [=](Index* index) { return TargetBitmap{}; };
if (lower_inclusive && upper_inclusive) {
auto elem_func = [&](const milvus::Json& json) {
@ -1139,6 +1207,130 @@ struct relational {
}
};
template <typename T, typename U, typename CmpFunc>
TargetBitmap
ExecExprVisitor::ExecCompareRightType(const T* left_raw_data,
const FieldId& right_field_id,
const int64_t current_chunk_id,
CmpFunc cmp_func) {
auto size_per_chunk = segment_.size_per_chunk();
auto num_chunks = upper_div(row_count_, size_per_chunk);
auto size = current_chunk_id == num_chunks - 1
? row_count_ - current_chunk_id * size_per_chunk
: size_per_chunk;
TargetBitmap result(size);
const U* right_raw_data =
segment_.chunk_data<U>(right_field_id, current_chunk_id).data();
for (int i = 0; i < size; ++i) {
result[i] = cmp_func(left_raw_data[i], right_raw_data[i]);
}
return result;
}
template <typename T, typename CmpFunc>
BitsetType
ExecExprVisitor::ExecCompareLeftType(const FieldId& left_field_id,
const FieldId& right_field_id,
const DataType& right_field_type,
CmpFunc cmp_func) {
std::vector<FixedVector<bool>> results;
auto size_per_chunk = segment_.size_per_chunk();
auto num_chunks = upper_div(row_count_, size_per_chunk);
for (int64_t chunk_id = 0; chunk_id < num_chunks; ++chunk_id) {
FixedVector<bool> result;
const T* left_raw_data =
segment_.chunk_data<T>(left_field_id, chunk_id).data();
switch (right_field_type) {
case DataType::BOOL:
result = ExecCompareRightType<T, bool, CmpFunc>(
left_raw_data, right_field_id, chunk_id, cmp_func);
break;
case DataType::INT8:
result = ExecCompareRightType<T, int8_t, CmpFunc>(
left_raw_data, right_field_id, chunk_id, cmp_func);
break;
case DataType::INT16:
result = ExecCompareRightType<T, int16_t, CmpFunc>(
left_raw_data, right_field_id, chunk_id, cmp_func);
break;
case DataType::INT32:
result = ExecCompareRightType<T, int32_t, CmpFunc>(
left_raw_data, right_field_id, chunk_id, cmp_func);
break;
case DataType::INT64:
result = ExecCompareRightType<T, int64_t, CmpFunc>(
left_raw_data, right_field_id, chunk_id, cmp_func);
break;
case DataType::FLOAT:
result = ExecCompareRightType<T, float, CmpFunc>(
left_raw_data, right_field_id, chunk_id, cmp_func);
break;
case DataType::DOUBLE:
result = ExecCompareRightType<T, double, CmpFunc>(
left_raw_data, right_field_id, chunk_id, cmp_func);
break;
default:
PanicInfo("unsupported left datatype of compare expr");
}
results.push_back(result);
}
auto final_result = AssembleChunk(results);
AssertInfo(final_result.size() == row_count_,
"[ExecExprVisitor]Size of results not equal row count");
return final_result;
}
template <typename CmpFunc>
BitsetType
ExecExprVisitor::ExecCompareExprDispatcherForNonIndexedSegment(
CompareExpr& expr, CmpFunc cmp_func) {
switch (expr.left_data_type_) {
case DataType::BOOL:
return ExecCompareLeftType<bool, CmpFunc>(expr.left_field_id_,
expr.right_field_id_,
expr.right_data_type_,
cmp_func);
case DataType::INT8:
return ExecCompareLeftType<int8_t, CmpFunc>(expr.left_field_id_,
expr.right_field_id_,
expr.right_data_type_,
cmp_func);
case DataType::INT16:
return ExecCompareLeftType<int16_t, CmpFunc>(expr.left_field_id_,
expr.right_field_id_,
expr.right_data_type_,
cmp_func);
case DataType::INT32:
return ExecCompareLeftType<int32_t, CmpFunc>(expr.left_field_id_,
expr.right_field_id_,
expr.right_data_type_,
cmp_func);
case DataType::INT64:
return ExecCompareLeftType<int64_t, CmpFunc>(expr.left_field_id_,
expr.right_field_id_,
expr.right_data_type_,
cmp_func);
case DataType::FLOAT:
return ExecCompareLeftType<float, CmpFunc>(expr.left_field_id_,
expr.right_field_id_,
expr.right_data_type_,
cmp_func);
case DataType::DOUBLE:
return ExecCompareLeftType<double, CmpFunc>(expr.left_field_id_,
expr.right_field_id_,
expr.right_data_type_,
cmp_func);
default:
PanicInfo("unsupported right datatype of compare expr");
}
}
template <typename Op>
auto
ExecExprVisitor::ExecCompareExprDispatcher(CompareExpr& expr, Op op)
@ -1151,6 +1343,11 @@ ExecExprVisitor::ExecCompareExprDispatcher(CompareExpr& expr, Op op)
float,
double,
std::string>;
auto is_string_expr = [&expr]() -> bool {
return expr.left_data_type_ == DataType::VARCHAR ||
expr.right_data_type_ == DataType::VARCHAR;
};
auto size_per_chunk = segment_.size_per_chunk();
auto num_chunk = upper_div(row_count_, size_per_chunk);
std::deque<BitsetType> bitsets;
@ -1170,6 +1367,14 @@ ExecExprVisitor::ExecCompareExprDispatcher(CompareExpr& expr, Op op)
"max(right_data_barrier, right_indexing_barrier) not equal to "
"num_chunk");
// For segment both fields has no index, can use SIMD to speed up.
// Avoiding too much call stack that blocks SIMD.
if (left_indexing_barrier == 0 && right_indexing_barrier == 0 &&
!is_string_expr()) {
return ExecCompareExprDispatcherForNonIndexedSegment<Op>(expr, op);
}
// TODO: refactoring the code that contains too much call stack.
for (int64_t chunk_id = 0; chunk_id < num_chunk; ++chunk_id) {
auto size = chunk_id == num_chunk - 1
? row_count_ - chunk_id * size_per_chunk
@ -1355,8 +1560,8 @@ ExecExprVisitor::visit(CompareExpr& expr) {
auto& left_field_meta = schema[expr.left_field_id_];
auto& right_field_meta = schema[expr.right_field_id_];
AssertInfo(expr.left_data_type_ == left_field_meta.get_data_type(),
"[ExecExprVisitor]Left data type not equal to left "
"field meta type");
"[ExecExprVisitor]Left data type not equal to left field "
"meta type");
AssertInfo(expr.right_data_type_ == right_field_meta.get_data_type(),
"[ExecExprVisitor]right data type not equal to right field "
"meta type");
@ -1534,7 +1739,7 @@ ExecExprVisitor::ExecTermVisitorImplTemplateJson(TermExpr& expr_raw)
using Index = index::ScalarIndex<milvus::Json>;
auto& expr = static_cast<TermExprImpl<ExprValueType>&>(expr_raw);
auto& nested_path = expr.column_.nested_path;
auto index_func = [=](Index* index) { return TargetBitmapPtr{}; };
auto index_func = [=](Index* index) { return TargetBitmap{}; };
std::unordered_set<ExprValueType> term_set(expr.terms_.begin(),
expr.terms_.end());
@ -1647,7 +1852,7 @@ ExecExprVisitor::visit(ExistsExpr& expr) {
switch (expr.column_.data_type) {
case DataType::JSON: {
using Index = index::ScalarIndex<milvus::Json>;
auto index_func = [=](Index* index) { return TargetBitmapPtr{}; };
auto index_func = [=](Index* index) { return TargetBitmap{}; };
auto elem_func = [nested_path](const milvus::Json& json) {
auto x = json.exist(nested_path);
return x;

View File

@ -26,12 +26,20 @@ TEST(Bitmap, Naive) {
sort_index->Build(N, vec.data());
{
auto res = sort_index->Range(0, OpType::LessThan);
double count = res->count();
double count = 0;
for (size_t i = 0; i < res.size(); ++i) {
if (res[i] == true)
count++;
}
ASSERT_NEAR(count / N, 0.5, 0.01);
}
{
auto res = sort_index->Range(-1, false, 1, true);
double count = res->count();
double count = 0;
for (size_t i = 0; i < res.size(); ++i) {
if (res[i] == true)
count++;
}
ASSERT_NEAR(count / N, 0.682, 0.01);
}
}
}

View File

@ -13,6 +13,7 @@
#include <pb/schema.pb.h>
#include <index/BoolIndex.h>
#include "test_utils/indexbuilder_test_utils.h"
#include "test_utils/AssertUtils.h"
class BoolIndexTest : public ::testing::Test {
protected:
@ -81,10 +82,10 @@ TEST_F(BoolIndexTest, In) {
index->Build(all_true.data_size(), all_true.data().data());
auto bitset1 = index->In(1, true_test.get());
ASSERT_TRUE(bitset1->any());
ASSERT_TRUE(Any(bitset1));
auto bitset2 = index->In(1, false_test.get());
ASSERT_TRUE(bitset2->none());
ASSERT_TRUE(BitSetNone(bitset2));
}
{
@ -92,10 +93,10 @@ TEST_F(BoolIndexTest, In) {
index->Build(all_false.data_size(), all_false.data().data());
auto bitset1 = index->In(1, true_test.get());
ASSERT_TRUE(bitset1->none());
ASSERT_TRUE(BitSetNone(bitset1));
auto bitset2 = index->In(1, false_test.get());
ASSERT_TRUE(bitset2->any());
ASSERT_TRUE(Any(bitset2));
}
{
@ -104,12 +105,12 @@ TEST_F(BoolIndexTest, In) {
auto bitset1 = index->In(1, true_test.get());
for (size_t i = 0; i < n; i++) {
ASSERT_EQ(bitset1->test(i), (i % 2) == 0);
ASSERT_EQ(bitset1[i], (i % 2) == 0);
}
auto bitset2 = index->In(1, false_test.get());
for (size_t i = 0; i < n; i++) {
ASSERT_EQ(bitset2->test(i), (i % 2) != 0);
ASSERT_EQ(bitset2[i], (i % 2) != 0);
}
}
}
@ -123,10 +124,10 @@ TEST_F(BoolIndexTest, NotIn) {
index->Build(all_true.data_size(), all_true.data().data());
auto bitset1 = index->NotIn(1, true_test.get());
ASSERT_TRUE(bitset1->none());
ASSERT_TRUE(BitSetNone(bitset1));
auto bitset2 = index->NotIn(1, false_test.get());
ASSERT_TRUE(bitset2->any());
ASSERT_TRUE(Any(bitset2));
}
{
@ -134,10 +135,10 @@ TEST_F(BoolIndexTest, NotIn) {
index->Build(all_false.data_size(), all_false.data().data());
auto bitset1 = index->NotIn(1, true_test.get());
ASSERT_TRUE(bitset1->any());
ASSERT_TRUE(Any(bitset1));
auto bitset2 = index->NotIn(1, false_test.get());
ASSERT_TRUE(bitset2->none());
ASSERT_TRUE(BitSetNone(bitset2));
}
{
@ -146,12 +147,12 @@ TEST_F(BoolIndexTest, NotIn) {
auto bitset1 = index->NotIn(1, true_test.get());
for (size_t i = 0; i < n; i++) {
ASSERT_EQ(bitset1->test(i), (i % 2) != 0);
ASSERT_EQ(bitset1[i], (i % 2) != 0);
}
auto bitset2 = index->NotIn(1, false_test.get());
for (size_t i = 0; i < n; i++) {
ASSERT_EQ(bitset2->test(i), (i % 2) == 0);
ASSERT_EQ(bitset2[i], (i % 2) == 0);
}
}
}
@ -168,10 +169,10 @@ TEST_F(BoolIndexTest, Codec) {
copy_index->Load(index->Serialize(nullptr));
auto bitset1 = copy_index->NotIn(1, true_test.get());
ASSERT_TRUE(bitset1->none());
ASSERT_TRUE(BitSetNone(bitset1));
auto bitset2 = copy_index->NotIn(1, false_test.get());
ASSERT_TRUE(bitset2->any());
ASSERT_TRUE(Any(bitset2));
}
{
@ -182,10 +183,10 @@ TEST_F(BoolIndexTest, Codec) {
copy_index->Load(index->Serialize(nullptr));
auto bitset1 = copy_index->NotIn(1, true_test.get());
ASSERT_TRUE(bitset1->any());
ASSERT_TRUE(Any(bitset1));
auto bitset2 = copy_index->NotIn(1, false_test.get());
ASSERT_TRUE(bitset2->none());
ASSERT_TRUE(BitSetNone(bitset2));
}
{
@ -197,12 +198,12 @@ TEST_F(BoolIndexTest, Codec) {
auto bitset1 = copy_index->NotIn(1, true_test.get());
for (size_t i = 0; i < n; i++) {
ASSERT_EQ(bitset1->test(i), (i % 2) != 0);
ASSERT_EQ(bitset1[i], (i % 2) != 0);
}
auto bitset2 = copy_index->NotIn(1, false_test.get());
for (size_t i = 0; i < n; i++) {
ASSERT_EQ(bitset2->test(i), (i % 2) == 0);
ASSERT_EQ(bitset2[i], (i % 2) == 0);
}
}
}

View File

@ -30,6 +30,7 @@
#include "test_utils/DataGen.h"
#include "test_utils/PbHelper.h"
#include "test_utils/indexbuilder_test_utils.h"
#include "query/generated/ExecExprVisitor.h"
namespace chrono = std::chrono;
@ -4419,3 +4420,41 @@ TEST(CApiTest, RANGE_SEARCH_WITH_RADIUS_AND_RANGE_FILTER_WHEN_L2) {
DeleteCollection(c_collection);
DeleteSegment(segment);
}
TEST(CApiTest, AssembeChunkTest) {
FixedVector<bool> chunk;
for (size_t i = 0; i < 1000; ++i) {
chunk.push_back(i % 2 == 0);
}
BitsetType result;
milvus::query::AppendOneChunk(result, chunk);
std::string s;
boost::to_string(result, s);
std::cout << s << std::endl;
int index = 0;
for (size_t i = 0; i < 1000; i++) {
ASSERT_EQ(result[index++], chunk[i]) << i;
}
for (int i = 0; i < 934; ++i) {
chunk.push_back(i % 2 == 0);
}
milvus::query::AppendOneChunk(result, chunk);
for (size_t i = 0; i < 934; i++) {
ASSERT_EQ(result[index++], chunk[i]) << i;
}
for (int i = 0; i < 62; ++i) {
chunk.push_back(i % 2 == 0);
}
milvus::query::AppendOneChunk(result, chunk);
for (size_t i = 0; i < 62; i++) {
ASSERT_EQ(result[index++], chunk[i]) << i;
}
for (int i = 0; i < 105; ++i) {
chunk.push_back(i % 2 == 0);
}
milvus::query::AppendOneChunk(result, chunk);
for (size_t i = 0; i < 105; i++) {
ASSERT_EQ(result[index++], chunk[i]) << i;
}
}

View File

@ -15,6 +15,7 @@
#include <memory>
#include <regex>
#include <vector>
#include <chrono>
#include "common/Json.h"
#include "common/Types.h"
@ -27,6 +28,7 @@
#include "query/generated/ExecExprVisitor.h"
#include "segcore/SegmentGrowingImpl.h"
#include "simdjson/padded_string.h"
#include "segcore/segment_c.h"
#include "test_utils/DataGen.h"
#include "index/IndexFactory.h"
@ -1046,6 +1048,301 @@ TEST(Expr, TestCompareWithScalarIndex) {
}
}
TEST(Expr, TestCompareExpr) {
using namespace milvus::query;
using namespace milvus::segcore;
auto schema = std::make_shared<Schema>();
auto vec_fid = schema->AddDebugField(
"fakevec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2);
auto bool_fid = schema->AddDebugField("bool", DataType::BOOL);
auto bool_1_fid = schema->AddDebugField("bool1", DataType::BOOL);
auto int8_fid = schema->AddDebugField("int8", DataType::INT8);
auto int8_1_fid = schema->AddDebugField("int81", DataType::INT8);
auto int16_fid = schema->AddDebugField("int16", DataType::INT16);
auto int16_1_fid = schema->AddDebugField("int161", DataType::INT16);
auto int32_fid = schema->AddDebugField("int32", DataType::INT32);
auto int32_1_fid = schema->AddDebugField("int321", DataType::INT32);
auto int64_fid = schema->AddDebugField("int64", DataType::INT64);
auto int64_1_fid = schema->AddDebugField("int641", DataType::INT64);
auto float_fid = schema->AddDebugField("float", DataType::FLOAT);
auto float_1_fid = schema->AddDebugField("float1", DataType::FLOAT);
auto double_fid = schema->AddDebugField("double", DataType::DOUBLE);
auto double_1_fid = schema->AddDebugField("double1", DataType::DOUBLE);
auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR);
auto str2_fid = schema->AddDebugField("string2", DataType::VARCHAR);
auto str3_fid = schema->AddDebugField("string3", DataType::VARCHAR);
schema->set_primary_field_id(str1_fid);
auto seg = CreateSealedSegment(schema);
int N = 1000;
auto raw_data = DataGen(schema, N);
for (auto& [field_id, field_meta] : schema->get_fields()) {
auto array = raw_data.get_col(field_id);
auto data_info =
LoadFieldDataInfo{field_id.get(), array.get(), N, "/tmp/a"};
seg->LoadFieldData(data_info);
}
ExecExprVisitor visitor(*seg, seg->get_row_count(), MAX_TIMESTAMP);
auto build_expr = [&](enum DataType type) -> std::shared_ptr<query::Expr> {
switch (type) {
case DataType::BOOL: {
auto compare_expr = std::make_shared<query::CompareExpr>();
compare_expr->op_type_ = OpType::LessThan;
compare_expr->left_data_type_ = DataType::BOOL;
compare_expr->left_field_id_ = bool_fid;
compare_expr->right_data_type_ = DataType::BOOL;
compare_expr->right_field_id_ = bool_1_fid;
return compare_expr;
}
case DataType::INT8: {
auto compare_expr = std::make_shared<query::CompareExpr>();
compare_expr->op_type_ = OpType::LessThan;
compare_expr->left_data_type_ = DataType::INT8;
compare_expr->left_field_id_ = int8_fid;
compare_expr->right_data_type_ = DataType::INT8;
compare_expr->right_field_id_ = int8_1_fid;
return compare_expr;
}
case DataType::INT16: {
auto compare_expr = std::make_shared<query::CompareExpr>();
compare_expr->op_type_ = OpType::LessThan;
compare_expr->left_data_type_ = DataType::INT16;
compare_expr->left_field_id_ = int16_fid;
compare_expr->right_data_type_ = DataType::INT16;
compare_expr->right_field_id_ = int16_1_fid;
return compare_expr;
}
case DataType::INT32: {
auto compare_expr = std::make_shared<query::CompareExpr>();
compare_expr->op_type_ = OpType::LessThan;
compare_expr->left_data_type_ = DataType::INT32;
compare_expr->left_field_id_ = int32_fid;
compare_expr->right_data_type_ = DataType::INT32;
compare_expr->right_field_id_ = int32_1_fid;
return compare_expr;
}
case DataType::INT64: {
auto compare_expr = std::make_shared<query::CompareExpr>();
compare_expr->op_type_ = OpType::LessThan;
compare_expr->left_data_type_ = DataType::INT64;
compare_expr->left_field_id_ = int64_fid;
compare_expr->right_data_type_ = DataType::INT64;
compare_expr->right_field_id_ = int64_1_fid;
return compare_expr;
}
case DataType::FLOAT: {
auto compare_expr = std::make_shared<query::CompareExpr>();
compare_expr->op_type_ = OpType::LessThan;
compare_expr->left_data_type_ = DataType::FLOAT;
compare_expr->left_field_id_ = float_fid;
compare_expr->right_data_type_ = DataType::FLOAT;
compare_expr->right_field_id_ = float_1_fid;
return compare_expr;
}
case DataType::DOUBLE: {
auto compare_expr = std::make_shared<query::CompareExpr>();
compare_expr->op_type_ = OpType::LessThan;
compare_expr->left_data_type_ = DataType::DOUBLE;
compare_expr->left_field_id_ = double_fid;
compare_expr->right_data_type_ = DataType::DOUBLE;
compare_expr->right_field_id_ = double_1_fid;
return compare_expr;
}
case DataType::VARCHAR: {
auto compare_expr = std::make_shared<query::CompareExpr>();
compare_expr->op_type_ = OpType::LessThan;
compare_expr->left_data_type_ = DataType::VARCHAR;
compare_expr->left_field_id_ = str2_fid;
compare_expr->right_data_type_ = DataType::VARCHAR;
compare_expr->right_field_id_ = str3_fid;
return compare_expr;
}
default:
return std::make_shared<query::CompareExpr>();
}
};
std::cout << "start compare test" << std::endl;
auto expr = build_expr(DataType::BOOL);
auto final = visitor.call_child(*expr);
expr = build_expr(DataType::INT8);
final = visitor.call_child(*expr);
expr = build_expr(DataType::INT16);
final = visitor.call_child(*expr);
expr = build_expr(DataType::INT32);
final = visitor.call_child(*expr);
expr = build_expr(DataType::INT64);
final = visitor.call_child(*expr);
expr = build_expr(DataType::FLOAT);
final = visitor.call_child(*expr);
expr = build_expr(DataType::DOUBLE);
final = visitor.call_child(*expr);
std::cout << "end compare test" << std::endl;
}
TEST(Expr, TestExprs) {
using namespace milvus::query;
using namespace milvus::segcore;
auto schema = std::make_shared<Schema>();
auto vec_fid = schema->AddDebugField(
"fakevec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2);
auto int8_fid = schema->AddDebugField("int8", DataType::INT8);
auto int8_1_fid = schema->AddDebugField("int81", DataType::INT8);
auto int16_fid = schema->AddDebugField("int16", DataType::INT16);
auto int16_1_fid = schema->AddDebugField("int161", DataType::INT16);
auto int32_fid = schema->AddDebugField("int32", DataType::INT32);
auto int32_1_fid = schema->AddDebugField("int321", DataType::INT32);
auto int64_fid = schema->AddDebugField("int64", DataType::INT64);
auto int64_1_fid = schema->AddDebugField("int641", DataType::INT64);
auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR);
auto str2_fid = schema->AddDebugField("string2", DataType::VARCHAR);
schema->set_primary_field_id(str1_fid);
auto seg = CreateSealedSegment(schema);
int N = 1000000;
auto raw_data = DataGen(schema, N);
// load field data
for (auto& [field_id, field_meta] : schema->get_fields()) {
std::cout << field_id.get() << field_meta.get_name().get() << std::endl;
auto array = raw_data.get_col(field_id);
auto data_info =
LoadFieldDataInfo{field_id.get(), array.get(), N, "/tmp/a"};
seg->LoadFieldData(data_info);
}
ExecExprVisitor visitor(*seg, seg->get_row_count(), MAX_TIMESTAMP);
enum ExprType {
UnaryRangeExpr = 0,
TermExprImpl = 1,
CompareExpr = 2,
LogicalUnaryExpr = 3,
BinaryRangeExpr = 4,
LogicalBinaryExpr = 5,
BinaryArithOpEvalRangeExpr = 6,
};
auto build_expr =
[&](enum ExprType test_type) -> std::shared_ptr<query::Expr> {
switch (test_type) {
case UnaryRangeExpr:
return std::make_shared<query::UnaryRangeExprImpl<int8_t>>(
ColumnInfo(int8_fid, DataType::INT8),
proto::plan::OpType::GreaterThan,
10,
proto::plan::GenericValue::ValCase::kInt64Val);
break;
case TermExprImpl: {
std::vector<int64_t> retrieve_ints = {1, 4, 6};
return std::make_shared<query::TermExprImpl<int64_t>>(
ColumnInfo(int64_fid, DataType::INT64),
retrieve_ints,
proto::plan::GenericValue::ValCase::kInt64Val);
break;
}
case CompareExpr: {
auto compare_expr = std::make_shared<query::CompareExpr>();
compare_expr->op_type_ = OpType::LessThan;
compare_expr->left_data_type_ = DataType::INT8;
compare_expr->left_field_id_ = int8_fid;
compare_expr->right_data_type_ = DataType::INT8;
compare_expr->right_field_id_ = int8_1_fid;
return compare_expr;
break;
}
case BinaryRangeExpr: {
return std::make_shared<query::BinaryRangeExprImpl<int64_t>>(
ColumnInfo(int64_fid, DataType::INT64),
proto::plan::GenericValue::ValCase::kInt64Val,
true,
true,
10,
45);
break;
}
case LogicalUnaryExpr: {
ExprPtr child_expr =
std::make_unique<query::UnaryRangeExprImpl<int32_t>>(
ColumnInfo(int32_fid, DataType::INT32),
proto::plan::OpType::GreaterThan,
10,
proto::plan::GenericValue::ValCase::kInt64Val);
return std::make_shared<query::LogicalUnaryExpr>(
LogicalUnaryExpr::OpType::LogicalNot, child_expr);
break;
}
case LogicalBinaryExpr: {
ExprPtr child1_expr =
std::make_unique<query::UnaryRangeExprImpl<int8_t>>(
ColumnInfo(int8_fid, DataType::INT8),
proto::plan::OpType::GreaterThan,
10,
proto::plan::GenericValue::ValCase::kInt64Val);
ExprPtr child2_expr =
std::make_unique<query::UnaryRangeExprImpl<int8_t>>(
ColumnInfo(int8_fid, DataType::INT8),
proto::plan::OpType::NotEqual,
10,
proto::plan::GenericValue::ValCase::kInt64Val);
return std::make_shared<query::LogicalBinaryExpr>(
LogicalBinaryExpr::OpType::LogicalXor,
child1_expr,
child2_expr);
break;
}
case BinaryArithOpEvalRangeExpr: {
return std::make_shared<
query::BinaryArithOpEvalRangeExprImpl<int8_t>>(
ColumnInfo(int8_fid, DataType::INT8),
proto::plan::GenericValue::ValCase::kInt64Val,
proto::plan::ArithOpType::Add,
10,
proto::plan::OpType::Equal,
100);
break;
}
default:
return std::make_shared<query::BinaryRangeExprImpl<int64_t>>(
ColumnInfo(int64_fid, DataType::INT64),
proto::plan::GenericValue::ValCase::kInt64Val,
true,
true,
10,
45);
break;
}
};
auto expr = build_expr(UnaryRangeExpr);
std::cout << "start test" << std::endl;
auto start = std::chrono::steady_clock::now();
auto final = visitor.call_child(*expr);
std::cout << "cost: "
<< std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::steady_clock::now() - start)
.count()
<< "us" << std::endl;
}
TEST(Expr, TestCompareWithScalarIndexMaris) {
using namespace milvus::query;
using namespace milvus::segcore;

View File

@ -57,16 +57,16 @@ TEST_F(StringIndexMarisaTest, In) {
auto index = milvus::index::CreateStringIndexMarisa();
index->Build(nb, strs.data());
auto bitset = index->In(strs.size(), strs.data());
ASSERT_EQ(bitset->size(), strs.size());
ASSERT_TRUE(bitset->any());
ASSERT_EQ(bitset.size(), strs.size());
ASSERT_TRUE(Any(bitset));
}
TEST_F(StringIndexMarisaTest, NotIn) {
auto index = milvus::index::CreateStringIndexMarisa();
index->Build(nb, strs.data());
auto bitset = index->NotIn(strs.size(), strs.data());
ASSERT_EQ(bitset->size(), strs.size());
ASSERT_TRUE(bitset->none());
ASSERT_EQ(bitset.size(), strs.size());
ASSERT_TRUE(BitSetNone(bitset));
}
TEST_F(StringIndexMarisaTest, Range) {
@ -79,32 +79,32 @@ TEST_F(StringIndexMarisaTest, Range) {
{
auto bitset = index->Range("0", milvus::OpType::GreaterEqual);
ASSERT_EQ(bitset->size(), nb);
ASSERT_EQ(bitset->count(), nb);
ASSERT_EQ(bitset.size(), nb);
ASSERT_EQ(Count(bitset), nb);
}
{
auto bitset = index->Range("90", milvus::OpType::LessThan);
ASSERT_EQ(bitset->size(), nb);
ASSERT_EQ(bitset->count(), nb);
ASSERT_EQ(bitset.size(), nb);
ASSERT_EQ(Count(bitset), nb);
}
{
auto bitset = index->Range("9", milvus::OpType::LessEqual);
ASSERT_EQ(bitset->size(), nb);
ASSERT_EQ(bitset->count(), nb);
ASSERT_EQ(bitset.size(), nb);
ASSERT_EQ(Count(bitset), nb);
}
{
auto bitset = index->Range("0", true, "9", true);
ASSERT_EQ(bitset->size(), nb);
ASSERT_EQ(bitset->count(), nb);
ASSERT_EQ(bitset.size(), nb);
ASSERT_EQ(Count(bitset), nb);
}
{
auto bitset = index->Range("0", true, "90", false);
ASSERT_EQ(bitset->size(), nb);
ASSERT_EQ(bitset->count(), nb);
ASSERT_EQ(bitset.size(), nb);
ASSERT_EQ(Count(bitset), nb);
}
}
@ -125,8 +125,8 @@ TEST_F(StringIndexMarisaTest, PrefixMatch) {
for (size_t i = 0; i < strs.size(); i++) {
auto str = strs[i];
auto bitset = index->PrefixMatch(str);
ASSERT_EQ(bitset->size(), strs.size());
ASSERT_TRUE(bitset->test(i));
ASSERT_EQ(bitset.size(), strs.size());
ASSERT_TRUE(bitset[i]);
}
}
@ -139,7 +139,7 @@ TEST_F(StringIndexMarisaTest, Query) {
ds->Set<milvus::OpType>(milvus::index::OPERATOR_TYPE,
milvus::OpType::In);
auto bitset = index->Query(ds);
ASSERT_TRUE(bitset->any());
ASSERT_TRUE(Any(bitset));
}
{
@ -147,7 +147,7 @@ TEST_F(StringIndexMarisaTest, Query) {
ds->Set<milvus::OpType>(milvus::index::OPERATOR_TYPE,
milvus::OpType::NotIn);
auto bitset = index->Query(ds);
ASSERT_TRUE(bitset->none());
ASSERT_TRUE(BitSetNone(bitset));
}
{
@ -156,8 +156,8 @@ TEST_F(StringIndexMarisaTest, Query) {
milvus::OpType::GreaterEqual);
ds->Set<std::string>(milvus::index::RANGE_VALUE, "0");
auto bitset = index->Query(ds);
ASSERT_EQ(bitset->size(), strs.size());
ASSERT_EQ(bitset->count(), strs.size());
ASSERT_EQ(bitset.size(), strs.size());
ASSERT_EQ(Count(bitset), strs.size());
}
{
@ -169,7 +169,7 @@ TEST_F(StringIndexMarisaTest, Query) {
ds->Set<bool>(milvus::index::LOWER_BOUND_INCLUSIVE, true);
ds->Set<bool>(milvus::index::UPPER_BOUND_INCLUSIVE, true);
auto bitset = index->Query(ds);
ASSERT_TRUE(bitset->any());
ASSERT_TRUE(Any(bitset));
}
{
@ -180,8 +180,8 @@ TEST_F(StringIndexMarisaTest, Query) {
ds->Set<std::string>(milvus::index::PREFIX_VALUE,
std::move(strs[i]));
auto bitset = index->Query(ds);
ASSERT_EQ(bitset->size(), strs.size());
ASSERT_TRUE(bitset->test(i));
ASSERT_EQ(bitset.size(), strs.size());
ASSERT_TRUE(bitset[i]);
}
}
}
@ -205,58 +205,58 @@ TEST_F(StringIndexMarisaTest, Codec) {
{
auto bitset = copy_index->In(nb, strings.data());
ASSERT_EQ(bitset->size(), nb);
ASSERT_TRUE(bitset->any());
ASSERT_EQ(bitset.size(), nb);
ASSERT_TRUE(Any(bitset));
}
{
auto bitset = copy_index->In(1, invalid_strings.data());
ASSERT_EQ(bitset->size(), nb);
ASSERT_TRUE(bitset->none());
ASSERT_EQ(bitset.size(), nb);
ASSERT_TRUE(BitSetNone(bitset));
}
{
auto bitset = copy_index->NotIn(nb, strings.data());
ASSERT_EQ(bitset->size(), nb);
ASSERT_TRUE(bitset->none());
ASSERT_EQ(bitset.size(), nb);
ASSERT_TRUE(BitSetNone(bitset));
}
{
auto bitset = copy_index->Range("0", milvus::OpType::GreaterEqual);
ASSERT_EQ(bitset->size(), nb);
ASSERT_EQ(bitset->count(), nb);
ASSERT_EQ(bitset.size(), nb);
ASSERT_EQ(Count(bitset), nb);
}
{
auto bitset = copy_index->Range("90", milvus::OpType::LessThan);
ASSERT_EQ(bitset->size(), nb);
ASSERT_EQ(bitset->count(), nb);
ASSERT_EQ(bitset.size(), nb);
ASSERT_EQ(Count(bitset), nb);
}
{
auto bitset = copy_index->Range("9", milvus::OpType::LessEqual);
ASSERT_EQ(bitset->size(), nb);
ASSERT_EQ(bitset->count(), nb);
ASSERT_EQ(bitset.size(), nb);
ASSERT_EQ(Count(bitset), nb);
}
{
auto bitset = copy_index->Range("0", true, "9", true);
ASSERT_EQ(bitset->size(), nb);
ASSERT_EQ(bitset->count(), nb);
ASSERT_EQ(bitset.size(), nb);
ASSERT_EQ(Count(bitset), nb);
}
{
auto bitset = copy_index->Range("0", true, "90", false);
ASSERT_EQ(bitset->size(), nb);
ASSERT_EQ(bitset->count(), nb);
ASSERT_EQ(bitset.size(), nb);
ASSERT_EQ(Count(bitset), nb);
}
{
for (size_t i = 0; i < nb; i++) {
auto str = strings[i];
auto bitset = copy_index->PrefixMatch(str);
ASSERT_EQ(bitset->size(), nb);
ASSERT_TRUE(bitset->test(i));
ASSERT_EQ(bitset.size(), nb);
ASSERT_TRUE(bitset[i]);
}
}
}
@ -283,58 +283,58 @@ TEST_F(StringIndexMarisaTest, BaseIndexCodec) {
{
auto bitset = copy_index->In(nb, strings.data());
ASSERT_EQ(bitset->size(), nb);
ASSERT_TRUE(bitset->any());
ASSERT_EQ(bitset.size(), nb);
ASSERT_TRUE(Any(bitset));
}
{
auto bitset = copy_index->In(1, invalid_strings.data());
ASSERT_EQ(bitset->size(), nb);
ASSERT_TRUE(bitset->none());
ASSERT_EQ(bitset.size(), nb);
ASSERT_TRUE(BitSetNone(bitset));
}
{
auto bitset = copy_index->NotIn(nb, strings.data());
ASSERT_EQ(bitset->size(), nb);
ASSERT_TRUE(bitset->none());
ASSERT_EQ(bitset.size(), nb);
ASSERT_TRUE(BitSetNone(bitset));
}
{
auto bitset = copy_index->Range("0", milvus::OpType::GreaterEqual);
ASSERT_EQ(bitset->size(), nb);
ASSERT_EQ(bitset->count(), nb);
ASSERT_EQ(bitset.size(), nb);
ASSERT_EQ(Count(bitset), nb);
}
{
auto bitset = copy_index->Range("90", milvus::OpType::LessThan);
ASSERT_EQ(bitset->size(), nb);
ASSERT_EQ(bitset->count(), nb);
ASSERT_EQ(bitset.size(), nb);
ASSERT_EQ(Count(bitset), nb);
}
{
auto bitset = copy_index->Range("9", milvus::OpType::LessEqual);
ASSERT_EQ(bitset->size(), nb);
ASSERT_EQ(bitset->count(), nb);
ASSERT_EQ(bitset.size(), nb);
ASSERT_EQ(Count(bitset), nb);
}
{
auto bitset = copy_index->Range("0", true, "9", true);
ASSERT_EQ(bitset->size(), nb);
ASSERT_EQ(bitset->count(), nb);
ASSERT_EQ(bitset.size(), nb);
ASSERT_EQ(Count(bitset), nb);
}
{
auto bitset = copy_index->Range("0", true, "90", false);
ASSERT_EQ(bitset->size(), nb);
ASSERT_EQ(bitset->count(), nb);
ASSERT_EQ(bitset.size(), nb);
ASSERT_EQ(Count(bitset), nb);
}
{
for (size_t i = 0; i < nb; i++) {
auto str = strings[i];
auto bitset = copy_index->PrefixMatch(str);
ASSERT_EQ(bitset->size(), nb);
ASSERT_TRUE(bitset->test(i));
ASSERT_EQ(bitset.size(), nb);
ASSERT_TRUE(bitset[i]);
}
}
}

View File

@ -15,6 +15,8 @@
#include <vector>
#include <memory>
#include "common/Types.h"
using milvus::index::ScalarIndex;
namespace {
@ -33,6 +35,36 @@ compare_double(double x, double y, double epsilon = 0.000001f) {
return false;
}
bool
Any(const milvus::FixedVector<bool>& vec) {
for (auto& val : vec) {
if (val == false) {
return false;
}
}
return true;
}
bool
BitSetNone(const milvus::FixedVector<bool>& vec) {
for (auto& val : vec) {
if (val == true) {
return false;
}
}
return true;
}
uint64_t
Count(const milvus::FixedVector<bool>& vec) {
uint64_t count = 0;
for (size_t i = 0; i < vec.size(); ++i) {
if (vec[i] == true)
count++;
}
return count;
}
inline void
assert_order(const milvus::SearchResult& result,
const knowhere::MetricType& metric_type) {
@ -71,24 +103,24 @@ assert_in(ScalarIndex<T>* index, const std::vector<T>& arr) {
}
auto bitset1 = index->In(arr.size(), arr.data());
ASSERT_EQ(arr.size(), bitset1->size());
ASSERT_TRUE(bitset1->any());
ASSERT_EQ(arr.size(), bitset1.size());
ASSERT_TRUE(Any(bitset1));
auto test = std::make_unique<T>(arr[arr.size() - 1] + 1);
auto bitset2 = index->In(1, test.get());
ASSERT_EQ(arr.size(), bitset2->size());
ASSERT_TRUE(bitset2->none());
ASSERT_EQ(arr.size(), bitset2.size());
ASSERT_TRUE(BitSetNone(bitset2));
}
template <typename T>
inline void
assert_not_in(ScalarIndex<T>* index, const std::vector<T>& arr) {
auto bitset1 = index->NotIn(arr.size(), arr.data());
ASSERT_EQ(arr.size(), bitset1->size());
ASSERT_TRUE(bitset1->none());
ASSERT_EQ(arr.size(), bitset1.size());
ASSERT_TRUE(BitSetNone(bitset1));
auto test = std::make_unique<T>(arr[arr.size() - 1] + 1);
auto bitset2 = index->NotIn(1, test.get());
ASSERT_EQ(arr.size(), bitset2->size());
ASSERT_TRUE(bitset2->any());
ASSERT_EQ(arr.size(), bitset2.size());
ASSERT_TRUE(Any(bitset2));
}
template <typename T>
@ -98,24 +130,24 @@ assert_range(ScalarIndex<T>* index, const std::vector<T>& arr) {
auto test_max = arr[arr.size() - 1];
auto bitset1 = index->Range(test_min - 1, milvus::OpType::GreaterThan);
ASSERT_EQ(arr.size(), bitset1->size());
ASSERT_TRUE(bitset1->any());
ASSERT_EQ(arr.size(), bitset1.size());
ASSERT_TRUE(Any(bitset1));
auto bitset2 = index->Range(test_min, milvus::OpType::GreaterEqual);
ASSERT_EQ(arr.size(), bitset2->size());
ASSERT_TRUE(bitset2->any());
ASSERT_EQ(arr.size(), bitset2.size());
ASSERT_TRUE(Any(bitset2));
auto bitset3 = index->Range(test_max + 1, milvus::OpType::LessThan);
ASSERT_EQ(arr.size(), bitset3->size());
ASSERT_TRUE(bitset3->any());
ASSERT_EQ(arr.size(), bitset3.size());
ASSERT_TRUE(Any(bitset3));
auto bitset4 = index->Range(test_max, milvus::OpType::LessEqual);
ASSERT_EQ(arr.size(), bitset4->size());
ASSERT_TRUE(bitset4->any());
ASSERT_EQ(arr.size(), bitset4.size());
ASSERT_TRUE(Any(bitset4));
auto bitset5 = index->Range(test_min, true, test_max, true);
ASSERT_EQ(arr.size(), bitset5->size());
ASSERT_TRUE(bitset5->any());
ASSERT_EQ(arr.size(), bitset5.size());
ASSERT_TRUE(Any(bitset5));
}
template <typename T>
@ -156,8 +188,8 @@ inline void
assert_in(ScalarIndex<std::string>* index,
const std::vector<std::string>& arr) {
auto bitset1 = index->In(arr.size(), arr.data());
ASSERT_EQ(arr.size(), bitset1->size());
ASSERT_TRUE(bitset1->any());
ASSERT_EQ(arr.size(), bitset1.size());
ASSERT_TRUE(Any(bitset1));
}
template <>
@ -165,8 +197,8 @@ inline void
assert_not_in(ScalarIndex<std::string>* index,
const std::vector<std::string>& arr) {
auto bitset1 = index->NotIn(arr.size(), arr.data());
ASSERT_EQ(arr.size(), bitset1->size());
ASSERT_TRUE(bitset1->none());
ASSERT_EQ(arr.size(), bitset1.size());
ASSERT_TRUE(BitSetNone(bitset1));
}
template <>
@ -177,15 +209,15 @@ assert_range(ScalarIndex<std::string>* index,
auto test_max = arr[arr.size() - 1];
auto bitset2 = index->Range(test_min, milvus::OpType::GreaterEqual);
ASSERT_EQ(arr.size(), bitset2->size());
ASSERT_TRUE(bitset2->any());
ASSERT_EQ(arr.size(), bitset2.size());
ASSERT_TRUE(Any(bitset2));
auto bitset4 = index->Range(test_max, milvus::OpType::LessEqual);
ASSERT_EQ(arr.size(), bitset4->size());
ASSERT_TRUE(bitset4->any());
ASSERT_EQ(arr.size(), bitset4.size());
ASSERT_TRUE(Any(bitset4));
auto bitset5 = index->Range(test_min, true, test_max, true);
ASSERT_EQ(arr.size(), bitset5->size());
ASSERT_TRUE(bitset5->any());
ASSERT_EQ(arr.size(), bitset5.size());
ASSERT_TRUE(Any(bitset5));
}
} // namespace

View File

@ -252,6 +252,14 @@ DataGen(SchemaPtr schema,
insert_cols(data, N, field_meta);
break;
}
case DataType::BOOL: {
FixedVector<bool> data(N);
for (int i = 0; i < N; ++i) {
data[i] = i % 2 == 0 ? true : false;
}
insert_cols(data, N, field_meta);
break;
}
case DataType::INT64: {
vector<int64_t> data(N);
for (int i = 0; i < N; i++) {

View File

@ -31,6 +31,51 @@ if [[ ! ${jobs+1} ]]; then
fi
fi
function get_cpu_arch {
local CPU_ARCH=$1
local OS
OS=$(uname)
local MACHINE
MACHINE=$(uname -m)
ADDITIONAL_FLAGS=""
if [ -z "$CPU_ARCH" ]; then
if [ "$OS" = "Darwin" ]; then
if [ "$MACHINE" = "x86_64" ]; then
local CPU_CAPABILITIES
CPU_CAPABILITIES=$(sysctl -a | grep machdep.cpu.features | awk '{print tolower($0)}')
if [[ $CPU_CAPABILITIES =~ "avx" ]]; then
CPU_ARCH="avx"
else
CPU_ARCH="sse"
fi
elif [[ $(sysctl -a | grep machdep.cpu.brand_string) =~ "Apple" ]]; then
# Apple silicon.
CPU_ARCH="arm64"
fi
else [ "$OS" = "Linux" ];
local CPU_CAPABILITIES
CPU_CAPABILITIES=$(cat /proc/cpuinfo | grep flags | head -n 1| awk '{print tolower($0)}')
if [[ "$CPU_CAPABILITIES" =~ "avx" ]]; then
CPU_ARCH="avx"
elif [[ "$CPU_CAPABILITIES" =~ "sse" ]]; then
CPU_ARCH="sse"
elif [ "$MACHINE" = "aarch64" ]; then
CPU_ARCH="aarch64"
fi
fi
fi
echo -n $CPU_ARCH
}
SOURCE="${BASH_SOURCE[0]}"
while [ -h "$SOURCE" ]; do # resolve $SOURCE until the file is no longer a symlink
DIR="$( cd -P "$( dirname "$SOURCE" )" && pwd )"
@ -58,8 +103,9 @@ CUSTOM_THIRDPARTY_PATH=""
EMBEDDED_MILVUS="OFF"
BUILD_DISK_ANN="OFF"
USE_ASAN="OFF"
OPEN_SIMD="OFF"
while getopts "p:d:t:s:f:n:a:ulrcghzmeb" arg; do
while getopts "p:d:t:s:f:n:i:a:ulrcghzmeb" arg; do
case $arg in
f)
CUSTOM_THIRDPARTY_PATH=$OPTARG
@ -114,6 +160,9 @@ while getopts "p:d:t:s:f:n:a:ulrcghzmeb" arg; do
BUILD_TYPE=Debug
fi
;;
i)
OPEN_SIMD=$OPTARG
;;
h) # help
echo "
@ -189,6 +238,8 @@ if [[ ${MAKE_CLEAN} == "ON" ]]; then
exit 0
fi
CPU_ARCH=$(get_cpu_arch $CPU_TARGET)
arch=$(uname -m)
CMAKE_CMD="cmake \
${CMAKE_EXTRA_ARGS} \
@ -208,6 +259,8 @@ ${CMAKE_EXTRA_ARGS} \
-DEMBEDDED_MILVUS=${EMBEDDED_MILVUS} \
-DBUILD_DISK_ANN=${BUILD_DISK_ANN} \
-DUSE_ASAN=${USE_ASAN} \
-DOPEN_SIMD=${OPEN_SIMD} \
-DCPU_ARCH=${CPU_ARCH} \
${CPP_SRC_DIR}"
echo "CC $CC"