enhance: [bitset] multiple 'and' and 'or' in a single op (#33345)

issue #34117
* Refactoring
* Added a capability to perform multiple bitwise `and` and `or`
operations in a single op
* AVX2, AVX512, ARM NEON, ARM SVE backed bitwise `and`, `op`, `xor` and
`sub` ops
* more unit tests for bitset
* fixed a bug in `or_with_count` for certain bitset sizes
* fixed a bug for certain offset values for inplace operations that take
two bitsets

Signed-off-by: Alexandr Guzhva <alexanderguzhva@gmail.com>
This commit is contained in:
Alexander Guzhva 2024-10-22 04:25:33 -04:00 committed by GitHub
parent 6bedc7e8c8
commit 5a1f752272
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 3529 additions and 385 deletions

View File

@ -23,6 +23,7 @@
#include <type_traits> #include <type_traits>
#include "common.h" #include "common.h"
#include "detail/maybe_vector.h"
namespace milvus { namespace milvus {
namespace bitset { namespace bitset {
@ -109,7 +110,6 @@ class BitsetBase {
public: public:
using policy_type = PolicyT; using policy_type = PolicyT;
using data_type = typename policy_type::data_type; using data_type = typename policy_type::data_type;
using size_type = typename policy_type::size_type;
using proxy_type = typename policy_type::proxy_type; using proxy_type = typename policy_type::proxy_type;
using const_proxy_type = typename policy_type::const_proxy_type; using const_proxy_type = typename policy_type::const_proxy_type;
@ -128,21 +128,21 @@ class BitsetBase {
} }
// Return the number of bits we're working with. // Return the number of bits we're working with.
inline size_type inline size_t
size() const { size() const {
return as_derived().size_impl(); return as_derived().size_impl();
} }
// Return the number of bytes which is needed to // Return the number of bytes which is needed to
// contain all our bits. // contain all our bits.
inline size_type inline size_t
size_in_bytes() const { size_in_bytes() const {
return policy_type::get_required_size_in_bytes(this->size()); return policy_type::get_required_size_in_bytes(this->size());
} }
// Return the number of elements which is needed to // Return the number of elements which is needed to
// contain all our bits. // contain all our bits.
inline size_type inline size_t
size_in_elements() const { size_in_elements() const {
return policy_type::get_required_size_in_elements(this->size()); return policy_type::get_required_size_in_elements(this->size());
} }
@ -155,19 +155,19 @@ class BitsetBase {
// //
inline proxy_type inline proxy_type
operator[](const size_type bit_idx) { operator[](const size_t bit_idx) {
range_checker::lt(bit_idx, this->size()); range_checker::lt(bit_idx, this->size());
const size_type idx_v = bit_idx + this->offset(); const size_t idx_v = bit_idx + this->offset();
return policy_type::get_proxy(this->data(), idx_v); return policy_type::get_proxy(this->data(), idx_v);
} }
// //
inline bool inline bool
operator[](const size_type bit_idx) const { operator[](const size_t bit_idx) const {
range_checker::lt(bit_idx, this->size()); range_checker::lt(bit_idx, this->size());
const size_type idx_v = bit_idx + this->offset(); const size_t idx_v = bit_idx + this->offset();
const auto proxy = policy_type::get_proxy(this->data(), idx_v); const auto proxy = policy_type::get_proxy(this->data(), idx_v);
return proxy.operator bool(); return proxy.operator bool();
} }
@ -180,10 +180,21 @@ class BitsetBase {
// Set a given bit to a given value. // Set a given bit to a given value.
inline void inline void
set(const size_type bit_idx, const bool value = true) { set(const size_t bit_idx, const bool value = true) {
this->operator[](bit_idx) = value; this->operator[](bit_idx) = value;
} }
// Set a given range of [a, b) bits to a given value.
inline void
set(const size_t bit_idx_start,
const size_t size,
const bool value = true) {
range_checker::le(bit_idx_start + size, this->size());
policy_type::op_fill(
this->data(), this->offset() + bit_idx_start, size, value);
}
// Set all bits to false. // Set all bits to false.
inline void inline void
reset() { reset() {
@ -192,10 +203,16 @@ class BitsetBase {
// Set a given bit to false. // Set a given bit to false.
inline void inline void
reset(const size_type bit_idx) { reset(const size_t bit_idx) {
this->operator[](bit_idx) = false; this->operator[](bit_idx) = false;
} }
// Set a given range of [a, b) bits to false.
inline void
reset(const size_t bit_idx_start, const size_t size) {
this->set(bit_idx_start, size, false);
}
// Return whether all bits are set to true. // Return whether all bits are set to true.
inline bool inline bool
all() const { all() const {
@ -217,7 +234,7 @@ class BitsetBase {
// Inplace and. // Inplace and.
template <typename I, bool R> template <typename I, bool R>
inline void inline void
inplace_and(const BitsetBase<PolicyT, I, R>& other, const size_type size) { inplace_and(const BitsetBase<PolicyT, I, R>& other, const size_t size) {
range_checker::le(size, this->size()); range_checker::le(size, this->size());
range_checker::le(size, other.size()); range_checker::le(size, other.size());
@ -225,6 +242,74 @@ class BitsetBase {
this->data(), other.data(), this->offset(), other.offset(), size); this->data(), other.data(), this->offset(), other.offset(), size);
} }
template <bool R>
inline void
inplace_and(const BitsetView<PolicyT, R>* const others,
const size_t n_others,
const size_t size) {
range_checker::le(size, this->size());
for (size_t i = 0; i < n_others; i++) {
range_checker::le(size, others[i].size());
}
// pick buffers
detail::MaybeVector<const data_type*> tmp_data(n_others);
detail::MaybeVector<size_t> tmp_offset(n_others);
for (size_t i = 0; i < n_others; i++) {
tmp_data[i] = others[i].data();
tmp_offset[i] = others[i].offset();
}
policy_type::op_and_multiple(this->data(),
tmp_data.data(),
this->offset(),
tmp_offset.data(),
n_others,
size);
}
template <bool R>
inline void
inplace_and(const BitsetView<PolicyT, R>* const others,
const size_t n_others) {
this->inplace_and(others, n_others, this->size());
}
template <typename ContainerT, bool R>
inline void
inplace_and(const Bitset<PolicyT, ContainerT, R>* const others,
const size_t n_others,
const size_t size) {
range_checker::le(size, this->size());
for (size_t i = 0; i < n_others; i++) {
range_checker::le(size, others[i].size());
}
// pick buffers
detail::MaybeVector<const data_type*> tmp_data(n_others);
detail::MaybeVector<size_t> tmp_offset(n_others);
for (size_t i = 0; i < n_others; i++) {
tmp_data[i] = others[i].data();
tmp_offset[i] = others[i].offset();
}
policy_type::op_and_multiple(this->data(),
tmp_data.data(),
this->offset(),
tmp_offset.data(),
n_others,
size);
}
template <typename ContainerT, bool R>
inline void
inplace_and(const Bitset<PolicyT, ContainerT, R>* const others,
const size_t n_others) {
this->inplace_and(others, n_others, this->size());
}
// Inplace and. A given bitset / bitset view is expected to have the same size. // Inplace and. A given bitset / bitset view is expected to have the same size.
template <typename I, bool R> template <typename I, bool R>
inline ImplT& inline ImplT&
@ -238,7 +323,7 @@ class BitsetBase {
// Inplace or. // Inplace or.
template <typename I, bool R> template <typename I, bool R>
inline void inline void
inplace_or(const BitsetBase<PolicyT, I, R>& other, const size_type size) { inplace_or(const BitsetBase<PolicyT, I, R>& other, const size_t size) {
range_checker::le(size, this->size()); range_checker::le(size, this->size());
range_checker::le(size, other.size()); range_checker::le(size, other.size());
@ -246,6 +331,74 @@ class BitsetBase {
this->data(), other.data(), this->offset(), other.offset(), size); this->data(), other.data(), this->offset(), other.offset(), size);
} }
template <bool R>
inline void
inplace_or(const BitsetView<PolicyT, R>* const others,
const size_t n_others,
const size_t size) {
range_checker::le(size, this->size());
for (size_t i = 0; i < n_others; i++) {
range_checker::le(size, others[i].size());
}
// pick buffers
detail::MaybeVector<const data_type*> tmp_data(n_others);
detail::MaybeVector<size_t> tmp_offset(n_others);
for (size_t i = 0; i < n_others; i++) {
tmp_data[i] = others[i].data();
tmp_offset[i] = others[i].offset();
}
policy_type::op_or_multiple(this->data(),
tmp_data.data(),
this->offset(),
tmp_offset.data(),
n_others,
size);
}
template <bool R>
inline void
inplace_or(const BitsetView<PolicyT, R>* const others,
const size_t n_others) {
this->inplace_or(others, n_others, this->size());
}
template <typename ContainerT, bool R>
inline void
inplace_or(const Bitset<PolicyT, ContainerT, R>* const others,
const size_t n_others,
const size_t size) {
range_checker::le(size, this->size());
for (size_t i = 0; i < n_others; i++) {
range_checker::le(size, others[i].size());
}
// pick buffers
detail::MaybeVector<const data_type*> tmp_data(n_others);
detail::MaybeVector<size_t> tmp_offset(n_others);
for (size_t i = 0; i < n_others; i++) {
tmp_data[i] = others[i].data();
tmp_offset[i] = others[i].offset();
}
policy_type::op_or_multiple(this->data(),
tmp_data.data(),
this->offset(),
tmp_offset.data(),
n_others,
size);
}
template <typename ContainerT, bool R>
inline void
inplace_or(const Bitset<PolicyT, ContainerT, R>* const others,
const size_t n_others) {
this->inplace_or(others, n_others, this->size());
}
// Inplace or. A given bitset / bitset view is expected to have the same size. // Inplace or. A given bitset / bitset view is expected to have the same size.
template <typename I, bool R> template <typename I, bool R>
inline ImplT& inline ImplT&
@ -264,13 +417,13 @@ class BitsetBase {
// //
inline BitsetView<PolicyT, IsRangeCheckEnabled> inline BitsetView<PolicyT, IsRangeCheckEnabled>
operator+(const size_type offset) { operator+(const size_t offset) {
return this->view(offset); return this->view(offset);
} }
// Create a view of a given size from the given position. // Create a view of a given size from the given position.
inline BitsetView<PolicyT, IsRangeCheckEnabled> inline BitsetView<PolicyT, IsRangeCheckEnabled>
view(const size_type offset, const size_type size) { view(const size_t offset, const size_t size) {
range_checker::le(offset, this->size()); range_checker::le(offset, this->size());
range_checker::le(offset + size, this->size()); range_checker::le(offset + size, this->size());
@ -280,7 +433,7 @@ class BitsetBase {
// Create a const view of a given size from the given position. // Create a const view of a given size from the given position.
inline BitsetView<PolicyT, IsRangeCheckEnabled> inline BitsetView<PolicyT, IsRangeCheckEnabled>
view(const size_type offset, const size_type size) const { view(const size_t offset, const size_t size) const {
range_checker::le(offset, this->size()); range_checker::le(offset, this->size());
range_checker::le(offset + size, this->size()); range_checker::le(offset + size, this->size());
@ -292,7 +445,7 @@ class BitsetBase {
// Create a view from the given position, which uses all available size. // Create a view from the given position, which uses all available size.
inline BitsetView<PolicyT, IsRangeCheckEnabled> inline BitsetView<PolicyT, IsRangeCheckEnabled>
view(const size_type offset) { view(const size_t offset) {
range_checker::le(offset, this->size()); range_checker::le(offset, this->size());
return BitsetView<PolicyT, IsRangeCheckEnabled>( return BitsetView<PolicyT, IsRangeCheckEnabled>(
@ -301,7 +454,7 @@ class BitsetBase {
// Create a const view from the given position, which uses all available size. // Create a const view from the given position, which uses all available size.
inline const BitsetView<PolicyT, IsRangeCheckEnabled> inline const BitsetView<PolicyT, IsRangeCheckEnabled>
view(const size_type offset) const { view(const size_t offset) const {
range_checker::le(offset, this->size()); range_checker::le(offset, this->size());
return BitsetView<PolicyT, IsRangeCheckEnabled>( return BitsetView<PolicyT, IsRangeCheckEnabled>(
@ -323,7 +476,7 @@ class BitsetBase {
} }
// Return the number of bits which are set to true. // Return the number of bits which are set to true.
inline size_type inline size_t
count() const { count() const {
return policy_type::op_count( return policy_type::op_count(
this->data(), this->offset(), this->size()); this->data(), this->offset(), this->size());
@ -354,7 +507,7 @@ class BitsetBase {
// Inplace xor. // Inplace xor.
template <typename I, bool R> template <typename I, bool R>
inline void inline void
inplace_xor(const BitsetBase<PolicyT, I, R>& other, const size_type size) { inplace_xor(const BitsetBase<PolicyT, I, R>& other, const size_t size) {
range_checker::le(size, this->size()); range_checker::le(size, this->size());
range_checker::le(size, other.size()); range_checker::le(size, other.size());
@ -375,7 +528,7 @@ class BitsetBase {
// Inplace sub. // Inplace sub.
template <typename I, bool R> template <typename I, bool R>
inline void inline void
inplace_sub(const BitsetBase<PolicyT, I, R>& other, const size_type size) { inplace_sub(const BitsetBase<PolicyT, I, R>& other, const size_t size) {
range_checker::le(size, this->size()); range_checker::le(size, this->size());
range_checker::le(size, other.size()); range_checker::le(size, other.size());
@ -394,16 +547,16 @@ class BitsetBase {
} }
// Find the index of the first bit set to true. // Find the index of the first bit set to true.
inline std::optional<size_type> inline std::optional<size_t>
find_first() const { find_first() const {
return policy_type::op_find( return policy_type::op_find(
this->data(), this->offset(), this->size(), 0); this->data(), this->offset(), this->size(), 0);
} }
// Find the index of the first bit set to true, starting from a given bit index. // Find the index of the first bit set to true, starting from a given bit index.
inline std::optional<size_type> inline std::optional<size_t>
find_next(const size_type starting_bit_idx) const { find_next(const size_t starting_bit_idx) const {
const size_type size_v = this->size(); const size_t size_v = this->size();
if (starting_bit_idx + 1 >= size_v) { if (starting_bit_idx + 1 >= size_v) {
return std::nullopt; return std::nullopt;
} }
@ -414,7 +567,7 @@ class BitsetBase {
// Read multiple bits starting from a given bit index. // Read multiple bits starting from a given bit index.
inline data_type inline data_type
read(const size_type starting_bit_idx, const size_type nbits) { read(const size_t starting_bit_idx, const size_t nbits) {
range_checker::le(nbits, sizeof(data_type)); range_checker::le(nbits, sizeof(data_type));
return policy_type::op_read( return policy_type::op_read(
@ -423,9 +576,9 @@ class BitsetBase {
// Write multiple bits starting from a given bit index. // Write multiple bits starting from a given bit index.
inline void inline void
write(const size_type starting_bit_idx, write(const size_t starting_bit_idx,
const data_type value, const data_type value,
const size_type nbits) { const size_t nbits) {
range_checker::le(nbits, sizeof(data_type)); range_checker::le(nbits, sizeof(data_type));
policy_type::op_write( policy_type::op_write(
@ -437,7 +590,7 @@ class BitsetBase {
void void
inplace_compare_column(const T* const __restrict t, inplace_compare_column(const T* const __restrict t,
const U* const __restrict u, const U* const __restrict u,
const size_type size, const size_t size,
CompareOpType op) { CompareOpType op) {
if (op == CompareOpType::EQ) { if (op == CompareOpType::EQ) {
this->inplace_compare_column<T, U, CompareOpType::EQ>(t, u, size); this->inplace_compare_column<T, U, CompareOpType::EQ>(t, u, size);
@ -460,7 +613,7 @@ class BitsetBase {
void void
inplace_compare_column(const T* const __restrict t, inplace_compare_column(const T* const __restrict t,
const U* const __restrict u, const U* const __restrict u,
const size_type size) { const size_t size) {
range_checker::le(size, this->size()); range_checker::le(size, this->size());
policy_type::template op_compare_column<T, U, Op>( policy_type::template op_compare_column<T, U, Op>(
@ -471,7 +624,7 @@ class BitsetBase {
template <typename T> template <typename T>
void void
inplace_compare_val(const T* const __restrict t, inplace_compare_val(const T* const __restrict t,
const size_type size, const size_t size,
const T& value, const T& value,
CompareOpType op) { CompareOpType op) {
if (op == CompareOpType::EQ) { if (op == CompareOpType::EQ) {
@ -494,7 +647,7 @@ class BitsetBase {
template <typename T, CompareOpType Op> template <typename T, CompareOpType Op>
void void
inplace_compare_val(const T* const __restrict t, inplace_compare_val(const T* const __restrict t,
const size_type size, const size_t size,
const T& value) { const T& value) {
range_checker::le(size, this->size()); range_checker::le(size, this->size());
@ -508,7 +661,7 @@ class BitsetBase {
inplace_within_range_column(const T* const __restrict lower, inplace_within_range_column(const T* const __restrict lower,
const T* const __restrict upper, const T* const __restrict upper,
const T* const __restrict values, const T* const __restrict values,
const size_type size, const size_t size,
const RangeType op) { const RangeType op) {
if (op == RangeType::IncInc) { if (op == RangeType::IncInc) {
this->inplace_within_range_column<T, RangeType::IncInc>( this->inplace_within_range_column<T, RangeType::IncInc>(
@ -532,7 +685,7 @@ class BitsetBase {
inplace_within_range_column(const T* const __restrict lower, inplace_within_range_column(const T* const __restrict lower,
const T* const __restrict upper, const T* const __restrict upper,
const T* const __restrict values, const T* const __restrict values,
const size_type size) { const size_t size) {
range_checker::le(size, this->size()); range_checker::le(size, this->size());
policy_type::template op_within_range_column<T, Op>( policy_type::template op_within_range_column<T, Op>(
@ -545,7 +698,7 @@ class BitsetBase {
inplace_within_range_val(const T& lower, inplace_within_range_val(const T& lower,
const T& upper, const T& upper,
const T* const __restrict values, const T* const __restrict values,
const size_type size, const size_t size,
const RangeType op) { const RangeType op) {
if (op == RangeType::IncInc) { if (op == RangeType::IncInc) {
this->inplace_within_range_val<T, RangeType::IncInc>( this->inplace_within_range_val<T, RangeType::IncInc>(
@ -569,7 +722,7 @@ class BitsetBase {
inplace_within_range_val(const T& lower, inplace_within_range_val(const T& lower,
const T& upper, const T& upper,
const T* const __restrict values, const T* const __restrict values,
const size_type size) { const size_t size) {
range_checker::le(size, this->size()); range_checker::le(size, this->size());
policy_type::template op_within_range_val<T, Op>( policy_type::template op_within_range_val<T, Op>(
@ -582,7 +735,7 @@ class BitsetBase {
inplace_arith_compare(const T* const __restrict src, inplace_arith_compare(const T* const __restrict src,
const ArithHighPrecisionType<T>& right_operand, const ArithHighPrecisionType<T>& right_operand,
const ArithHighPrecisionType<T>& value, const ArithHighPrecisionType<T>& value,
const size_type size, const size_t size,
const ArithOpType a_op, const ArithOpType a_op,
const CompareOpType cmp_op) { const CompareOpType cmp_op) {
if (a_op == ArithOpType::Add) { if (a_op == ArithOpType::Add) {
@ -765,7 +918,7 @@ class BitsetBase {
inplace_arith_compare(const T* const __restrict src, inplace_arith_compare(const T* const __restrict src,
const ArithHighPrecisionType<T>& right_operand, const ArithHighPrecisionType<T>& right_operand,
const ArithHighPrecisionType<T>& value, const ArithHighPrecisionType<T>& value,
const size_type size) { const size_t size) {
range_checker::le(size, this->size()); range_checker::le(size, this->size());
policy_type::template op_arith_compare<T, AOp, CmpOp>( policy_type::template op_arith_compare<T, AOp, CmpOp>(
@ -775,9 +928,9 @@ class BitsetBase {
// //
// Inplace and. Also, counts the number of active bits. // Inplace and. Also, counts the number of active bits.
template <typename I, bool R> template <typename I, bool R>
inline size_type inline size_t
inplace_and_with_count(const BitsetBase<PolicyT, I, R>& other, inplace_and_with_count(const BitsetBase<PolicyT, I, R>& other,
const size_type size) { const size_t size) {
range_checker::le(size, this->size()); range_checker::le(size, this->size());
range_checker::le(size, other.size()); range_checker::le(size, other.size());
@ -787,9 +940,9 @@ class BitsetBase {
// Inplace or. Also, counts the number of inactive bits. // Inplace or. Also, counts the number of inactive bits.
template <typename I, bool R> template <typename I, bool R>
inline size_type inline size_t
inplace_or_with_count(const BitsetBase<PolicyT, I, R>& other, inplace_or_with_count(const BitsetBase<PolicyT, I, R>& other,
const size_type size) { const size_t size) {
range_checker::le(size, this->size()); range_checker::le(size, this->size());
range_checker::le(size, other.size()); range_checker::le(size, other.size());
@ -798,7 +951,7 @@ class BitsetBase {
} }
// Return the starting bit offset in our container. // Return the starting bit offset in our container.
inline size_type inline size_t
offset() const { offset() const {
return as_derived().offset_impl(); return as_derived().offset_impl();
} }
@ -829,7 +982,6 @@ class BitsetView : public BitsetBase<PolicyT,
public: public:
using policy_type = PolicyT; using policy_type = PolicyT;
using data_type = typename policy_type::data_type; using data_type = typename policy_type::data_type;
using size_type = typename policy_type::size_type;
using proxy_type = typename policy_type::proxy_type; using proxy_type = typename policy_type::proxy_type;
using const_proxy_type = typename policy_type::const_proxy_type; using const_proxy_type = typename policy_type::const_proxy_type;
@ -849,11 +1001,11 @@ class BitsetView : public BitsetBase<PolicyT,
: Data{bitset.data()}, Size{bitset.size()}, Offset{bitset.offset()} { : Data{bitset.data()}, Size{bitset.size()}, Offset{bitset.offset()} {
} }
BitsetView(void* data, const size_type size) BitsetView(void* data, const size_t size)
: Data{reinterpret_cast<data_type*>(data)}, Size{size}, Offset{0} { : Data{reinterpret_cast<data_type*>(data)}, Size{size}, Offset{0} {
} }
BitsetView(void* data, const size_type offset, const size_type size) BitsetView(void* data, const size_t offset, const size_t size)
: Data{reinterpret_cast<data_type*>(data)}, Size{size}, Offset{offset} { : Data{reinterpret_cast<data_type*>(data)}, Size{size}, Offset{offset} {
} }
@ -861,9 +1013,9 @@ class BitsetView : public BitsetBase<PolicyT,
// the referenced bits are [Offset, Offset + Size) // the referenced bits are [Offset, Offset + Size)
data_type* Data = nullptr; data_type* Data = nullptr;
// measured in bits // measured in bits
size_type Size = 0; size_t Size = 0;
// measured in bits // measured in bits
size_type Offset = 0; size_t Offset = 0;
inline data_type* inline data_type*
data_impl() { data_impl() {
@ -873,11 +1025,11 @@ class BitsetView : public BitsetBase<PolicyT,
data_impl() const { data_impl() const {
return Data; return Data;
} }
inline size_type inline size_t
size_impl() const { size_impl() const {
return Size; return Size;
} }
inline size_type inline size_t
offset_impl() const { offset_impl() const {
return Offset; return Offset;
} }
@ -896,10 +1048,11 @@ class Bitset
public: public:
using policy_type = PolicyT; using policy_type = PolicyT;
using data_type = typename policy_type::data_type; using data_type = typename policy_type::data_type;
using size_type = typename policy_type::size_type;
using proxy_type = typename policy_type::proxy_type; using proxy_type = typename policy_type::proxy_type;
using const_proxy_type = typename policy_type::const_proxy_type; using const_proxy_type = typename policy_type::const_proxy_type;
using view_type = BitsetView<PolicyT, IsRangeCheckEnabled>;
// This is the container type. // This is the container type.
using container_type = ContainerT; using container_type = ContainerT;
// This is how the data is stored. For example, we may operate using // This is how the data is stored. For example, we may operate using
@ -914,11 +1067,11 @@ class Bitset
Bitset() { Bitset() {
} }
// Allocate the given number of bits. // Allocate the given number of bits.
Bitset(const size_type size) Bitset(const size_t size)
: Data(get_required_size_in_container_elements(size)), Size{size} { : Data(get_required_size_in_container_elements(size)), Size{size} {
} }
// Allocate the given number of bits, initialize with a given value. // Allocate the given number of bits, initialize with a given value.
Bitset(const size_type size, const bool init) Bitset(const size_t size, const bool init)
: Data(get_required_size_in_container_elements(size), : Data(get_required_size_in_container_elements(size),
init ? data_type(-1) : 0), init ? data_type(-1) : 0),
Size{size} { Size{size} {
@ -964,8 +1117,8 @@ class Bitset
// Resize. // Resize.
void void
resize(const size_type new_size) { resize(const size_t new_size) {
const size_type new_size_in_container_elements = const size_t new_size_in_container_elements =
get_required_size_in_container_elements(new_size); get_required_size_in_container_elements(new_size);
Data.resize(new_size_in_container_elements); Data.resize(new_size_in_container_elements);
Size = new_size; Size = new_size;
@ -973,8 +1126,8 @@ class Bitset
// Resize and initialize new bits with a given value if grown. // Resize and initialize new bits with a given value if grown.
void void
resize(const size_type new_size, const bool init) { resize(const size_t new_size, const bool init) {
const size_type old_size = this->size(); const size_t old_size = this->size();
this->resize(new_size); this->resize(new_size);
if (new_size > old_size) { if (new_size > old_size) {
@ -989,11 +1142,11 @@ class Bitset
template <typename I, bool R> template <typename I, bool R>
void void
append(const BitsetBase<PolicyT, I, R>& other, append(const BitsetBase<PolicyT, I, R>& other,
const size_type starting_bit_idx, const size_t starting_bit_idx,
const size_type count) { const size_t count) {
range_checker::le(starting_bit_idx, other.size()); range_checker::le(starting_bit_idx, other.size());
const size_type old_size = this->size(); const size_t old_size = this->size();
this->resize(this->size() + count); this->resize(this->size() + count);
policy_type::op_copy(other.data(), policy_type::op_copy(other.data(),
@ -1020,8 +1173,8 @@ class Bitset
// Reserve // Reserve
inline void inline void
reserve(const size_type capacity) { reserve(const size_t capacity) {
const size_type capacity_in_container_elements = const size_t capacity_in_container_elements =
get_required_size_in_container_elements(capacity); get_required_size_in_container_elements(capacity);
Data.reserve(capacity_in_container_elements); Data.reserve(capacity_in_container_elements);
} }
@ -1048,7 +1201,7 @@ class Bitset
// the container // the container
container_type Data; container_type Data;
// the actual number of bits // the actual number of bits
size_type Size = 0; size_t Size = 0;
inline data_type* inline data_type*
data_impl() { data_impl() {
@ -1058,19 +1211,19 @@ class Bitset
data_impl() const { data_impl() const {
return reinterpret_cast<const data_type*>(Data.data()); return reinterpret_cast<const data_type*>(Data.data());
} }
inline size_type inline size_t
size_impl() const { size_impl() const {
return Size; return Size;
} }
inline size_type inline size_t
offset_impl() const { offset_impl() const {
return 0; return 0;
} }
// //
static inline size_type static inline size_t
get_required_size_in_container_elements(const size_t size) { get_required_size_in_container_elements(const size_t size) {
const size_type size_in_bytes = const size_t size_in_bytes =
policy_type::get_required_size_in_bytes(size); policy_type::get_required_size_in_bytes(size);
return (size_in_bytes + sizeof(container_data_type) - 1) / return (size_in_bytes + sizeof(container_data_type) - 1) /
sizeof(container_data_type); sizeof(container_data_type);

View File

@ -27,6 +27,19 @@ namespace bitset {
// this option is only somewhat supported // this option is only somewhat supported
// #define BITSET_HEADER_ONLY // #define BITSET_HEADER_ONLY
// `always inline` hint.
// It is introduced to deal with clang's behavior to reuse
// once generated code. But if it is needed to generate
// different machine code for multiple platforms based on
// a single template, then such a behavior is undesired.
// `always inline` is applied for PolicyT methods. It is fine,
// because they are not used directly and are wrapped
// in BitsetBase methods. So, a compiler may decide whether
// to really inline them, but it forces a compiler to
// generate specialized code for every hardward platform.
// todo: MSVC has its own way to define `always inline`.
#define BITSET_ALWAYS_INLINE __attribute__((always_inline))
// a supporting utility // a supporting utility
template <class> template <class>
inline constexpr bool always_false_v = false; inline constexpr bool always_false_v = false;

View File

@ -32,55 +32,53 @@ namespace detail {
template <typename ElementT> template <typename ElementT>
struct BitWiseBitsetPolicy { struct BitWiseBitsetPolicy {
using data_type = ElementT; using data_type = ElementT;
constexpr static auto data_bits = sizeof(data_type) * 8; constexpr static size_t data_bits = sizeof(data_type) * 8;
using size_type = size_t;
using self_type = BitWiseBitsetPolicy<ElementT>; using self_type = BitWiseBitsetPolicy<ElementT>;
using proxy_type = Proxy<self_type>; using proxy_type = Proxy<self_type>;
using const_proxy_type = ConstProxy<self_type>; using const_proxy_type = ConstProxy<self_type>;
static inline size_type static inline size_t
get_element(const size_t idx) { get_element(const size_t idx) {
return idx / data_bits; return idx / data_bits;
} }
static inline size_type static inline size_t
get_shift(const size_t idx) { get_shift(const size_t idx) {
return idx % data_bits; return idx % data_bits;
} }
static inline size_type static inline size_t
get_required_size_in_elements(const size_t size) { get_required_size_in_elements(const size_t size) {
return (size + data_bits - 1) / data_bits; return (size + data_bits - 1) / data_bits;
} }
static inline size_type static inline size_t
get_required_size_in_bytes(const size_t size) { get_required_size_in_bytes(const size_t size) {
return get_required_size_in_elements(size) * sizeof(data_type); return get_required_size_in_elements(size) * sizeof(data_type);
} }
static inline proxy_type static inline proxy_type
get_proxy(data_type* const __restrict data, const size_type idx) { get_proxy(data_type* const __restrict data, const size_t idx) {
data_type& element = data[get_element(idx)]; data_type& element = data[get_element(idx)];
const size_type shift = get_shift(idx); const size_t shift = get_shift(idx);
return proxy_type{element, shift}; return proxy_type{element, shift};
} }
static inline const_proxy_type static inline const_proxy_type
get_proxy(const data_type* const __restrict data, const size_type idx) { get_proxy(const data_type* const __restrict data, const size_t idx) {
const data_type& element = data[get_element(idx)]; const data_type& element = data[get_element(idx)];
const size_type shift = get_shift(idx); const size_t shift = get_shift(idx);
return const_proxy_type{element, shift}; return const_proxy_type{element, shift};
} }
static inline data_type static inline data_type
op_read(const data_type* const data, op_read(const data_type* const data,
const size_type start, const size_t start,
const size_type nbits) { const size_t nbits) {
data_type value = 0; data_type value = 0;
for (size_type i = 0; i < nbits; i++) { for (size_t i = 0; i < nbits; i++) {
const auto proxy = get_proxy(data, start + i); const auto proxy = get_proxy(data, start + i);
value += proxy ? (data_type(1) << i) : 0; value += proxy ? (data_type(1) << i) : 0;
} }
@ -90,10 +88,10 @@ struct BitWiseBitsetPolicy {
static void static void
op_write(data_type* const data, op_write(data_type* const data,
const size_type start, const size_t start,
const size_type nbits, const size_t nbits,
const data_type value) { const data_type value) {
for (size_type i = 0; i < nbits; i++) { for (size_t i = 0; i < nbits; i++) {
auto proxy = get_proxy(data, start + i); auto proxy = get_proxy(data, start + i);
data_type mask = data_type(1) << i; data_type mask = data_type(1) << i;
if ((value & mask) == mask) { if ((value & mask) == mask) {
@ -105,10 +103,8 @@ struct BitWiseBitsetPolicy {
} }
static inline void static inline void
op_flip(data_type* const data, op_flip(data_type* const data, const size_t start, const size_t size) {
const size_type start, for (size_t i = 0; i < size; i++) {
const size_type size) {
for (size_type i = 0; i < size; i++) {
auto proxy = get_proxy(data, start + i); auto proxy = get_proxy(data, start + i);
proxy.flip(); proxy.flip();
} }
@ -122,7 +118,7 @@ struct BitWiseBitsetPolicy {
const size_t size) { const size_t size) {
// todo: check if intersect // todo: check if intersect
for (size_type i = 0; i < size; i++) { for (size_t i = 0; i < size; i++) {
auto proxy_left = get_proxy(left, start_left + i); auto proxy_left = get_proxy(left, start_left + i);
auto proxy_right = get_proxy(right, start_right + i); auto proxy_right = get_proxy(right, start_right + i);
@ -130,6 +126,27 @@ struct BitWiseBitsetPolicy {
} }
} }
static inline void
op_and_multiple(data_type* const left,
const data_type* const* const rights,
const size_t start_left,
const size_t* const __restrict start_rights,
const size_t n_rights,
const size_t size) {
for (size_t i = 0; i < size; i++) {
auto proxy_left = get_proxy(left, start_left + i);
bool value = proxy_left;
for (size_t j = 0; j < n_rights; j++) {
auto proxy_right = get_proxy(rights[j], start_rights[j] + i);
value &= proxy_right;
}
proxy_left = value;
}
}
static inline void static inline void
op_or(data_type* const left, op_or(data_type* const left,
const data_type* const right, const data_type* const right,
@ -138,7 +155,7 @@ struct BitWiseBitsetPolicy {
const size_t size) { const size_t size) {
// todo: check if intersect // todo: check if intersect
for (size_type i = 0; i < size; i++) { for (size_t i = 0; i < size; i++) {
auto proxy_left = get_proxy(left, start_left + i); auto proxy_left = get_proxy(left, start_left + i);
auto proxy_right = get_proxy(right, start_right + i); auto proxy_right = get_proxy(right, start_right + i);
@ -147,26 +164,43 @@ struct BitWiseBitsetPolicy {
} }
static inline void static inline void
op_set(data_type* const data, const size_type start, const size_type size) { op_or_multiple(data_type* const left,
for (size_type i = 0; i < size; i++) { const data_type* const* const rights,
const size_t start_left,
const size_t* const __restrict start_rights,
const size_t n_rights,
const size_t size) {
for (size_t i = 0; i < size; i++) {
auto proxy_left = get_proxy(left, start_left + i);
bool value = proxy_left;
for (size_t j = 0; j < n_rights; j++) {
auto proxy_right = get_proxy(rights[j], start_rights[j] + i);
value |= proxy_right;
}
proxy_left = value;
}
}
static inline void
op_set(data_type* const data, const size_t start, const size_t size) {
for (size_t i = 0; i < size; i++) {
get_proxy(data, start + i) = true; get_proxy(data, start + i) = true;
} }
} }
static inline void static inline void
op_reset(data_type* const data, op_reset(data_type* const data, const size_t start, const size_t size) {
const size_type start, for (size_t i = 0; i < size; i++) {
const size_type size) {
for (size_type i = 0; i < size; i++) {
get_proxy(data, start + i) = false; get_proxy(data, start + i) = false;
} }
} }
static inline bool static inline bool
op_all(const data_type* const data, op_all(const data_type* const data, const size_t start, const size_t size) {
const size_type start, for (size_t i = 0; i < size; i++) {
const size_type size) {
for (size_type i = 0; i < size; i++) {
if (!get_proxy(data, start + i)) { if (!get_proxy(data, start + i)) {
return false; return false;
} }
@ -177,9 +211,9 @@ struct BitWiseBitsetPolicy {
static inline bool static inline bool
op_none(const data_type* const data, op_none(const data_type* const data,
const size_type start, const size_t start,
const size_type size) { const size_t size) {
for (size_type i = 0; i < size; i++) { for (size_t i = 0; i < size; i++) {
if (get_proxy(data, start + i)) { if (get_proxy(data, start + i)) {
return false; return false;
} }
@ -190,11 +224,11 @@ struct BitWiseBitsetPolicy {
static void static void
op_copy(const data_type* const src, op_copy(const data_type* const src,
const size_type start_src, const size_t start_src,
data_type* const dst, data_type* const dst,
const size_type start_dst, const size_t start_dst,
const size_type size) { const size_t size) {
for (size_type i = 0; i < size; i++) { for (size_t i = 0; i < size; i++) {
const auto src_p = get_proxy(src, start_src + i); const auto src_p = get_proxy(src, start_src + i);
auto dst_p = get_proxy(dst, start_dst + i); auto dst_p = get_proxy(dst, start_dst + i);
dst_p = src_p.operator bool(); dst_p = src_p.operator bool();
@ -203,22 +237,22 @@ struct BitWiseBitsetPolicy {
static void static void
op_fill(data_type* const dst, op_fill(data_type* const dst,
const size_type start_dst, const size_t start_dst,
const size_type size, const size_t size,
const bool value) { const bool value) {
for (size_type i = 0; i < size; i++) { for (size_t i = 0; i < size; i++) {
auto dst_p = get_proxy(dst, start_dst + i); auto dst_p = get_proxy(dst, start_dst + i);
dst_p = value; dst_p = value;
} }
} }
static inline size_type static inline size_t
op_count(const data_type* const data, op_count(const data_type* const data,
const size_type start, const size_t start,
const size_type size) { const size_t size) {
size_type count = 0; size_t count = 0;
for (size_type i = 0; i < size; i++) { for (size_t i = 0; i < size; i++) {
auto proxy = get_proxy(data, start + i); auto proxy = get_proxy(data, start + i);
count += (proxy) ? 1 : 0; count += (proxy) ? 1 : 0;
} }
@ -232,7 +266,7 @@ struct BitWiseBitsetPolicy {
const size_t start_left, const size_t start_left,
const size_t start_right, const size_t start_right,
const size_t size) { const size_t size) {
for (size_type i = 0; i < size; i++) { for (size_t i = 0; i < size; i++) {
const auto proxy_left = get_proxy(left, start_left + i); const auto proxy_left = get_proxy(left, start_left + i);
const auto proxy_right = get_proxy(right, start_right + i); const auto proxy_right = get_proxy(right, start_right + i);
@ -252,7 +286,7 @@ struct BitWiseBitsetPolicy {
const size_t size) { const size_t size) {
// todo: check if intersect // todo: check if intersect
for (size_type i = 0; i < size; i++) { for (size_t i = 0; i < size; i++) {
auto proxy_left = get_proxy(left, start_left + i); auto proxy_left = get_proxy(left, start_left + i);
const auto proxy_right = get_proxy(right, start_right + i); const auto proxy_right = get_proxy(right, start_right + i);
@ -268,7 +302,7 @@ struct BitWiseBitsetPolicy {
const size_t size) { const size_t size) {
// todo: check if intersect // todo: check if intersect
for (size_type i = 0; i < size; i++) { for (size_t i = 0; i < size; i++) {
auto proxy_left = get_proxy(left, start_left + i); auto proxy_left = get_proxy(left, start_left + i);
const auto proxy_right = get_proxy(right, start_right + i); const auto proxy_right = get_proxy(right, start_right + i);
@ -277,12 +311,12 @@ struct BitWiseBitsetPolicy {
} }
// //
static inline std::optional<size_type> static inline std::optional<size_t>
op_find(const data_type* const data, op_find(const data_type* const data,
const size_type start, const size_t start,
const size_type size, const size_t size,
const size_type starting_idx) { const size_t starting_idx) {
for (size_type i = starting_idx; i < size; i++) { for (size_t i = starting_idx; i < size; i++) {
const auto proxy = get_proxy(data, start + i); const auto proxy = get_proxy(data, start + i);
if (proxy) { if (proxy) {
return i; return i;
@ -296,11 +330,11 @@ struct BitWiseBitsetPolicy {
template <typename T, typename U, CompareOpType Op> template <typename T, typename U, CompareOpType Op>
static inline void static inline void
op_compare_column(data_type* const __restrict data, op_compare_column(data_type* const __restrict data,
const size_type start, const size_t start,
const T* const __restrict t, const T* const __restrict t,
const U* const __restrict u, const U* const __restrict u,
const size_type size) { const size_t size) {
for (size_type i = 0; i < size; i++) { for (size_t i = 0; i < size; i++) {
get_proxy(data, start + i) = get_proxy(data, start + i) =
CompareOperator<Op>::compare(t[i], u[i]); CompareOperator<Op>::compare(t[i], u[i]);
} }
@ -310,11 +344,11 @@ struct BitWiseBitsetPolicy {
template <typename T, CompareOpType Op> template <typename T, CompareOpType Op>
static inline void static inline void
op_compare_val(data_type* const __restrict data, op_compare_val(data_type* const __restrict data,
const size_type start, const size_t start,
const T* const __restrict t, const T* const __restrict t,
const size_type size, const size_t size,
const T& value) { const T& value) {
for (size_type i = 0; i < size; i++) { for (size_t i = 0; i < size; i++) {
get_proxy(data, start + i) = get_proxy(data, start + i) =
CompareOperator<Op>::compare(t[i], value); CompareOperator<Op>::compare(t[i], value);
} }
@ -323,12 +357,12 @@ struct BitWiseBitsetPolicy {
template <typename T, RangeType Op> template <typename T, RangeType Op>
static inline void static inline void
op_within_range_column(data_type* const __restrict data, op_within_range_column(data_type* const __restrict data,
const size_type start, const size_t start,
const T* const __restrict lower, const T* const __restrict lower,
const T* const __restrict upper, const T* const __restrict upper,
const T* const __restrict values, const T* const __restrict values,
const size_type size) { const size_t size) {
for (size_type i = 0; i < size; i++) { for (size_t i = 0; i < size; i++) {
get_proxy(data, start + i) = get_proxy(data, start + i) =
RangeOperator<Op>::within_range(lower[i], upper[i], values[i]); RangeOperator<Op>::within_range(lower[i], upper[i], values[i]);
} }
@ -338,12 +372,12 @@ struct BitWiseBitsetPolicy {
template <typename T, RangeType Op> template <typename T, RangeType Op>
static inline void static inline void
op_within_range_val(data_type* const __restrict data, op_within_range_val(data_type* const __restrict data,
const size_type start, const size_t start,
const T& lower, const T& lower,
const T& upper, const T& upper,
const T* const __restrict values, const T* const __restrict values,
const size_type size) { const size_t size) {
for (size_type i = 0; i < size; i++) { for (size_t i = 0; i < size; i++) {
get_proxy(data, start + i) = get_proxy(data, start + i) =
RangeOperator<Op>::within_range(lower, upper, values[i]); RangeOperator<Op>::within_range(lower, upper, values[i]);
} }
@ -353,12 +387,12 @@ struct BitWiseBitsetPolicy {
template <typename T, ArithOpType AOp, CompareOpType CmpOp> template <typename T, ArithOpType AOp, CompareOpType CmpOp>
static inline void static inline void
op_arith_compare(data_type* const __restrict data, op_arith_compare(data_type* const __restrict data,
const size_type start, const size_t start,
const T* const __restrict src, const T* const __restrict src,
const ArithHighPrecisionType<T>& right_operand, const ArithHighPrecisionType<T>& right_operand,
const ArithHighPrecisionType<T>& value, const ArithHighPrecisionType<T>& value,
const size_type size) { const size_t size) {
for (size_type i = 0; i < size; i++) { for (size_t i = 0; i < size; i++) {
get_proxy(data, start + i) = get_proxy(data, start + i) =
ArithCompareOperator<AOp, CmpOp>::compare( ArithCompareOperator<AOp, CmpOp>::compare(
src[i], right_operand, value); src[i], right_operand, value);
@ -375,7 +409,7 @@ struct BitWiseBitsetPolicy {
// todo: check if intersect // todo: check if intersect
size_t active = 0; size_t active = 0;
for (size_type i = 0; i < size; i++) { for (size_t i = 0; i < size; i++) {
auto proxy_left = get_proxy(left, start_left + i); auto proxy_left = get_proxy(left, start_left + i);
auto proxy_right = get_proxy(right, start_right + i); auto proxy_right = get_proxy(right, start_right + i);
@ -397,7 +431,7 @@ struct BitWiseBitsetPolicy {
// todo: check if intersect // todo: check if intersect
size_t inactive = 0; size_t inactive = 0;
for (size_type i = 0; i < size; i++) { for (size_t i = 0; i < size; i++) {
auto proxy_left = get_proxy(left, start_left + i); auto proxy_left = get_proxy(left, start_left + i);
auto proxy_right = get_proxy(right, start_right + i); auto proxy_right = get_proxy(right, start_right + i);

View File

@ -32,53 +32,49 @@ namespace detail {
template <typename ElementT, typename VectorizedT> template <typename ElementT, typename VectorizedT>
struct VectorizedElementWiseBitsetPolicy { struct VectorizedElementWiseBitsetPolicy {
using data_type = ElementT; using data_type = ElementT;
constexpr static auto data_bits = sizeof(data_type) * 8; constexpr static size_t data_bits = sizeof(data_type) * 8;
using size_type = size_t;
using self_type = VectorizedElementWiseBitsetPolicy<ElementT, VectorizedT>; using self_type = VectorizedElementWiseBitsetPolicy<ElementT, VectorizedT>;
using proxy_type = Proxy<self_type>; using proxy_type = Proxy<self_type>;
using const_proxy_type = ConstProxy<self_type>; using const_proxy_type = ConstProxy<self_type>;
static inline size_type static inline size_t
get_element(const size_t idx) { get_element(const size_t idx) {
return idx / data_bits; return idx / data_bits;
} }
static inline size_type static inline size_t
get_shift(const size_t idx) { get_shift(const size_t idx) {
return idx % data_bits; return idx % data_bits;
} }
static inline size_type static inline size_t
get_required_size_in_elements(const size_t size) { get_required_size_in_elements(const size_t size) {
return (size + data_bits - 1) / data_bits; return (size + data_bits - 1) / data_bits;
} }
static inline size_type static inline size_t
get_required_size_in_bytes(const size_t size) { get_required_size_in_bytes(const size_t size) {
return get_required_size_in_elements(size) * sizeof(data_type); return get_required_size_in_elements(size) * sizeof(data_type);
} }
static inline proxy_type static inline proxy_type
get_proxy(data_type* const __restrict data, const size_type idx) { get_proxy(data_type* const __restrict data, const size_t idx) {
data_type& element = data[get_element(idx)]; data_type& element = data[get_element(idx)];
const size_type shift = get_shift(idx); const size_t shift = get_shift(idx);
return proxy_type{element, shift}; return proxy_type{element, shift};
} }
static inline const_proxy_type static inline const_proxy_type
get_proxy(const data_type* const __restrict data, const size_type idx) { get_proxy(const data_type* const __restrict data, const size_t idx) {
const data_type& element = data[get_element(idx)]; const data_type& element = data[get_element(idx)];
const size_type shift = get_shift(idx); const size_t shift = get_shift(idx);
return const_proxy_type{element, shift}; return const_proxy_type{element, shift};
} }
static inline void static inline void
op_flip(data_type* const data, op_flip(data_type* const data, const size_t start, const size_t size) {
const size_type start,
const size_type size) {
ElementWiseBitsetPolicy<ElementT>::op_flip(data, start, size); ElementWiseBitsetPolicy<ElementT>::op_flip(data, start, size);
} }
@ -88,8 +84,25 @@ struct VectorizedElementWiseBitsetPolicy {
const size_t start_left, const size_t start_left,
const size_t start_right, const size_t start_right,
const size_t size) { const size_t size) {
ElementWiseBitsetPolicy<ElementT>::op_and( if (!VectorizedT::template forward_op_and<ElementT>(
left, right, start_left, start_right, size); left, right, start_left, start_right, size)) {
ElementWiseBitsetPolicy<ElementT>::op_and(
left, right, start_left, start_right, size);
}
}
static inline void
op_and_multiple(data_type* const left,
const data_type* const* const rights,
const size_t start_left,
const size_t* const __restrict start_rights,
const size_t n_rights,
const size_t size) {
if (!VectorizedT::template forward_op_and_multiple<ElementT>(
left, rights, start_left, start_rights, n_rights, size)) {
ElementWiseBitsetPolicy<ElementT>::op_and_multiple(
left, rights, start_left, start_rights, n_rights, size);
}
} }
static inline void static inline void
@ -98,59 +111,72 @@ struct VectorizedElementWiseBitsetPolicy {
const size_t start_left, const size_t start_left,
const size_t start_right, const size_t start_right,
const size_t size) { const size_t size) {
ElementWiseBitsetPolicy<ElementT>::op_or( if (!VectorizedT::template forward_op_or<ElementT>(
left, right, start_left, start_right, size); left, right, start_left, start_right, size)) {
ElementWiseBitsetPolicy<ElementT>::op_or(
left, right, start_left, start_right, size);
}
} }
static inline void static inline void
op_set(data_type* const data, const size_type start, const size_type size) { op_or_multiple(data_type* const left,
const data_type* const* const rights,
const size_t start_left,
const size_t* const __restrict start_rights,
const size_t n_rights,
const size_t size) {
if (!VectorizedT::template forward_op_or_multiple<ElementT>(
left, rights, start_left, start_rights, n_rights, size)) {
ElementWiseBitsetPolicy<ElementT>::op_or_multiple(
left, rights, start_left, start_rights, n_rights, size);
}
}
static inline void
op_set(data_type* const data, const size_t start, const size_t size) {
ElementWiseBitsetPolicy<ElementT>::op_set(data, start, size); ElementWiseBitsetPolicy<ElementT>::op_set(data, start, size);
} }
static inline void static inline void
op_reset(data_type* const data, op_reset(data_type* const data, const size_t start, const size_t size) {
const size_type start,
const size_type size) {
ElementWiseBitsetPolicy<ElementT>::op_reset(data, start, size); ElementWiseBitsetPolicy<ElementT>::op_reset(data, start, size);
} }
static inline bool static inline bool
op_all(const data_type* const data, op_all(const data_type* const data, const size_t start, const size_t size) {
const size_type start,
const size_type size) {
return ElementWiseBitsetPolicy<ElementT>::op_all(data, start, size); return ElementWiseBitsetPolicy<ElementT>::op_all(data, start, size);
} }
static inline bool static inline bool
op_none(const data_type* const data, op_none(const data_type* const data,
const size_type start, const size_t start,
const size_type size) { const size_t size) {
return ElementWiseBitsetPolicy<ElementT>::op_none(data, start, size); return ElementWiseBitsetPolicy<ElementT>::op_none(data, start, size);
} }
static void static void
op_copy(const data_type* const src, op_copy(const data_type* const src,
const size_type start_src, const size_t start_src,
data_type* const dst, data_type* const dst,
const size_type start_dst, const size_t start_dst,
const size_type size) { const size_t size) {
ElementWiseBitsetPolicy<ElementT>::op_copy( ElementWiseBitsetPolicy<ElementT>::op_copy(
src, start_src, dst, start_dst, size); src, start_src, dst, start_dst, size);
} }
static inline size_type static inline size_t
op_count(const data_type* const data, op_count(const data_type* const data,
const size_type start, const size_t start,
const size_type size) { const size_t size) {
return ElementWiseBitsetPolicy<ElementT>::op_count(data, start, size); return ElementWiseBitsetPolicy<ElementT>::op_count(data, start, size);
} }
static inline bool static inline bool
op_eq(const data_type* const left, op_eq(const data_type* const left,
const data_type* const right, const data_type* const right,
const size_type start_left, const size_t start_left,
const size_type start_right, const size_t start_right,
const size_type size) { const size_t size) {
return ElementWiseBitsetPolicy<ElementT>::op_eq( return ElementWiseBitsetPolicy<ElementT>::op_eq(
left, right, start_left, start_right, size); left, right, start_left, start_right, size);
} }
@ -161,8 +187,11 @@ struct VectorizedElementWiseBitsetPolicy {
const size_t start_left, const size_t start_left,
const size_t start_right, const size_t start_right,
const size_t size) { const size_t size) {
ElementWiseBitsetPolicy<ElementT>::op_xor( if (!VectorizedT::template forward_op_xor<ElementT>(
left, right, start_left, start_right, size); left, right, start_left, start_right, size)) {
ElementWiseBitsetPolicy<ElementT>::op_xor(
left, right, start_left, start_right, size);
}
} }
static inline void static inline void
@ -171,24 +200,27 @@ struct VectorizedElementWiseBitsetPolicy {
const size_t start_left, const size_t start_left,
const size_t start_right, const size_t start_right,
const size_t size) { const size_t size) {
ElementWiseBitsetPolicy<ElementT>::op_sub( if (!VectorizedT::template forward_op_sub<ElementT>(
left, right, start_left, start_right, size); left, right, start_left, start_right, size)) {
ElementWiseBitsetPolicy<ElementT>::op_sub(
left, right, start_left, start_right, size);
}
} }
static void static void
op_fill(data_type* const data, op_fill(data_type* const data,
const size_type start, const size_t start,
const size_type size, const size_t size,
const bool value) { const bool value) {
ElementWiseBitsetPolicy<ElementT>::op_fill(data, start, size, value); ElementWiseBitsetPolicy<ElementT>::op_fill(data, start, size, value);
} }
// //
static inline std::optional<size_type> static inline std::optional<size_t>
op_find(const data_type* const data, op_find(const data_type* const data,
const size_type start, const size_t start,
const size_type size, const size_t size,
const size_type starting_idx) { const size_t starting_idx) {
return ElementWiseBitsetPolicy<ElementT>::op_find( return ElementWiseBitsetPolicy<ElementT>::op_find(
data, start, size, starting_idx); data, start, size, starting_idx);
} }
@ -197,16 +229,16 @@ struct VectorizedElementWiseBitsetPolicy {
template <typename T, typename U, CompareOpType Op> template <typename T, typename U, CompareOpType Op>
static inline void static inline void
op_compare_column(data_type* const __restrict data, op_compare_column(data_type* const __restrict data,
const size_type start, const size_t start,
const T* const __restrict t, const T* const __restrict t,
const U* const __restrict u, const U* const __restrict u,
const size_type size) { const size_t size) {
op_func( op_func(
start, start,
size, size,
[data, t, u](const size_type starting_bit, [data, t, u](const size_t starting_bit,
const size_type ptr_offset, const size_t ptr_offset,
const size_type nbits) { const size_t nbits) {
ElementWiseBitsetPolicy<ElementT>:: ElementWiseBitsetPolicy<ElementT>::
template op_compare_column<T, U, Op>(data, template op_compare_column<T, U, Op>(data,
starting_bit, starting_bit,
@ -214,9 +246,9 @@ struct VectorizedElementWiseBitsetPolicy {
u + ptr_offset, u + ptr_offset,
nbits); nbits);
}, },
[data, t, u](const size_type starting_element, [data, t, u](const size_t starting_element,
const size_type ptr_offset, const size_t ptr_offset,
const size_type nbits) { const size_t nbits) {
return VectorizedT::template op_compare_column<T, U, Op>( return VectorizedT::template op_compare_column<T, U, Op>(
reinterpret_cast<uint8_t*>(data + starting_element), reinterpret_cast<uint8_t*>(data + starting_element),
t + ptr_offset, t + ptr_offset,
@ -229,23 +261,23 @@ struct VectorizedElementWiseBitsetPolicy {
template <typename T, CompareOpType Op> template <typename T, CompareOpType Op>
static inline void static inline void
op_compare_val(data_type* const __restrict data, op_compare_val(data_type* const __restrict data,
const size_type start, const size_t start,
const T* const __restrict t, const T* const __restrict t,
const size_type size, const size_t size,
const T& value) { const T& value) {
op_func( op_func(
start, start,
size, size,
[data, t, value](const size_type starting_bit, [data, t, value](const size_t starting_bit,
const size_type ptr_offset, const size_t ptr_offset,
const size_type nbits) { const size_t nbits) {
ElementWiseBitsetPolicy<ElementT>::template op_compare_val<T, ElementWiseBitsetPolicy<ElementT>::template op_compare_val<T,
Op>( Op>(
data, starting_bit, t + ptr_offset, nbits, value); data, starting_bit, t + ptr_offset, nbits, value);
}, },
[data, t, value](const size_type starting_element, [data, t, value](const size_t starting_element,
const size_type ptr_offset, const size_t ptr_offset,
const size_type nbits) { const size_t nbits) {
return VectorizedT::template op_compare_val<T, Op>( return VectorizedT::template op_compare_val<T, Op>(
reinterpret_cast<uint8_t*>(data + starting_element), reinterpret_cast<uint8_t*>(data + starting_element),
t + ptr_offset, t + ptr_offset,
@ -258,17 +290,17 @@ struct VectorizedElementWiseBitsetPolicy {
template <typename T, RangeType Op> template <typename T, RangeType Op>
static inline void static inline void
op_within_range_column(data_type* const __restrict data, op_within_range_column(data_type* const __restrict data,
const size_type start, const size_t start,
const T* const __restrict lower, const T* const __restrict lower,
const T* const __restrict upper, const T* const __restrict upper,
const T* const __restrict values, const T* const __restrict values,
const size_type size) { const size_t size) {
op_func( op_func(
start, start,
size, size,
[data, lower, upper, values](const size_type starting_bit, [data, lower, upper, values](const size_t starting_bit,
const size_type ptr_offset, const size_t ptr_offset,
const size_type nbits) { const size_t nbits) {
ElementWiseBitsetPolicy<ElementT>:: ElementWiseBitsetPolicy<ElementT>::
template op_within_range_column<T, Op>(data, template op_within_range_column<T, Op>(data,
starting_bit, starting_bit,
@ -277,9 +309,9 @@ struct VectorizedElementWiseBitsetPolicy {
values + ptr_offset, values + ptr_offset,
nbits); nbits);
}, },
[data, lower, upper, values](const size_type starting_element, [data, lower, upper, values](const size_t starting_element,
const size_type ptr_offset, const size_t ptr_offset,
const size_type nbits) { const size_t nbits) {
return VectorizedT::template op_within_range_column<T, Op>( return VectorizedT::template op_within_range_column<T, Op>(
reinterpret_cast<uint8_t*>(data + starting_element), reinterpret_cast<uint8_t*>(data + starting_element),
lower + ptr_offset, lower + ptr_offset,
@ -293,17 +325,17 @@ struct VectorizedElementWiseBitsetPolicy {
template <typename T, RangeType Op> template <typename T, RangeType Op>
static inline void static inline void
op_within_range_val(data_type* const __restrict data, op_within_range_val(data_type* const __restrict data,
const size_type start, const size_t start,
const T& lower, const T& lower,
const T& upper, const T& upper,
const T* const __restrict values, const T* const __restrict values,
const size_type size) { const size_t size) {
op_func( op_func(
start, start,
size, size,
[data, lower, upper, values](const size_type starting_bit, [data, lower, upper, values](const size_t starting_bit,
const size_type ptr_offset, const size_t ptr_offset,
const size_type nbits) { const size_t nbits) {
ElementWiseBitsetPolicy<ElementT>:: ElementWiseBitsetPolicy<ElementT>::
template op_within_range_val<T, Op>(data, template op_within_range_val<T, Op>(data,
starting_bit, starting_bit,
@ -312,9 +344,9 @@ struct VectorizedElementWiseBitsetPolicy {
values + ptr_offset, values + ptr_offset,
nbits); nbits);
}, },
[data, lower, upper, values](const size_type starting_element, [data, lower, upper, values](const size_t starting_element,
const size_type ptr_offset, const size_t ptr_offset,
const size_type nbits) { const size_t nbits) {
return VectorizedT::template op_within_range_val<T, Op>( return VectorizedT::template op_within_range_val<T, Op>(
reinterpret_cast<uint8_t*>(data + starting_element), reinterpret_cast<uint8_t*>(data + starting_element),
lower, lower,
@ -328,17 +360,17 @@ struct VectorizedElementWiseBitsetPolicy {
template <typename T, ArithOpType AOp, CompareOpType CmpOp> template <typename T, ArithOpType AOp, CompareOpType CmpOp>
static inline void static inline void
op_arith_compare(data_type* const __restrict data, op_arith_compare(data_type* const __restrict data,
const size_type start, const size_t start,
const T* const __restrict src, const T* const __restrict src,
const ArithHighPrecisionType<T>& right_operand, const ArithHighPrecisionType<T>& right_operand,
const ArithHighPrecisionType<T>& value, const ArithHighPrecisionType<T>& value,
const size_type size) { const size_t size) {
op_func( op_func(
start, start,
size, size,
[data, src, right_operand, value](const size_type starting_bit, [data, src, right_operand, value](const size_t starting_bit,
const size_type ptr_offset, const size_t ptr_offset,
const size_type nbits) { const size_t nbits) {
ElementWiseBitsetPolicy<ElementT>:: ElementWiseBitsetPolicy<ElementT>::
template op_arith_compare<T, AOp, CmpOp>(data, template op_arith_compare<T, AOp, CmpOp>(data,
starting_bit, starting_bit,
@ -347,9 +379,9 @@ struct VectorizedElementWiseBitsetPolicy {
value, value,
nbits); nbits);
}, },
[data, src, right_operand, value](const size_type starting_element, [data, src, right_operand, value](const size_t starting_element,
const size_type ptr_offset, const size_t ptr_offset,
const size_type nbits) { const size_t nbits) {
return VectorizedT::template op_arith_compare<T, AOp, CmpOp>( return VectorizedT::template op_arith_compare<T, AOp, CmpOp>(
reinterpret_cast<uint8_t*>(data + starting_element), reinterpret_cast<uint8_t*>(data + starting_element),
src + ptr_offset, src + ptr_offset,
@ -380,12 +412,12 @@ struct VectorizedElementWiseBitsetPolicy {
left, right, start_left, start_right, size); left, right, start_left, start_right, size);
} }
// void FuncBaseline(const size_t starting_bit, const size_type ptr_offset, const size_type nbits) // void FuncBaseline(const size_t starting_bit, const size_t ptr_offset, const size_t nbits)
// bool FuncVectorized(const size_type starting_element, const size_type ptr_offset, const size_type nbits) // bool FuncVectorized(const size_t starting_element, const size_t ptr_offset, const size_t nbits)
template <typename FuncBaseline, typename FuncVectorized> template <typename FuncBaseline, typename FuncVectorized>
static inline void static inline void
op_func(const size_type start, op_func(const size_t start,
const size_type size, const size_t size,
FuncBaseline func_baseline, FuncBaseline func_baseline,
FuncVectorized func_vectorized) { FuncVectorized func_vectorized) {
if (size == 0) { if (size == 0) {

View File

@ -26,6 +26,8 @@
#include "popcount.h" #include "popcount.h"
#include "bitset/common.h" #include "bitset/common.h"
#include "maybe_vector.h"
namespace milvus { namespace milvus {
namespace bitset { namespace bitset {
namespace detail { namespace detail {
@ -34,53 +36,51 @@ namespace detail {
template <typename ElementT> template <typename ElementT>
struct ElementWiseBitsetPolicy { struct ElementWiseBitsetPolicy {
using data_type = ElementT; using data_type = ElementT;
constexpr static auto data_bits = sizeof(data_type) * 8; constexpr static size_t data_bits = sizeof(data_type) * 8;
using size_type = size_t;
using self_type = ElementWiseBitsetPolicy<ElementT>; using self_type = ElementWiseBitsetPolicy<ElementT>;
using proxy_type = Proxy<self_type>; using proxy_type = Proxy<self_type>;
using const_proxy_type = ConstProxy<self_type>; using const_proxy_type = ConstProxy<self_type>;
static inline size_type static inline size_t
get_element(const size_t idx) { get_element(const size_t idx) {
return idx / data_bits; return idx / data_bits;
} }
static inline size_type static inline size_t
get_shift(const size_t idx) { get_shift(const size_t idx) {
return idx % data_bits; return idx % data_bits;
} }
static inline size_type static inline size_t
get_required_size_in_elements(const size_t size) { get_required_size_in_elements(const size_t size) {
return (size + data_bits - 1) / data_bits; return (size + data_bits - 1) / data_bits;
} }
static inline size_type static inline size_t
get_required_size_in_bytes(const size_t size) { get_required_size_in_bytes(const size_t size) {
return get_required_size_in_elements(size) * sizeof(data_type); return get_required_size_in_elements(size) * sizeof(data_type);
} }
static inline proxy_type static inline proxy_type
get_proxy(data_type* const __restrict data, const size_type idx) { get_proxy(data_type* const __restrict data, const size_t idx) {
data_type& element = data[get_element(idx)]; data_type& element = data[get_element(idx)];
const size_type shift = get_shift(idx); const size_t shift = get_shift(idx);
return proxy_type{element, shift}; return proxy_type{element, shift};
} }
static inline const_proxy_type static inline const_proxy_type
get_proxy(const data_type* const __restrict data, const size_type idx) { get_proxy(const data_type* const __restrict data, const size_t idx) {
const data_type& element = data[get_element(idx)]; const data_type& element = data[get_element(idx)];
const size_type shift = get_shift(idx); const size_t shift = get_shift(idx);
return const_proxy_type{element, shift}; return const_proxy_type{element, shift};
} }
static inline data_type static inline data_type
op_read(const data_type* const data, op_read(const data_type* const data,
const size_type start, const size_t start,
const size_type nbits) { const size_t nbits) {
if (nbits == 0) { if (nbits == 0) {
return 0; return 0;
} }
@ -121,8 +121,8 @@ struct ElementWiseBitsetPolicy {
static inline void static inline void
op_write(data_type* const data, op_write(data_type* const data,
const size_type start, const size_t start,
const size_type nbits, const size_t nbits,
const data_type value) { const data_type value) {
if (nbits == 0) { if (nbits == 0) {
return; return;
@ -169,9 +169,7 @@ struct ElementWiseBitsetPolicy {
} }
static inline void static inline void
op_flip(data_type* const data, op_flip(data_type* const data, const size_t start, const size_t size) {
const size_type start,
const size_type size) {
if (size == 0) { if (size == 0) {
return; return;
} }
@ -211,7 +209,7 @@ struct ElementWiseBitsetPolicy {
} }
// process the middle // process the middle
for (size_type i = start_element; i < end_element; i++) { for (size_t i = start_element; i < end_element; i++) {
data[i] = ~data[i]; data[i] = ~data[i];
} }
@ -228,7 +226,7 @@ struct ElementWiseBitsetPolicy {
} }
} }
static inline void static BITSET_ALWAYS_INLINE inline void
op_and(data_type* const left, op_and(data_type* const left,
const data_type* const right, const data_type* const right,
const size_t start_left, const size_t start_left,
@ -244,7 +242,25 @@ struct ElementWiseBitsetPolicy {
}); });
} }
static inline void static BITSET_ALWAYS_INLINE inline void
op_and_multiple(data_type* const left,
const data_type* const* const rights,
const size_t start_left,
const size_t* const __restrict start_rights,
const size_t n_rights,
const size_t size) {
op_func(left,
rights,
start_left,
start_rights,
n_rights,
size,
[](const data_type left_v, const data_type right_v) {
return left_v & right_v;
});
}
static BITSET_ALWAYS_INLINE inline void
op_or(data_type* const left, op_or(data_type* const left,
const data_type* const right, const data_type* const right,
const size_t start_left, const size_t start_left,
@ -260,8 +276,26 @@ struct ElementWiseBitsetPolicy {
}); });
} }
static BITSET_ALWAYS_INLINE inline void
op_or_multiple(data_type* const left,
const data_type* const* const rights,
const size_t start_left,
const size_t* const __restrict start_rights,
const size_t n_rights,
const size_t size) {
op_func(left,
rights,
start_left,
start_rights,
n_rights,
size,
[](const data_type left_v, const data_type right_v) {
return left_v | right_v;
});
}
static inline data_type static inline data_type
get_shift_mask_begin(const size_type shift) { get_shift_mask_begin(const size_t shift) {
// 0 -> 0b00000000 // 0 -> 0b00000000
// 1 -> 0b00000001 // 1 -> 0b00000001
// 2 -> 0b00000011 // 2 -> 0b00000011
@ -273,7 +307,7 @@ struct ElementWiseBitsetPolicy {
} }
static inline data_type static inline data_type
get_shift_mask_end(const size_type shift) { get_shift_mask_end(const size_t shift) {
// 0 -> 0b11111111 // 0 -> 0b11111111
// 1 -> 0b11111110 // 1 -> 0b11111110
// 2 -> 0b11111100 // 2 -> 0b11111100
@ -281,21 +315,17 @@ struct ElementWiseBitsetPolicy {
} }
static inline void static inline void
op_set(data_type* const data, const size_type start, const size_type size) { op_set(data_type* const data, const size_t start, const size_t size) {
op_fill(data, start, size, true); op_fill(data, start, size, true);
} }
static inline void static inline void
op_reset(data_type* const data, op_reset(data_type* const data, const size_t start, const size_t size) {
const size_type start,
const size_type size) {
op_fill(data, start, size, false); op_fill(data, start, size, false);
} }
static inline bool static inline bool
op_all(const data_type* const data, op_all(const data_type* const data, const size_t start, const size_t size) {
const size_type start,
const size_type size) {
if (size == 0) { if (size == 0) {
return true; return true;
} }
@ -329,7 +359,7 @@ struct ElementWiseBitsetPolicy {
} }
// process the middle // process the middle
for (size_type i = start_element; i < end_element; i++) { for (size_t i = start_element; i < end_element; i++) {
if (data[i] != data_type(-1)) { if (data[i] != data_type(-1)) {
return false; return false;
} }
@ -351,8 +381,8 @@ struct ElementWiseBitsetPolicy {
static inline bool static inline bool
op_none(const data_type* const data, op_none(const data_type* const data,
const size_type start, const size_t start,
const size_type size) { const size_t size) {
if (size == 0) { if (size == 0) {
return true; return true;
} }
@ -386,7 +416,7 @@ struct ElementWiseBitsetPolicy {
} }
// process the middle // process the middle
for (size_type i = start_element; i < end_element; i++) { for (size_t i = start_element; i < end_element; i++) {
if (data[i] != data_type(0)) { if (data[i] != data_type(0)) {
return false; return false;
} }
@ -408,27 +438,27 @@ struct ElementWiseBitsetPolicy {
static void static void
op_copy(const data_type* const src, op_copy(const data_type* const src,
const size_type start_src, const size_t start_src,
data_type* const dst, data_type* const dst,
const size_type start_dst, const size_t start_dst,
const size_type size) { const size_t size) {
if (size == 0) { if (size == 0) {
return; return;
} }
// process big blocks // process big blocks
const size_type size_b = (size / data_bits) * data_bits; const size_t size_b = (size / data_bits) * data_bits;
if ((start_src % data_bits) == 0) { if ((start_src % data_bits) == 0) {
if ((start_dst % data_bits) == 0) { if ((start_dst % data_bits) == 0) {
// plain memcpy // plain memcpy
for (size_type i = 0; i < size_b; i += data_bits) { for (size_t i = 0; i < size_b; i += data_bits) {
const data_type src_v = src[(start_src + i) / data_bits]; const data_type src_v = src[(start_src + i) / data_bits];
dst[(start_dst + i) / data_bits] = src_v; dst[(start_dst + i) / data_bits] = src_v;
} }
} else { } else {
// easier read // easier read
for (size_type i = 0; i < size_b; i += data_bits) { for (size_t i = 0; i < size_b; i += data_bits) {
const data_type src_v = src[(start_src + i) / data_bits]; const data_type src_v = src[(start_src + i) / data_bits];
op_write(dst, start_dst + i, data_bits, src_v); op_write(dst, start_dst + i, data_bits, src_v);
} }
@ -436,14 +466,14 @@ struct ElementWiseBitsetPolicy {
} else { } else {
if ((start_dst % data_bits) == 0) { if ((start_dst % data_bits) == 0) {
// easier write // easier write
for (size_type i = 0; i < size_b; i += data_bits) { for (size_t i = 0; i < size_b; i += data_bits) {
const data_type src_v = const data_type src_v =
op_read(src, start_src + i, data_bits); op_read(src, start_src + i, data_bits);
dst[(start_dst + i) / data_bits] = src_v; dst[(start_dst + i) / data_bits] = src_v;
} }
} else { } else {
// general case // general case
for (size_type i = 0; i < size_b; i += data_bits) { for (size_t i = 0; i < size_b; i += data_bits) {
const data_type src_v = const data_type src_v =
op_read(src, start_src + i, data_bits); op_read(src, start_src + i, data_bits);
op_write(dst, start_dst + i, data_bits, src_v); op_write(dst, start_dst + i, data_bits, src_v);
@ -461,8 +491,8 @@ struct ElementWiseBitsetPolicy {
static void static void
op_fill(data_type* const data, op_fill(data_type* const data,
const size_type start, const size_t start,
const size_type size, const size_t size,
const bool value) { const bool value) {
if (size == 0) { if (size == 0) {
return; return;
@ -504,7 +534,7 @@ struct ElementWiseBitsetPolicy {
} }
// process the middle // process the middle
for (size_type i = start_element; i < end_element; i++) { for (size_t i = start_element; i < end_element; i++) {
data[i] = new_v; data[i] = new_v;
} }
@ -520,15 +550,15 @@ struct ElementWiseBitsetPolicy {
} }
} }
static inline size_type static inline size_t
op_count(const data_type* const data, op_count(const data_type* const data,
const size_type start, const size_t start,
const size_type size) { const size_t size) {
if (size == 0) { if (size == 0) {
return 0; return 0;
} }
size_type count = 0; size_t count = 0;
auto start_element = get_element(start); auto start_element = get_element(start);
const auto end_element = get_element(start + size); const auto end_element = get_element(start + size);
@ -558,7 +588,7 @@ struct ElementWiseBitsetPolicy {
} }
// process the middle // process the middle
for (size_type i = start_element; i < end_element; i++) { for (size_t i = start_element; i < end_element; i++) {
count += PopCountHelper<data_type>::count(data[i]); count += PopCountHelper<data_type>::count(data[i]);
} }
@ -577,24 +607,23 @@ struct ElementWiseBitsetPolicy {
static inline bool static inline bool
op_eq(const data_type* const left, op_eq(const data_type* const left,
const data_type* const right, const data_type* const right,
const size_type start_left, const size_t start_left,
const size_type start_right, const size_t start_right,
const size_type size) { const size_t size) {
if (size == 0) { if (size == 0) {
return true; return true;
} }
// process big chunks // process big chunks
const size_type size_b = (size / data_bits) * data_bits; const size_t size_b = (size / data_bits) * data_bits;
if ((start_left % data_bits) == 0) { if ((start_left % data_bits) == 0) {
if ((start_right % data_bits) == 0) { if ((start_right % data_bits) == 0) {
// plain "memcpy" // plain "memcpy"
size_type start_left_idx = start_left / data_bits; size_t start_left_idx = start_left / data_bits;
size_type start_right_idx = start_right / data_bits; size_t start_right_idx = start_right / data_bits;
for (size_type i = 0, j = 0; i < size_b; for (size_t i = 0, j = 0; i < size_b; i += data_bits, j += 1) {
i += data_bits, j += 1) {
const data_type left_v = left[start_left_idx + j]; const data_type left_v = left[start_left_idx + j];
const data_type right_v = right[start_right_idx + j]; const data_type right_v = right[start_right_idx + j];
if (left_v != right_v) { if (left_v != right_v) {
@ -603,10 +632,9 @@ struct ElementWiseBitsetPolicy {
} }
} else { } else {
// easier left // easier left
size_type start_left_idx = start_left / data_bits; size_t start_left_idx = start_left / data_bits;
for (size_type i = 0, j = 0; i < size_b; for (size_t i = 0, j = 0; i < size_b; i += data_bits, j += 1) {
i += data_bits, j += 1) {
const data_type left_v = left[start_left_idx + j]; const data_type left_v = left[start_left_idx + j];
const data_type right_v = const data_type right_v =
op_read(right, start_right + i, data_bits); op_read(right, start_right + i, data_bits);
@ -618,10 +646,9 @@ struct ElementWiseBitsetPolicy {
} else { } else {
if ((start_right % data_bits) == 0) { if ((start_right % data_bits) == 0) {
// easier right // easier right
size_type start_right_idx = start_right / data_bits; size_t start_right_idx = start_right / data_bits;
for (size_type i = 0, j = 0; i < size_b; for (size_t i = 0, j = 0; i < size_b; i += data_bits, j += 1) {
i += data_bits, j += 1) {
const data_type left_v = const data_type left_v =
op_read(left, start_left + i, data_bits); op_read(left, start_left + i, data_bits);
const data_type right_v = right[start_right_idx + j]; const data_type right_v = right[start_right_idx + j];
@ -631,7 +658,7 @@ struct ElementWiseBitsetPolicy {
} }
} else { } else {
// general case // general case
for (size_type i = 0; i < size_b; i += data_bits) { for (size_t i = 0; i < size_b; i += data_bits) {
const data_type left_v = const data_type left_v =
op_read(left, start_left + i, data_bits); op_read(left, start_left + i, data_bits);
const data_type right_v = const data_type right_v =
@ -657,7 +684,7 @@ struct ElementWiseBitsetPolicy {
return true; return true;
} }
static inline void static BITSET_ALWAYS_INLINE inline void
op_xor(data_type* const left, op_xor(data_type* const left,
const data_type* const right, const data_type* const right,
const size_t start_left, const size_t start_left,
@ -673,7 +700,7 @@ struct ElementWiseBitsetPolicy {
}); });
} }
static inline void static BITSET_ALWAYS_INLINE inline void
op_sub(data_type* const left, op_sub(data_type* const left,
const data_type* const right, const data_type* const right,
const size_t start_left, const size_t start_left,
@ -690,11 +717,11 @@ struct ElementWiseBitsetPolicy {
} }
// //
static inline std::optional<size_type> static inline std::optional<size_t>
op_find(const data_type* const data, op_find(const data_type* const data,
const size_type start, const size_t start,
const size_type size, const size_t size,
const size_type starting_idx) { const size_t starting_idx) {
if (size == 0) { if (size == 0) {
return std::nullopt; return std::nullopt;
} }
@ -706,7 +733,7 @@ struct ElementWiseBitsetPolicy {
const auto start_shift = get_shift(start + starting_idx); const auto start_shift = get_shift(start + starting_idx);
const auto end_shift = get_shift(start + size); const auto end_shift = get_shift(start + size);
size_type extra_offset = 0; size_t extra_offset = 0;
// same element? // same element?
if (start_element == end_element) { if (start_element == end_element) {
@ -718,7 +745,7 @@ struct ElementWiseBitsetPolicy {
const data_type value = existing_v & existing_mask; const data_type value = existing_v & existing_mask;
if (value != 0) { if (value != 0) {
const auto ctz = CtzHelper<data_type>::ctz(value); const auto ctz = CtzHelper<data_type>::ctz(value);
return size_type(ctz) + start_element * data_bits - start; return size_t(ctz) + start_element * data_bits - start;
} else { } else {
return std::nullopt; return std::nullopt;
} }
@ -733,7 +760,7 @@ struct ElementWiseBitsetPolicy {
if (value != 0) { if (value != 0) {
const auto ctz = CtzHelper<data_type>::ctz(value) + const auto ctz = CtzHelper<data_type>::ctz(value) +
start_element * data_bits - start; start_element * data_bits - start;
return size_type(ctz); return size_t(ctz);
} }
start_element += 1; start_element += 1;
@ -741,11 +768,11 @@ struct ElementWiseBitsetPolicy {
} }
// process the middle // process the middle
for (size_type i = start_element; i < end_element; i++) { for (size_t i = start_element; i < end_element; i++) {
const data_type value = data[i]; const data_type value = data[i];
if (value != 0) { if (value != 0) {
const auto ctz = CtzHelper<data_type>::ctz(value); const auto ctz = CtzHelper<data_type>::ctz(value);
return size_type(ctz) + i * data_bits - start; return size_t(ctz) + i * data_bits - start;
} }
} }
@ -757,7 +784,7 @@ struct ElementWiseBitsetPolicy {
const data_type value = existing_v & existing_mask; const data_type value = existing_v & existing_mask;
if (value != 0) { if (value != 0) {
const auto ctz = CtzHelper<data_type>::ctz(value); const auto ctz = CtzHelper<data_type>::ctz(value);
return size_type(ctz) + end_element * data_bits - start; return size_t(ctz) + end_element * data_bits - start;
} }
} }
@ -768,11 +795,11 @@ struct ElementWiseBitsetPolicy {
template <typename T, typename U, CompareOpType Op> template <typename T, typename U, CompareOpType Op>
static inline void static inline void
op_compare_column(data_type* const __restrict data, op_compare_column(data_type* const __restrict data,
const size_type start, const size_t start,
const T* const __restrict t, const T* const __restrict t,
const U* const __restrict u, const U* const __restrict u,
const size_type size) { const size_t size) {
op_func(data, start, size, [t, u](const size_type bit_idx) { op_func(data, start, size, [t, u](const size_t bit_idx) {
return CompareOperator<Op>::compare(t[bit_idx], u[bit_idx]); return CompareOperator<Op>::compare(t[bit_idx], u[bit_idx]);
}); });
} }
@ -781,11 +808,11 @@ struct ElementWiseBitsetPolicy {
template <typename T, CompareOpType Op> template <typename T, CompareOpType Op>
static inline void static inline void
op_compare_val(data_type* const __restrict data, op_compare_val(data_type* const __restrict data,
const size_type start, const size_t start,
const T* const __restrict t, const T* const __restrict t,
const size_type size, const size_t size,
const T& value) { const T& value) {
op_func(data, start, size, [t, value](const size_type bit_idx) { op_func(data, start, size, [t, value](const size_t bit_idx) {
return CompareOperator<Op>::compare(t[bit_idx], value); return CompareOperator<Op>::compare(t[bit_idx], value);
}); });
} }
@ -794,13 +821,13 @@ struct ElementWiseBitsetPolicy {
template <typename T, RangeType Op> template <typename T, RangeType Op>
static inline void static inline void
op_within_range_column(data_type* const __restrict data, op_within_range_column(data_type* const __restrict data,
const size_type start, const size_t start,
const T* const __restrict lower, const T* const __restrict lower,
const T* const __restrict upper, const T* const __restrict upper,
const T* const __restrict values, const T* const __restrict values,
const size_type size) { const size_t size) {
op_func( op_func(
data, start, size, [lower, upper, values](const size_type bit_idx) { data, start, size, [lower, upper, values](const size_t bit_idx) {
return RangeOperator<Op>::within_range( return RangeOperator<Op>::within_range(
lower[bit_idx], upper[bit_idx], values[bit_idx]); lower[bit_idx], upper[bit_idx], values[bit_idx]);
}); });
@ -810,13 +837,13 @@ struct ElementWiseBitsetPolicy {
template <typename T, RangeType Op> template <typename T, RangeType Op>
static inline void static inline void
op_within_range_val(data_type* const __restrict data, op_within_range_val(data_type* const __restrict data,
const size_type start, const size_t start,
const T& lower, const T& lower,
const T& upper, const T& upper,
const T* const __restrict values, const T* const __restrict values,
const size_type size) { const size_t size) {
op_func( op_func(
data, start, size, [lower, upper, values](const size_type bit_idx) { data, start, size, [lower, upper, values](const size_t bit_idx) {
return RangeOperator<Op>::within_range( return RangeOperator<Op>::within_range(
lower, upper, values[bit_idx]); lower, upper, values[bit_idx]);
}); });
@ -826,15 +853,15 @@ struct ElementWiseBitsetPolicy {
template <typename T, ArithOpType AOp, CompareOpType CmpOp> template <typename T, ArithOpType AOp, CompareOpType CmpOp>
static inline void static inline void
op_arith_compare(data_type* const __restrict data, op_arith_compare(data_type* const __restrict data,
const size_type start, const size_t start,
const T* const __restrict src, const T* const __restrict src,
const ArithHighPrecisionType<T>& right_operand, const ArithHighPrecisionType<T>& right_operand,
const ArithHighPrecisionType<T>& value, const ArithHighPrecisionType<T>& value,
const size_type size) { const size_t size) {
op_func(data, op_func(data,
start, start,
size, size,
[src, right_operand, value](const size_type bit_idx) { [src, right_operand, value](const size_t bit_idx) {
return ArithCompareOperator<AOp, CmpOp>::compare( return ArithCompareOperator<AOp, CmpOp>::compare(
src[bit_idx], right_operand, value); src[bit_idx], right_operand, value);
}); });
@ -872,11 +899,14 @@ struct ElementWiseBitsetPolicy {
const size_t size) { const size_t size) {
size_t inactive = 0; size_t inactive = 0;
const size_t size_b = (size / data_bits) * data_bits;
// process bulk
op_func(left, op_func(left,
right, right,
start_left, start_left,
start_right, start_right,
size, size_b,
[&inactive](const data_type left_v, const data_type right_v) { [&inactive](const data_type left_v, const data_type right_v) {
const data_type result = left_v | right_v; const data_type result = left_v | right_v;
inactive += inactive +=
@ -885,12 +915,25 @@ struct ElementWiseBitsetPolicy {
return result; return result;
}); });
// process leftovers
if (size != size_b) {
const data_type left_v =
op_read(left, start_left + size_b, size - size_b);
const data_type right_v =
op_read(right, start_right + size_b, size - size_b);
const data_type result_v = left_v | right_v;
inactive +=
(size - size_b - PopCountHelper<data_type>::count(result_v));
op_write(left, start_left + size_b, size - size_b, result_v);
}
return inactive; return inactive;
} }
// data_type Func(const data_type left_v, const data_type right_v); // data_type Func(const data_type left_v, const data_type right_v);
template <typename Func> template <typename Func>
static inline void static BITSET_ALWAYS_INLINE inline void
op_func(data_type* const left, op_func(data_type* const left,
const data_type* const right, const data_type* const right,
const size_t start_left, const size_t start_left,
@ -902,16 +945,15 @@ struct ElementWiseBitsetPolicy {
} }
// process big blocks // process big blocks
const size_type size_b = (size / data_bits) * data_bits; const size_t size_b = (size / data_bits) * data_bits;
if ((start_left % data_bits) == 0) { if ((start_left % data_bits) == 0) {
if ((start_right % data_bits) == 0) { if ((start_right % data_bits) == 0) {
// plain "memcpy". // plain "memcpy".
// A compiler auto-vectorization is expected. // A compiler auto-vectorization is expected.
size_type start_left_idx = start_left / data_bits; size_t start_left_idx = start_left / data_bits;
size_type start_right_idx = start_right / data_bits; size_t start_right_idx = start_right / data_bits;
for (size_type i = 0, j = 0; i < size_b; for (size_t i = 0, j = 0; i < size_b; i += data_bits, j += 1) {
i += data_bits, j += 1) {
data_type& left_v = left[start_left_idx + j]; data_type& left_v = left[start_left_idx + j];
const data_type right_v = right[start_right_idx + j]; const data_type right_v = right[start_right_idx + j];
@ -920,25 +962,9 @@ struct ElementWiseBitsetPolicy {
} }
} else { } else {
// easier read // easier read
size_type start_right_idx = start_right / data_bits; size_t start_left_idx = start_left / data_bits;
for (size_type i = 0, j = 0; i < size_b; for (size_t i = 0, j = 0; i < size_b; i += data_bits, j += 1) {
i += data_bits, j += 1) {
const data_type left_v =
op_read(left, start_left + i, data_bits);
const data_type right_v = right[start_right_idx + j];
const data_type result_v = func(left_v, right_v);
op_write(left, start_right + i, data_bits, result_v);
}
}
} else {
if ((start_right % data_bits) == 0) {
// easier write
size_type start_left_idx = start_left / data_bits;
for (size_type i = 0, j = 0; i < size_b;
i += data_bits, j += 1) {
data_type& left_v = left[start_left_idx + j]; data_type& left_v = left[start_left_idx + j];
const data_type right_v = const data_type right_v =
op_read(right, start_right + i, data_bits); op_read(right, start_right + i, data_bits);
@ -946,16 +972,30 @@ struct ElementWiseBitsetPolicy {
const data_type result_v = func(left_v, right_v); const data_type result_v = func(left_v, right_v);
left_v = result_v; left_v = result_v;
} }
}
} else {
if ((start_right % data_bits) == 0) {
// easier write
size_t start_right_idx = start_right / data_bits;
for (size_t i = 0, j = 0; i < size_b; i += data_bits, j += 1) {
const data_type left_v =
op_read(left, start_left + i, data_bits);
const data_type right_v = right[start_right_idx + j];
const data_type result_v = func(left_v, right_v);
op_write(left, start_left + i, data_bits, result_v);
}
} else { } else {
// general case // general case
for (size_type i = 0; i < size_b; i += data_bits) { for (size_t i = 0; i < size_b; i += data_bits) {
const data_type left_v = const data_type left_v =
op_read(left, start_left + i, data_bits); op_read(left, start_left + i, data_bits);
const data_type right_v = const data_type right_v =
op_read(right, start_right + i, data_bits); op_read(right, start_right + i, data_bits);
const data_type result_v = func(left_v, right_v); const data_type result_v = func(left_v, right_v);
op_write(left, start_right + i, data_bits, result_v); op_write(left, start_left + i, data_bits, result_v);
} }
} }
} }
@ -972,11 +1012,145 @@ struct ElementWiseBitsetPolicy {
} }
} }
// bool Func(const size_type bit_idx); // data_type Func(const data_type left_v, const data_type right_v);
template <typename Func> template <typename Func>
static inline void static BITSET_ALWAYS_INLINE inline void
op_func(data_type* const left,
const data_type* const* const rights,
const size_t start_left,
const size_t* const __restrict start_rights,
const size_t n_rights,
const size_t size,
Func func) {
if (size == 0 || n_rights == 0) {
return;
}
if (n_rights == 1) {
op_func<Func>(
left, rights[0], start_left, start_rights[0], size, func);
return;
}
// process big blocks
const size_t size_b = (size / data_bits) * data_bits;
// check a specific case
bool all_aligned = true;
for (size_t i = 0; i < n_rights; i++) {
if (start_rights[i] % data_bits != 0) {
all_aligned = false;
break;
}
}
// all are aligned
if (all_aligned) {
MaybeVector<const data_type*> tmp(n_rights);
for (size_t i = 0; i < n_rights; i++) {
tmp[i] = rights[i] + (start_rights[i] / data_bits);
}
// plain "memcpy".
// A compiler auto-vectorization is expected.
const size_t start_left_idx = start_left / data_bits;
data_type* left_ptr = left + start_left_idx;
auto unrolled = [left_ptr, &tmp, func, size_b](const size_t count) {
for (size_t i = 0, j = 0; i < size_b; i += data_bits, j += 1) {
data_type& left_v = left_ptr[j];
data_type value = left_v;
for (size_t k = 0; k < count; k++) {
const data_type right_v = tmp[k][j];
value = func(value, right_v);
}
left_v = value;
}
};
switch (n_rights) {
// case 1: unrolled(1); break;
case 2:
unrolled(2);
break;
case 3:
unrolled(3);
break;
case 4:
unrolled(4);
break;
case 5:
unrolled(5);
break;
case 6:
unrolled(6);
break;
case 7:
unrolled(7);
break;
case 8:
unrolled(8);
break;
default: {
for (size_t i = 0, j = 0; i < size_b;
i += data_bits, j += 1) {
data_type& left_v = left_ptr[j];
data_type value = left_v;
for (size_t k = 0; k < n_rights; k++) {
const data_type right_v = tmp[k][j];
value = func(value, right_v);
}
left_v = value;
}
}
}
} else {
// general case. Unoptimized.
for (size_t i = 0; i < size_b; i += data_bits) {
const data_type left_v =
op_read(left, start_left + i, data_bits);
data_type value = left_v;
for (size_t k = 0; k < n_rights; k++) {
const data_type right_v =
op_read(rights[k], start_rights[k] + i, data_bits);
value = func(value, right_v);
}
op_write(left, start_left + i, data_bits, value);
}
}
// process leftovers
if (size_b != size) {
const data_type left_v =
op_read(left, start_left + size_b, size - size_b);
data_type value = left_v;
for (size_t k = 0; k < n_rights; k++) {
const data_type right_v =
op_read(rights[k], start_rights[k] + size_b, size - size_b);
value = func(value, right_v);
}
op_write(left, start_left + size_b, size - size_b, value);
}
}
// bool Func(const size_t bit_idx);
template <typename Func>
static BITSET_ALWAYS_INLINE inline void
op_func(data_type* const __restrict data, op_func(data_type* const __restrict data,
const size_type start, const size_t start,
const size_t size, const size_t size,
Func func) { Func func) {
if (size == 0) { if (size == 0) {
@ -991,7 +1165,7 @@ struct ElementWiseBitsetPolicy {
if (start_element == end_element) { if (start_element == end_element) {
data_type bits = 0; data_type bits = 0;
for (size_type j = 0; j < size; j++) { for (size_t j = 0; j < size; j++) {
const bool bit = func(j); const bool bit = func(j);
// // a curious example where the compiler does not optimize the code properly // // a curious example where the compiler does not optimize the code properly
// bits |= (bit ? (data_type(1) << j) : 0); // bits |= (bit ? (data_type(1) << j) : 0);
@ -1009,10 +1183,10 @@ struct ElementWiseBitsetPolicy {
// process the first element // process the first element
if (start_shift != 0) { if (start_shift != 0) {
const size_type n_bits = data_bits - start_shift; const size_t n_bits = data_bits - start_shift;
data_type bits = 0; data_type bits = 0;
for (size_type j = 0; j < n_bits; j++) { for (size_t j = 0; j < n_bits; j++) {
const bool bit = func(j); const bool bit = func(j);
bits |= (data_type(bit ? 1 : 0) << j); bits |= (data_type(bit ? 1 : 0) << j);
} }
@ -1026,9 +1200,9 @@ struct ElementWiseBitsetPolicy {
// process the middle // process the middle
{ {
for (size_type i = start_element; i < end_element; i++) { for (size_t i = start_element; i < end_element; i++) {
data_type bits = 0; data_type bits = 0;
for (size_type j = 0; j < data_bits; j++) { for (size_t j = 0; j < data_bits; j++) {
const bool bit = func(ptr_offset + j); const bool bit = func(ptr_offset + j);
bits |= (data_type(bit ? 1 : 0) << j); bits |= (data_type(bit ? 1 : 0) << j);
} }
@ -1041,7 +1215,7 @@ struct ElementWiseBitsetPolicy {
// process the last element // process the last element
if (end_shift != 0) { if (end_shift != 0) {
data_type bits = 0; data_type bits = 0;
for (size_type j = 0; j < end_shift; j++) { for (size_t j = 0; j < end_shift; j++) {
const bool bit = func(ptr_offset + j); const bool bit = func(ptr_offset + j);
bits |= (data_type(bit ? 1 : 0) << j); bits |= (data_type(bit ? 1 : 0) << j);
} }

View File

@ -0,0 +1,91 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <array>
#include <memory>
#include <type_traits>
namespace milvus {
namespace bitset {
namespace detail {
// A structure that allocates an array of elements.
// No ownership is implied.
// If the number of elements is small,
// then an allocation will be done on the stack.
// If the number of elements is large,
// then an allocation will be done on the heap.
template <typename T>
struct MaybeVector {
public:
static_assert(std::is_scalar_v<T>);
static constexpr size_t num_array_elements = 64;
std::unique_ptr<T[]> maybe_memory;
std::array<T, num_array_elements> maybe_array;
MaybeVector(const size_t n_elements) {
m_size = n_elements;
if (n_elements < num_array_elements) {
m_data = maybe_array.data();
} else {
maybe_memory = std::make_unique<T[]>(m_size);
m_data = maybe_memory.get();
}
}
MaybeVector(const MaybeVector&) = delete;
MaybeVector(MaybeVector&&) = delete;
MaybeVector&
operator=(const MaybeVector&) = delete;
MaybeVector&
operator=(MaybeVector&&) = delete;
inline size_t
size() const {
return m_size;
}
inline T*
data() {
return m_data;
}
inline const T*
data() const {
return m_data;
}
inline T&
operator[](const size_t idx) {
return m_data[idx];
}
inline const T&
operator[](const size_t idx) const {
return m_data[idx];
}
private:
size_t m_size = 0;
T* m_data = nullptr;
};
} // namespace detail
} // namespace bitset
} // namespace milvus

View File

@ -39,6 +39,11 @@ namespace neon {
FUNC(float); \ FUNC(float); \
FUNC(double); FUNC(double);
// a facility to run through all acceptable forward types
#define ALL_FORWARD_TYPES_1(FUNC) \
FUNC(uint8_t); \
FUNC(uint64_t);
/////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////
// the default implementation does nothing // the default implementation does nothing
@ -192,7 +197,122 @@ ALL_DATATYPES_1(DECLARE_PARTIAL_OP_ARITH_COMPARE)
/////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////
// forward ops
template <typename ElementT>
struct ForwardOpsImpl {
static inline bool
op_and(ElementT* const left,
const ElementT* const right,
const size_t start_left,
const size_t start_right,
const size_t size) {
return false;
}
static inline bool
op_and_multiple(ElementT* const left,
const ElementT* const* const rights,
const size_t start_left,
const size_t* const __restrict start_rights,
const size_t n_rights,
const size_t size) {
return false;
}
static inline bool
op_or(ElementT* const left,
const ElementT* const right,
const size_t start_left,
const size_t start_right,
const size_t size) {
return false;
}
static inline bool
op_or_multiple(ElementT* const left,
const ElementT* const* const rights,
const size_t start_left,
const size_t* const __restrict start_rights,
const size_t n_rights,
const size_t size) {
return false;
}
static inline bool
op_xor(ElementT* const left,
const ElementT* const right,
const size_t start_left,
const size_t start_right,
const size_t size) {
return false;
}
static inline bool
op_sub(ElementT* const left,
const ElementT* const right,
const size_t start_left,
const size_t start_right,
const size_t size) {
return false;
}
};
#define DECLARE_PARTIAL_FORWARD_OPS(ELEMENTTYPE) \
template <> \
struct ForwardOpsImpl<ELEMENTTYPE> { \
static bool \
op_and(ELEMENTTYPE* const left, \
const ELEMENTTYPE* const right, \
const size_t start_left, \
const size_t start_right, \
const size_t size); \
\
static bool \
op_and_multiple(ELEMENTTYPE* const left, \
const ELEMENTTYPE* const* const rights, \
const size_t start_left, \
const size_t* const __restrict start_rights, \
const size_t n_rights, \
const size_t size); \
\
static bool \
op_or(ELEMENTTYPE* const left, \
const ELEMENTTYPE* const right, \
const size_t start_left, \
const size_t start_right, \
const size_t size); \
\
static bool \
op_or_multiple(ELEMENTTYPE* const left, \
const ELEMENTTYPE* const* const rights, \
const size_t start_left, \
const size_t* const __restrict start_rights, \
const size_t n_rights, \
const size_t size); \
\
static bool \
op_sub(ELEMENTTYPE* const left, \
const ELEMENTTYPE* const right, \
const size_t start_left, \
const size_t start_right, \
const size_t size); \
\
static bool \
op_xor(ELEMENTTYPE* const left, \
const ELEMENTTYPE* const right, \
const size_t start_left, \
const size_t start_right, \
const size_t size); \
};
ALL_FORWARD_TYPES_1(DECLARE_PARTIAL_FORWARD_OPS)
#undef DECLARE_PARTIAL_FORWARD_OPS
///////////////////////////////////////////////////////////////////////////
#undef ALL_DATATYPES_1 #undef ALL_DATATYPES_1
#undef ALL_FORWARD_TYPES_1
} // namespace neon } // namespace neon
} // namespace arm } // namespace arm

View File

@ -28,6 +28,7 @@
#include "neon-decl.h" #include "neon-decl.h"
#include "bitset/common.h" #include "bitset/common.h"
#include "bitset/detail/element_wise.h"
namespace milvus { namespace milvus {
namespace bitset { namespace bitset {
@ -1810,6 +1811,151 @@ OpArithCompareImpl<double, AOp, CmpOp>::op_arith_compare(
} }
} }
///////////////////////////////////////////////////////////////////////////
// forward ops
//
bool
ForwardOpsImpl<uint8_t>::op_and(uint8_t* const left,
const uint8_t* const right,
const size_t start_left,
const size_t start_right,
const size_t size) {
ElementWiseBitsetPolicy<uint8_t>::op_and(
left, right, start_left, start_right, size);
return true;
}
bool
ForwardOpsImpl<uint8_t>::op_and_multiple(
uint8_t* const left,
const uint8_t* const* const rights,
const size_t start_left,
const size_t* const __restrict start_rights,
const size_t n_rights,
const size_t size) {
ElementWiseBitsetPolicy<uint8_t>::op_and_multiple(
left, rights, start_left, start_rights, n_rights, size);
return true;
}
bool
ForwardOpsImpl<uint8_t>::op_or(uint8_t* const left,
const uint8_t* const right,
const size_t start_left,
const size_t start_right,
const size_t size) {
ElementWiseBitsetPolicy<uint8_t>::op_or(
left, right, start_left, start_right, size);
return true;
}
bool
ForwardOpsImpl<uint8_t>::op_or_multiple(
uint8_t* const left,
const uint8_t* const* const rights,
const size_t start_left,
const size_t* const __restrict start_rights,
const size_t n_rights,
const size_t size) {
ElementWiseBitsetPolicy<uint8_t>::op_or_multiple(
left, rights, start_left, start_rights, n_rights, size);
return true;
}
bool
ForwardOpsImpl<uint8_t>::op_xor(uint8_t* const left,
const uint8_t* const right,
const size_t start_left,
const size_t start_right,
const size_t size) {
ElementWiseBitsetPolicy<uint8_t>::op_xor(
left, right, start_left, start_right, size);
return true;
}
bool
ForwardOpsImpl<uint8_t>::op_sub(uint8_t* const left,
const uint8_t* const right,
const size_t start_left,
const size_t start_right,
const size_t size) {
ElementWiseBitsetPolicy<uint8_t>::op_sub(
left, right, start_left, start_right, size);
return true;
}
//
bool
ForwardOpsImpl<uint64_t>::op_and(uint64_t* const left,
const uint64_t* const right,
const size_t start_left,
const size_t start_right,
const size_t size) {
ElementWiseBitsetPolicy<uint64_t>::op_and(
left, right, start_left, start_right, size);
return true;
}
bool
ForwardOpsImpl<uint64_t>::op_and_multiple(
uint64_t* const left,
const uint64_t* const* const rights,
const size_t start_left,
const size_t* const __restrict start_rights,
const size_t n_rights,
const size_t size) {
ElementWiseBitsetPolicy<uint64_t>::op_and_multiple(
left, rights, start_left, start_rights, n_rights, size);
return true;
}
bool
ForwardOpsImpl<uint64_t>::op_or(uint64_t* const left,
const uint64_t* const right,
const size_t start_left,
const size_t start_right,
const size_t size) {
ElementWiseBitsetPolicy<uint64_t>::op_or(
left, right, start_left, start_right, size);
return true;
}
bool
ForwardOpsImpl<uint64_t>::op_or_multiple(
uint64_t* const left,
const uint64_t* const* const rights,
const size_t start_left,
const size_t* const __restrict start_rights,
const size_t n_rights,
const size_t size) {
ElementWiseBitsetPolicy<uint64_t>::op_or_multiple(
left, rights, start_left, start_rights, n_rights, size);
return true;
}
bool
ForwardOpsImpl<uint64_t>::op_xor(uint64_t* const left,
const uint64_t* const right,
const size_t start_left,
const size_t start_right,
const size_t size) {
ElementWiseBitsetPolicy<uint64_t>::op_xor(
left, right, start_left, start_right, size);
return true;
}
bool
ForwardOpsImpl<uint64_t>::op_sub(uint64_t* const left,
const uint64_t* const right,
const size_t start_left,
const size_t start_right,
const size_t size) {
ElementWiseBitsetPolicy<uint64_t>::op_sub(
left, right, start_left, start_right, size);
return true;
}
/////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////
} // namespace neon } // namespace neon

View File

@ -55,6 +55,30 @@ struct VectorizedNeon {
template <typename T, ArithOpType AOp, CompareOpType CmpOp> template <typename T, ArithOpType AOp, CompareOpType CmpOp>
static constexpr inline auto op_arith_compare = static constexpr inline auto op_arith_compare =
neon::OpArithCompareImpl<T, AOp, CmpOp>::op_arith_compare; neon::OpArithCompareImpl<T, AOp, CmpOp>::op_arith_compare;
template <typename ElementT>
static constexpr inline auto forward_op_and =
neon::ForwardOpsImpl<ElementT>::op_and;
template <typename ElementT>
static constexpr inline auto forward_op_and_multiple =
neon::ForwardOpsImpl<ElementT>::op_and_multiple;
template <typename ElementT>
static constexpr inline auto forward_op_or =
neon::ForwardOpsImpl<ElementT>::op_or;
template <typename ElementT>
static constexpr inline auto forward_op_or_multiple =
neon::ForwardOpsImpl<ElementT>::op_or_multiple;
template <typename ElementT>
static constexpr inline auto forward_op_xor =
neon::ForwardOpsImpl<ElementT>::op_xor;
template <typename ElementT>
static constexpr inline auto forward_op_sub =
neon::ForwardOpsImpl<ElementT>::op_sub;
}; };
} // namespace arm } // namespace arm

View File

@ -39,6 +39,11 @@ namespace sve {
FUNC(float); \ FUNC(float); \
FUNC(double); FUNC(double);
// a facility to run through all acceptable forward types
#define ALL_FORWARD_TYPES_1(FUNC) \
FUNC(uint8_t); \
FUNC(uint64_t);
/////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////
// the default implementation does nothing // the default implementation does nothing
@ -192,7 +197,122 @@ ALL_DATATYPES_1(DECLARE_PARTIAL_OP_ARITH_COMPARE)
/////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////
// forward ops
template <typename ElementT>
struct ForwardOpsImpl {
static inline bool
op_and(ElementT* const left,
const ElementT* const right,
const size_t start_left,
const size_t start_right,
const size_t size) {
return false;
}
static inline bool
op_and_multiple(ElementT* const left,
const ElementT* const* const rights,
const size_t start_left,
const size_t* const __restrict start_rights,
const size_t n_rights,
const size_t size) {
return false;
}
static inline bool
op_or(ElementT* const left,
const ElementT* const right,
const size_t start_left,
const size_t start_right,
const size_t size) {
return false;
}
static inline bool
op_or_multiple(ElementT* const left,
const ElementT* const* const rights,
const size_t start_left,
const size_t* const __restrict start_rights,
const size_t n_rights,
const size_t size) {
return false;
}
static inline bool
op_xor(ElementT* const left,
const ElementT* const right,
const size_t start_left,
const size_t start_right,
const size_t size) {
return false;
}
static inline bool
op_sub(ElementT* const left,
const ElementT* const right,
const size_t start_left,
const size_t start_right,
const size_t size) {
return false;
}
};
#define DECLARE_PARTIAL_FORWARD_OPS(ELEMENTTYPE) \
template <> \
struct ForwardOpsImpl<ELEMENTTYPE> { \
static bool \
op_and(ELEMENTTYPE* const left, \
const ELEMENTTYPE* const right, \
const size_t start_left, \
const size_t start_right, \
const size_t size); \
\
static bool \
op_and_multiple(ELEMENTTYPE* const left, \
const ELEMENTTYPE* const* const rights, \
const size_t start_left, \
const size_t* const __restrict start_rights, \
const size_t n_rights, \
const size_t size); \
\
static bool \
op_or(ELEMENTTYPE* const left, \
const ELEMENTTYPE* const right, \
const size_t start_left, \
const size_t start_right, \
const size_t size); \
\
static bool \
op_or_multiple(ELEMENTTYPE* const left, \
const ELEMENTTYPE* const* const rights, \
const size_t start_left, \
const size_t* const __restrict start_rights, \
const size_t n_rights, \
const size_t size); \
\
static bool \
op_sub(ELEMENTTYPE* const left, \
const ELEMENTTYPE* const right, \
const size_t start_left, \
const size_t start_right, \
const size_t size); \
\
static bool \
op_xor(ELEMENTTYPE* const left, \
const ELEMENTTYPE* const right, \
const size_t start_left, \
const size_t start_right, \
const size_t size); \
};
ALL_FORWARD_TYPES_1(DECLARE_PARTIAL_FORWARD_OPS)
#undef DECLARE_PARTIAL_FORWARD_OPS
///////////////////////////////////////////////////////////////////////////
#undef ALL_DATATYPES_1 #undef ALL_DATATYPES_1
#undef ALL_FORWARD_TYPES_1
} // namespace sve } // namespace sve
} // namespace arm } // namespace arm

View File

@ -28,8 +28,7 @@
#include "sve-decl.h" #include "sve-decl.h"
#include "bitset/common.h" #include "bitset/common.h"
#include "bitset/detail/element_wise.h"
// #include <stdio.h>
namespace milvus { namespace milvus {
namespace bitset { namespace bitset {
@ -1623,6 +1622,151 @@ OpArithCompareImpl<double, AOp, CmpOp>::op_arith_compare(
} }
} }
///////////////////////////////////////////////////////////////////////////
// forward ops
//
bool
ForwardOpsImpl<uint8_t>::op_and(uint8_t* const left,
const uint8_t* const right,
const size_t start_left,
const size_t start_right,
const size_t size) {
ElementWiseBitsetPolicy<uint8_t>::op_and(
left, right, start_left, start_right, size);
return true;
}
bool
ForwardOpsImpl<uint8_t>::op_and_multiple(
uint8_t* const left,
const uint8_t* const* const rights,
const size_t start_left,
const size_t* const __restrict start_rights,
const size_t n_rights,
const size_t size) {
ElementWiseBitsetPolicy<uint8_t>::op_and_multiple(
left, rights, start_left, start_rights, n_rights, size);
return true;
}
bool
ForwardOpsImpl<uint8_t>::op_or(uint8_t* const left,
const uint8_t* const right,
const size_t start_left,
const size_t start_right,
const size_t size) {
ElementWiseBitsetPolicy<uint8_t>::op_or(
left, right, start_left, start_right, size);
return true;
}
bool
ForwardOpsImpl<uint8_t>::op_or_multiple(
uint8_t* const left,
const uint8_t* const* const rights,
const size_t start_left,
const size_t* const __restrict start_rights,
const size_t n_rights,
const size_t size) {
ElementWiseBitsetPolicy<uint8_t>::op_or_multiple(
left, rights, start_left, start_rights, n_rights, size);
return true;
}
bool
ForwardOpsImpl<uint8_t>::op_xor(uint8_t* const left,
const uint8_t* const right,
const size_t start_left,
const size_t start_right,
const size_t size) {
ElementWiseBitsetPolicy<uint8_t>::op_xor(
left, right, start_left, start_right, size);
return true;
}
bool
ForwardOpsImpl<uint8_t>::op_sub(uint8_t* const left,
const uint8_t* const right,
const size_t start_left,
const size_t start_right,
const size_t size) {
ElementWiseBitsetPolicy<uint8_t>::op_sub(
left, right, start_left, start_right, size);
return true;
}
//
bool
ForwardOpsImpl<uint64_t>::op_and(uint64_t* const left,
const uint64_t* const right,
const size_t start_left,
const size_t start_right,
const size_t size) {
ElementWiseBitsetPolicy<uint64_t>::op_and(
left, right, start_left, start_right, size);
return true;
}
bool
ForwardOpsImpl<uint64_t>::op_and_multiple(
uint64_t* const left,
const uint64_t* const* const rights,
const size_t start_left,
const size_t* const __restrict start_rights,
const size_t n_rights,
const size_t size) {
ElementWiseBitsetPolicy<uint64_t>::op_and_multiple(
left, rights, start_left, start_rights, n_rights, size);
return true;
}
bool
ForwardOpsImpl<uint64_t>::op_or(uint64_t* const left,
const uint64_t* const right,
const size_t start_left,
const size_t start_right,
const size_t size) {
ElementWiseBitsetPolicy<uint64_t>::op_or(
left, right, start_left, start_right, size);
return true;
}
bool
ForwardOpsImpl<uint64_t>::op_or_multiple(
uint64_t* const left,
const uint64_t* const* const rights,
const size_t start_left,
const size_t* const __restrict start_rights,
const size_t n_rights,
const size_t size) {
ElementWiseBitsetPolicy<uint64_t>::op_or_multiple(
left, rights, start_left, start_rights, n_rights, size);
return true;
}
bool
ForwardOpsImpl<uint64_t>::op_xor(uint64_t* const left,
const uint64_t* const right,
const size_t start_left,
const size_t start_right,
const size_t size) {
ElementWiseBitsetPolicy<uint64_t>::op_xor(
left, right, start_left, start_right, size);
return true;
}
bool
ForwardOpsImpl<uint64_t>::op_sub(uint64_t* const left,
const uint64_t* const right,
const size_t start_left,
const size_t start_right,
const size_t size) {
ElementWiseBitsetPolicy<uint64_t>::op_sub(
left, right, start_left, start_right, size);
return true;
}
/////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////
} // namespace sve } // namespace sve

View File

@ -55,6 +55,30 @@ struct VectorizedSve {
template <typename T, ArithOpType AOp, CompareOpType CmpOp> template <typename T, ArithOpType AOp, CompareOpType CmpOp>
static constexpr inline auto op_arith_compare = static constexpr inline auto op_arith_compare =
sve::OpArithCompareImpl<T, AOp, CmpOp>::op_arith_compare; sve::OpArithCompareImpl<T, AOp, CmpOp>::op_arith_compare;
template <typename ElementT>
static constexpr inline auto forward_op_and =
sve::ForwardOpsImpl<ElementT>::op_and;
template <typename ElementT>
static constexpr inline auto forward_op_and_multiple =
sve::ForwardOpsImpl<ElementT>::op_and_multiple;
template <typename ElementT>
static constexpr inline auto forward_op_or =
sve::ForwardOpsImpl<ElementT>::op_or;
template <typename ElementT>
static constexpr inline auto forward_op_or_multiple =
sve::ForwardOpsImpl<ElementT>::op_or_multiple;
template <typename ElementT>
static constexpr inline auto forward_op_xor =
sve::ForwardOpsImpl<ElementT>::op_xor;
template <typename ElementT>
static constexpr inline auto forward_op_sub =
sve::ForwardOpsImpl<ElementT>::op_sub;
}; };
} // namespace arm } // namespace arm

View File

@ -88,6 +88,11 @@ using namespace milvus::bitset::detail::arm;
FUNC(__VA_ARGS__, Mod, LT); \ FUNC(__VA_ARGS__, Mod, LT); \
FUNC(__VA_ARGS__, Mod, NE); FUNC(__VA_ARGS__, Mod, NE);
// a facility to run through all possible forward ElementT
#define ALL_FORWARD_OPS(FUNC) \
FUNC(uint8_t); \
FUNC(uint64_t);
// //
namespace milvus { namespace milvus {
namespace bitset { namespace bitset {
@ -235,6 +240,7 @@ ALL_RANGE_OPS(DISPATCH_OP_WITHIN_RANGE_COLUMN_IMPL, float)
ALL_RANGE_OPS(DISPATCH_OP_WITHIN_RANGE_COLUMN_IMPL, double) ALL_RANGE_OPS(DISPATCH_OP_WITHIN_RANGE_COLUMN_IMPL, double)
#undef DISPATCH_OP_WITHIN_RANGE_COLUMN_IMPL #undef DISPATCH_OP_WITHIN_RANGE_COLUMN_IMPL
} // namespace dynamic } // namespace dynamic
///////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////
@ -282,6 +288,8 @@ ALL_RANGE_OPS(DISPATCH_OP_WITHIN_RANGE_VAL_IMPL, int64_t)
ALL_RANGE_OPS(DISPATCH_OP_WITHIN_RANGE_VAL_IMPL, float) ALL_RANGE_OPS(DISPATCH_OP_WITHIN_RANGE_VAL_IMPL, float)
ALL_RANGE_OPS(DISPATCH_OP_WITHIN_RANGE_VAL_IMPL, double) ALL_RANGE_OPS(DISPATCH_OP_WITHIN_RANGE_VAL_IMPL, double)
#undef DISPATCH_OP_WITHIN_RANGE_VAL_IMPL
} // namespace dynamic } // namespace dynamic
///////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////
@ -332,6 +340,108 @@ ALL_ARITH_CMP_OPS(DISPATCH_OP_ARITH_COMPARE, int64_t)
ALL_ARITH_CMP_OPS(DISPATCH_OP_ARITH_COMPARE, float) ALL_ARITH_CMP_OPS(DISPATCH_OP_ARITH_COMPARE, float)
ALL_ARITH_CMP_OPS(DISPATCH_OP_ARITH_COMPARE, double) ALL_ARITH_CMP_OPS(DISPATCH_OP_ARITH_COMPARE, double)
#undef DISPATCH_OP_ARITH_COMPARE
} // namespace dynamic
/////////////////////////////////////////////////////////////////////////////
// forward_ops
template <typename ElementT>
using ForwardOpsOp2 = bool (*)(ElementT* const left,
const ElementT* const right,
const size_t start_left,
const size_t start_right,
const size_t size);
template <typename ElementT>
using ForwardOpsOpMultiple2 =
bool (*)(ElementT* const left,
const ElementT* const* const rights,
const size_t start_left,
const size_t* const __restrict start_rights,
const size_t n_rights,
const size_t size);
#define DECLARE_FORWARD_OPS_OP2(ELEMENTTYPE) \
ForwardOpsOp2<ELEMENTTYPE> forward_op_and_##ELEMENTTYPE = \
VectorizedRef::template forward_op_and<ELEMENTTYPE>; \
ForwardOpsOpMultiple2<ELEMENTTYPE> forward_op_and_multiple_##ELEMENTTYPE = \
VectorizedRef::template forward_op_and_multiple<ELEMENTTYPE>; \
ForwardOpsOp2<ELEMENTTYPE> forward_op_or_##ELEMENTTYPE = \
VectorizedRef::template forward_op_or<ELEMENTTYPE>; \
ForwardOpsOpMultiple2<ELEMENTTYPE> forward_op_or_multiple_##ELEMENTTYPE = \
VectorizedRef::template forward_op_or_multiple<ELEMENTTYPE>; \
ForwardOpsOp2<ELEMENTTYPE> forward_op_xor_##ELEMENTTYPE = \
VectorizedRef::template forward_op_xor<ELEMENTTYPE>; \
ForwardOpsOp2<ELEMENTTYPE> forward_op_sub_##ELEMENTTYPE = \
VectorizedRef::template forward_op_sub<ELEMENTTYPE>;
ALL_FORWARD_OPS(DECLARE_FORWARD_OPS_OP2)
#undef DECLARE_FORWARD_OPS_OP2
//
namespace dynamic {
#define DISPATCH_FORWARD_OPS_OP_AND(ELEMENTTYPE) \
bool ForwardOpsImpl<ELEMENTTYPE>::op_and(ELEMENTTYPE* const left, \
const ELEMENTTYPE* const right, \
const size_t start_left, \
const size_t start_right, \
const size_t size) { \
return forward_op_and_##ELEMENTTYPE( \
left, right, start_left, start_right, size); \
} \
bool ForwardOpsImpl<ELEMENTTYPE>::op_and_multiple( \
ELEMENTTYPE* const left, \
const ELEMENTTYPE* const* const rights, \
const size_t start_left, \
const size_t* const __restrict start_rights, \
const size_t n_rights, \
const size_t size) { \
return forward_op_and_multiple_##ELEMENTTYPE( \
left, rights, start_left, start_rights, n_rights, size); \
} \
bool ForwardOpsImpl<ELEMENTTYPE>::op_or(ELEMENTTYPE* const left, \
const ELEMENTTYPE* const right, \
const size_t start_left, \
const size_t start_right, \
const size_t size) { \
return forward_op_or_##ELEMENTTYPE( \
left, right, start_left, start_right, size); \
} \
bool ForwardOpsImpl<ELEMENTTYPE>::op_or_multiple( \
ELEMENTTYPE* const left, \
const ELEMENTTYPE* const* const rights, \
const size_t start_left, \
const size_t* const __restrict start_rights, \
const size_t n_rights, \
const size_t size) { \
return forward_op_or_multiple_##ELEMENTTYPE( \
left, rights, start_left, start_rights, n_rights, size); \
} \
bool ForwardOpsImpl<ELEMENTTYPE>::op_xor(ELEMENTTYPE* const left, \
const ELEMENTTYPE* const right, \
const size_t start_left, \
const size_t start_right, \
const size_t size) { \
return forward_op_xor_##ELEMENTTYPE( \
left, right, start_left, start_right, size); \
} \
bool ForwardOpsImpl<ELEMENTTYPE>::op_sub(ELEMENTTYPE* const left, \
const ELEMENTTYPE* const right, \
const size_t start_left, \
const size_t start_right, \
const size_t size) { \
return forward_op_sub_##ELEMENTTYPE( \
left, right, start_left, start_right, size); \
}
ALL_FORWARD_OPS(DISPATCH_FORWARD_OPS_OP_AND)
#undef DISPATCH_FORWARD_OPS_OP_AND
} // namespace dynamic } // namespace dynamic
} // namespace detail } // namespace detail
@ -402,11 +512,28 @@ init_dynamic_hook() {
ALL_ARITH_CMP_OPS(SET_ARITH_COMPARE_AVX512, float) ALL_ARITH_CMP_OPS(SET_ARITH_COMPARE_AVX512, float)
ALL_ARITH_CMP_OPS(SET_ARITH_COMPARE_AVX512, double) ALL_ARITH_CMP_OPS(SET_ARITH_COMPARE_AVX512, double)
#define SET_FORWARD_OPS_AVX512(ELEMENTTYPE) \
forward_op_and_##ELEMENTTYPE = \
VectorizedAvx512::template forward_op_and<ELEMENTTYPE>; \
forward_op_and_multiple_##ELEMENTTYPE = \
VectorizedAvx512::template forward_op_and_multiple<ELEMENTTYPE>; \
forward_op_or_##ELEMENTTYPE = \
VectorizedAvx512::template forward_op_or<ELEMENTTYPE>; \
forward_op_or_multiple_##ELEMENTTYPE = \
VectorizedAvx512::template forward_op_or_multiple<ELEMENTTYPE>; \
forward_op_xor_##ELEMENTTYPE = \
VectorizedAvx512::template forward_op_xor<ELEMENTTYPE>; \
forward_op_sub_##ELEMENTTYPE = \
VectorizedAvx512::template forward_op_sub<ELEMENTTYPE>;
ALL_FORWARD_OPS(SET_FORWARD_OPS_AVX512)
#undef SET_OP_COMPARE_COLUMN_AVX512 #undef SET_OP_COMPARE_COLUMN_AVX512
#undef SET_OP_COMPARE_VAL_AVX512 #undef SET_OP_COMPARE_VAL_AVX512
#undef SET_OP_WITHIN_RANGE_COLUMN_AVX512 #undef SET_OP_WITHIN_RANGE_COLUMN_AVX512
#undef SET_OP_WITHIN_RANGE_VAL_AVX512 #undef SET_OP_WITHIN_RANGE_VAL_AVX512
#undef SET_ARITH_COMPARE_AVX512 #undef SET_ARITH_COMPARE_AVX512
#undef SET_FORWARD_OPS_AVX512
return; return;
} }
@ -467,11 +594,28 @@ init_dynamic_hook() {
ALL_ARITH_CMP_OPS(SET_ARITH_COMPARE_AVX2, float) ALL_ARITH_CMP_OPS(SET_ARITH_COMPARE_AVX2, float)
ALL_ARITH_CMP_OPS(SET_ARITH_COMPARE_AVX2, double) ALL_ARITH_CMP_OPS(SET_ARITH_COMPARE_AVX2, double)
#define SET_FORWARD_OPS_AVX2(ELEMENTTYPE) \
forward_op_and_##ELEMENTTYPE = \
VectorizedAvx2::template forward_op_and<ELEMENTTYPE>; \
forward_op_and_multiple_##ELEMENTTYPE = \
VectorizedAvx2::template forward_op_and_multiple<ELEMENTTYPE>; \
forward_op_or_##ELEMENTTYPE = \
VectorizedAvx2::template forward_op_or<ELEMENTTYPE>; \
forward_op_or_multiple_##ELEMENTTYPE = \
VectorizedAvx2::template forward_op_or_multiple<ELEMENTTYPE>; \
forward_op_xor_##ELEMENTTYPE = \
VectorizedAvx2::template forward_op_xor<ELEMENTTYPE>; \
forward_op_sub_##ELEMENTTYPE = \
VectorizedAvx2::template forward_op_sub<ELEMENTTYPE>;
ALL_FORWARD_OPS(SET_FORWARD_OPS_AVX2)
#undef SET_OP_COMPARE_COLUMN_AVX2 #undef SET_OP_COMPARE_COLUMN_AVX2
#undef SET_OP_COMPARE_VAL_AVX2 #undef SET_OP_COMPARE_VAL_AVX2
#undef SET_OP_WITHIN_RANGE_COLUMN_AVX2 #undef SET_OP_WITHIN_RANGE_COLUMN_AVX2
#undef SET_OP_WITHIN_RANGE_VAL_AVX2 #undef SET_OP_WITHIN_RANGE_VAL_AVX2
#undef SET_ARITH_COMPARE_AVX2 #undef SET_ARITH_COMPARE_AVX2
#undef SET_FORWARD_OPS_AVX2
return; return;
} }
@ -535,15 +679,33 @@ init_dynamic_hook() {
ALL_ARITH_CMP_OPS(SET_ARITH_COMPARE_SVE, float) ALL_ARITH_CMP_OPS(SET_ARITH_COMPARE_SVE, float)
ALL_ARITH_CMP_OPS(SET_ARITH_COMPARE_SVE, double) ALL_ARITH_CMP_OPS(SET_ARITH_COMPARE_SVE, double)
#define SET_FORWARD_OPS_SVE(ELEMENTTYPE) \
forward_op_and_##ELEMENTTYPE = \
VectorizedSve::template forward_op_and<ELEMENTTYPE>; \
forward_op_and_multiple_##ELEMENTTYPE = \
VectorizedSve::template forward_op_and_multiple<ELEMENTTYPE>; \
forward_op_or_##ELEMENTTYPE = \
VectorizedSve::template forward_op_or<ELEMENTTYPE>; \
forward_op_or_multiple_##ELEMENTTYPE = \
VectorizedSve::template forward_op_or_multiple<ELEMENTTYPE>; \
forward_op_xor_##ELEMENTTYPE = \
VectorizedSve::template forward_op_xor<ELEMENTTYPE>; \
forward_op_sub_##ELEMENTTYPE = \
VectorizedSve::template forward_op_sub<ELEMENTTYPE>;
ALL_FORWARD_OPS(SET_FORWARD_OPS_SVE)
#undef SET_OP_COMPARE_COLUMN_SVE #undef SET_OP_COMPARE_COLUMN_SVE
#undef SET_OP_COMPARE_VAL_SVE #undef SET_OP_COMPARE_VAL_SVE
#undef SET_OP_WITHIN_RANGE_COLUMN_SVE #undef SET_OP_WITHIN_RANGE_COLUMN_SVE
#undef SET_OP_WITHIN_RANGE_VAL_SVE #undef SET_OP_WITHIN_RANGE_VAL_SVE
#undef SET_ARITH_COMPARE_SVE #undef SET_ARITH_COMPARE_SVE
#undef SET_FORWARD_OPS_SVE
return; return;
} }
#endif #endif
// neon ? // neon ?
{ {
#define SET_OP_COMPARE_COLUMN_NEON(TTYPE, UTYPE, OP) \ #define SET_OP_COMPARE_COLUMN_NEON(TTYPE, UTYPE, OP) \
@ -600,11 +762,28 @@ init_dynamic_hook() {
ALL_ARITH_CMP_OPS(SET_ARITH_COMPARE_NEON, float) ALL_ARITH_CMP_OPS(SET_ARITH_COMPARE_NEON, float)
ALL_ARITH_CMP_OPS(SET_ARITH_COMPARE_NEON, double) ALL_ARITH_CMP_OPS(SET_ARITH_COMPARE_NEON, double)
#define SET_FORWARD_OPS_NEON(ELEMENTTYPE) \
forward_op_and_##ELEMENTTYPE = \
VectorizedNeon::template forward_op_and<ELEMENTTYPE>; \
forward_op_and_multiple_##ELEMENTTYPE = \
VectorizedNeon::template forward_op_and_multiple<ELEMENTTYPE>; \
forward_op_or_##ELEMENTTYPE = \
VectorizedNeon::template forward_op_or<ELEMENTTYPE>; \
forward_op_or_multiple_##ELEMENTTYPE = \
VectorizedNeon::template forward_op_or_multiple<ELEMENTTYPE>; \
forward_op_xor_##ELEMENTTYPE = \
VectorizedNeon::template forward_op_xor<ELEMENTTYPE>; \
forward_op_sub_##ELEMENTTYPE = \
VectorizedNeon::template forward_op_sub<ELEMENTTYPE>;
ALL_FORWARD_OPS(SET_FORWARD_OPS_NEON)
#undef SET_OP_COMPARE_COLUMN_NEON #undef SET_OP_COMPARE_COLUMN_NEON
#undef SET_OP_COMPARE_VAL_NEON #undef SET_OP_COMPARE_VAL_NEON
#undef SET_OP_WITHIN_RANGE_COLUMN_NEON #undef SET_OP_WITHIN_RANGE_COLUMN_NEON
#undef SET_OP_WITHIN_RANGE_VAL_NEON #undef SET_OP_WITHIN_RANGE_VAL_NEON
#undef SET_ARITH_COMPARE_NEON #undef SET_ARITH_COMPARE_NEON
#undef SET_FORWARD_OPS_NEON
return; return;
} }
@ -616,6 +795,7 @@ init_dynamic_hook() {
#undef ALL_COMPARE_OPS #undef ALL_COMPARE_OPS
#undef ALL_RANGE_OPS #undef ALL_RANGE_OPS
#undef ALL_ARITH_CMP_OPS #undef ALL_ARITH_CMP_OPS
#undef ALL_FORWARD_OPS
// //
static int init_dynamic_ = []() { static int init_dynamic_ = []() {

View File

@ -37,6 +37,11 @@ namespace dynamic {
FUNC(float); \ FUNC(float); \
FUNC(double); FUNC(double);
// a facility to run through all acceptable forward types
#define ALL_FORWARD_TYPES_1(FUNC) \
FUNC(uint8_t); \
FUNC(uint64_t);
/////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////
// the default implementation // the default implementation
template <typename T, typename U, CompareOpType Op> template <typename T, typename U, CompareOpType Op>
@ -176,11 +181,125 @@ struct OpArithCompareImpl {
ALL_DATATYPES_1(DECLARE_PARTIAL_OP_ARITH_COMPARE) ALL_DATATYPES_1(DECLARE_PARTIAL_OP_ARITH_COMPARE)
// #undef DECLARE_PARTIAL_OP_ARITH_COMPARE
///////////////////////////////////////////////////////////////////////////
// the default implementation
template <typename ElementT>
struct ForwardOpsImpl {
static inline bool
op_and(ElementT* const left,
const ElementT* const right,
const size_t start_left,
const size_t start_right,
const size_t size) {
return false;
}
static inline bool
op_and_multiple(ElementT* const left,
const ElementT* const* const rights,
const size_t start_left,
const size_t* const __restrict start_rights,
const size_t n_rights,
const size_t size) {
return false;
}
static inline bool
op_or(ElementT* const left,
const ElementT* const right,
const size_t start_left,
const size_t start_right,
const size_t size) {
return false;
}
static inline bool
op_or_multiple(ElementT* const left,
const ElementT* const* const rights,
const size_t start_left,
const size_t* const __restrict start_rights,
const size_t n_rights,
const size_t size) {
return false;
}
static inline bool
op_xor(ElementT* const left,
const ElementT* const right,
const size_t start_left,
const size_t start_right,
const size_t size) {
return false;
}
static inline bool
op_sub(ElementT* const left,
const ElementT* const right,
const size_t start_left,
const size_t start_right,
const size_t size) {
return false;
}
};
#define DECLARE_PARTIAL_FORWARD_OPS(ELEMENTTYPE) \
template <> \
struct ForwardOpsImpl<ELEMENTTYPE> { \
static bool \
op_and(ELEMENTTYPE* const left, \
const ELEMENTTYPE* const right, \
const size_t start_left, \
const size_t start_right, \
const size_t size); \
\
static bool \
op_and_multiple(ELEMENTTYPE* const left, \
const ELEMENTTYPE* const* const rights, \
const size_t start_left, \
const size_t* const __restrict start_rights, \
const size_t n_rights, \
const size_t size); \
\
static bool \
op_or(ELEMENTTYPE* const left, \
const ELEMENTTYPE* const right, \
const size_t start_left, \
const size_t start_right, \
const size_t size); \
\
static bool \
op_or_multiple(ELEMENTTYPE* const left, \
const ELEMENTTYPE* const* const rights, \
const size_t start_left, \
const size_t* const __restrict start_rights, \
const size_t n_rights, \
const size_t size); \
\
static bool \
op_sub(ELEMENTTYPE* const left, \
const ELEMENTTYPE* const right, \
const size_t start_left, \
const size_t start_right, \
const size_t size); \
\
static bool \
op_xor(ELEMENTTYPE* const left, \
const ELEMENTTYPE* const right, \
const size_t start_left, \
const size_t start_right, \
const size_t size); \
};
ALL_FORWARD_TYPES_1(DECLARE_PARTIAL_FORWARD_OPS)
#undef DECLARE_PARTIAL_FORWARD_OPS
/////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////
#undef ALL_DATATYPES_1 #undef ALL_DATATYPES_1
#undef ALL_FORWARD_TYPES_1
} // namespace dynamic } // namespace dynamic
@ -248,6 +367,77 @@ struct VectorizedDynamic {
return dynamic::OpArithCompareImpl<T, AOp, CmpOp>::op_arith_compare( return dynamic::OpArithCompareImpl<T, AOp, CmpOp>::op_arith_compare(
bitmask, src, right_operand, value, size); bitmask, src, right_operand, value, size);
} }
// The following functions just forward parameters to the reference code,
// generated for a particular platform.
template <typename ElementT>
static inline bool
forward_op_and(ElementT* const left,
const ElementT* const right,
const size_t start_left,
const size_t start_right,
const size_t size) {
return dynamic::ForwardOpsImpl<ElementT>::op_and(
left, right, start_left, start_right, size);
}
template <typename ElementT>
static inline bool
forward_op_and_multiple(ElementT* const left,
const ElementT* const* const rights,
const size_t start_left,
const size_t* const __restrict start_rights,
const size_t n_rights,
const size_t size) {
return dynamic::ForwardOpsImpl<ElementT>::op_and_multiple(
left, rights, start_left, start_rights, n_rights, size);
}
template <typename ElementT>
static inline bool
forward_op_or(ElementT* const left,
const ElementT* const right,
const size_t start_left,
const size_t start_right,
const size_t size) {
return dynamic::ForwardOpsImpl<ElementT>::op_or(
left, right, start_left, start_right, size);
}
template <typename ElementT>
static inline bool
forward_op_or_multiple(ElementT* const left,
const ElementT* const* const rights,
const size_t start_left,
const size_t* const __restrict start_rights,
const size_t n_rights,
const size_t size) {
return dynamic::ForwardOpsImpl<ElementT>::op_or_multiple(
left, rights, start_left, start_rights, n_rights, size);
}
template <typename ElementT>
static inline bool
forward_op_xor(ElementT* const left,
const ElementT* const right,
const size_t start_left,
const size_t start_right,
const size_t size) {
return dynamic::ForwardOpsImpl<ElementT>::op_xor(
left, right, start_left, start_right, size);
}
template <typename ElementT>
static inline bool
forward_op_sub(ElementT* const left,
const ElementT* const right,
const size_t start_left,
const size_t start_right,
const size_t size) {
return dynamic::ForwardOpsImpl<ElementT>::op_sub(
left, right, start_left, start_right, size);
}
}; };
} // namespace detail } // namespace detail

View File

@ -27,9 +27,13 @@ namespace bitset {
namespace detail { namespace detail {
// The default reference vectorizer. // The default reference vectorizer.
// Its every function returns a boolean value whether a vectorized implementation // Functions return a boolean value whether a vectorized implementation
// exists and was invoked. If not, then the caller code will use a default // exists and was invoked. If not, then the caller code will use a default
// non-vectorized implementation. // non-vectorized implementation.
// Certain functions just forward the parameters to the platform code. Basically,
// sometimes compiler can do a good job on its own, we just need to make sure
// that it uses available appropriate hardware instructions. No specialized
// implementation is used under the hood.
// The default vectorizer provides no vectorized implementation, forcing the // The default vectorizer provides no vectorized implementation, forcing the
// caller to use a defaut non-vectorized implementation every time. // caller to use a defaut non-vectorized implementation every time.
struct VectorizedRef { struct VectorizedRef {
@ -88,6 +92,72 @@ struct VectorizedRef {
const size_t size) { const size_t size) {
return false; return false;
} }
// The following functions just forward parameters to the reference code,
// generated for a particular platform.
// The reference 'platform' is just a default platform.
template <typename ElementT>
static inline bool
forward_op_and(ElementT* const left,
const ElementT* const right,
const size_t start_left,
const size_t start_right,
const size_t size) {
return false;
}
template <typename ElementT>
static inline bool
forward_op_and_multiple(ElementT* const left,
const ElementT* const* const rights,
const size_t start_left,
const size_t* const __restrict start_rights,
const size_t n_rights,
const size_t size) {
return false;
}
template <typename ElementT>
static inline bool
forward_op_or(ElementT* const left,
const ElementT* const right,
const size_t start_left,
const size_t start_right,
const size_t size) {
return false;
}
template <typename ElementT>
static inline bool
forward_op_or_multiple(ElementT* const left,
const ElementT* const* const rights,
const size_t start_left,
const size_t* const __restrict start_rights,
const size_t n_rights,
const size_t size) {
return false;
}
template <typename ElementT>
static inline bool
forward_op_xor(ElementT* const left,
const ElementT* const right,
const size_t start_left,
const size_t start_right,
const size_t size) {
return false;
}
template <typename ElementT>
static inline bool
forward_op_sub(ElementT* const left,
const ElementT* const right,
const size_t start_left,
const size_t start_right,
const size_t size) {
return false;
}
}; };
} // namespace detail } // namespace detail

View File

@ -39,6 +39,11 @@ namespace avx2 {
FUNC(float); \ FUNC(float); \
FUNC(double); FUNC(double);
// a facility to run through all acceptable forward types
#define ALL_FORWARD_TYPES_1(FUNC) \
FUNC(uint8_t); \
FUNC(uint64_t);
/////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////
// the default implementation does nothing // the default implementation does nothing
@ -192,7 +197,122 @@ ALL_DATATYPES_1(DECLARE_PARTIAL_OP_ARITH_COMPARE)
/////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////
// forward ops
template <typename ElementT>
struct ForwardOpsImpl {
static inline bool
op_and(ElementT* const left,
const ElementT* const right,
const size_t start_left,
const size_t start_right,
const size_t size) {
return false;
}
static inline bool
op_and_multiple(ElementT* const left,
const ElementT* const* const rights,
const size_t start_left,
const size_t* const __restrict start_rights,
const size_t n_rights,
const size_t size) {
return false;
}
static inline bool
op_or(ElementT* const left,
const ElementT* const right,
const size_t start_left,
const size_t start_right,
const size_t size) {
return false;
}
static inline bool
op_or_multiple(ElementT* const left,
const ElementT* const* const rights,
const size_t start_left,
const size_t* const __restrict start_rights,
const size_t n_rights,
const size_t size) {
return false;
}
static inline bool
op_xor(ElementT* const left,
const ElementT* const right,
const size_t start_left,
const size_t start_right,
const size_t size) {
return false;
}
static inline bool
op_sub(ElementT* const left,
const ElementT* const right,
const size_t start_left,
const size_t start_right,
const size_t size) {
return false;
}
};
#define DECLARE_PARTIAL_FORWARD_OPS(ELEMENTTYPE) \
template <> \
struct ForwardOpsImpl<ELEMENTTYPE> { \
static bool \
op_and(ELEMENTTYPE* const left, \
const ELEMENTTYPE* const right, \
const size_t start_left, \
const size_t start_right, \
const size_t size); \
\
static bool \
op_and_multiple(ELEMENTTYPE* const left, \
const ELEMENTTYPE* const* const rights, \
const size_t start_left, \
const size_t* const __restrict start_rights, \
const size_t n_rights, \
const size_t size); \
\
static bool \
op_or(ELEMENTTYPE* const left, \
const ELEMENTTYPE* const right, \
const size_t start_left, \
const size_t start_right, \
const size_t size); \
\
static bool \
op_or_multiple(ELEMENTTYPE* const left, \
const ELEMENTTYPE* const* const rights, \
const size_t start_left, \
const size_t* const __restrict start_rights, \
const size_t n_rights, \
const size_t size); \
\
static bool \
op_sub(ELEMENTTYPE* const left, \
const ELEMENTTYPE* const right, \
const size_t start_left, \
const size_t start_right, \
const size_t size); \
\
static bool \
op_xor(ELEMENTTYPE* const left, \
const ELEMENTTYPE* const right, \
const size_t start_left, \
const size_t start_right, \
const size_t size); \
};
ALL_FORWARD_TYPES_1(DECLARE_PARTIAL_FORWARD_OPS)
#undef DECLARE_PARTIAL_FORWARD_OPS
///////////////////////////////////////////////////////////////////////////
#undef ALL_DATATYPES_1 #undef ALL_DATATYPES_1
#undef ALL_FORWARD_TYPES_1
} // namespace avx2 } // namespace avx2
} // namespace x86 } // namespace x86

View File

@ -28,6 +28,7 @@
#include "avx2-decl.h" #include "avx2-decl.h"
#include "bitset/common.h" #include "bitset/common.h"
#include "bitset/detail/element_wise.h"
#include "common.h" #include "common.h"
namespace milvus { namespace milvus {
@ -1649,6 +1650,151 @@ OpArithCompareImpl<double, AOp, CmpOp>::op_arith_compare(
} }
} }
///////////////////////////////////////////////////////////////////////////
// forward ops
//
bool
ForwardOpsImpl<uint8_t>::op_and(uint8_t* const left,
const uint8_t* const right,
const size_t start_left,
const size_t start_right,
const size_t size) {
ElementWiseBitsetPolicy<uint8_t>::op_and(
left, right, start_left, start_right, size);
return true;
}
bool
ForwardOpsImpl<uint8_t>::op_and_multiple(
uint8_t* const left,
const uint8_t* const* const rights,
const size_t start_left,
const size_t* const __restrict start_rights,
const size_t n_rights,
const size_t size) {
ElementWiseBitsetPolicy<uint8_t>::op_and_multiple(
left, rights, start_left, start_rights, n_rights, size);
return true;
}
bool
ForwardOpsImpl<uint8_t>::op_or(uint8_t* const left,
const uint8_t* const right,
const size_t start_left,
const size_t start_right,
const size_t size) {
ElementWiseBitsetPolicy<uint8_t>::op_or(
left, right, start_left, start_right, size);
return true;
}
bool
ForwardOpsImpl<uint8_t>::op_or_multiple(
uint8_t* const left,
const uint8_t* const* const rights,
const size_t start_left,
const size_t* const __restrict start_rights,
const size_t n_rights,
const size_t size) {
ElementWiseBitsetPolicy<uint8_t>::op_or_multiple(
left, rights, start_left, start_rights, n_rights, size);
return true;
}
bool
ForwardOpsImpl<uint8_t>::op_xor(uint8_t* const left,
const uint8_t* const right,
const size_t start_left,
const size_t start_right,
const size_t size) {
ElementWiseBitsetPolicy<uint8_t>::op_xor(
left, right, start_left, start_right, size);
return true;
}
bool
ForwardOpsImpl<uint8_t>::op_sub(uint8_t* const left,
const uint8_t* const right,
const size_t start_left,
const size_t start_right,
const size_t size) {
ElementWiseBitsetPolicy<uint8_t>::op_sub(
left, right, start_left, start_right, size);
return true;
}
//
bool
ForwardOpsImpl<uint64_t>::op_and(uint64_t* const left,
const uint64_t* const right,
const size_t start_left,
const size_t start_right,
const size_t size) {
ElementWiseBitsetPolicy<uint64_t>::op_and(
left, right, start_left, start_right, size);
return true;
}
bool
ForwardOpsImpl<uint64_t>::op_and_multiple(
uint64_t* const left,
const uint64_t* const* const rights,
const size_t start_left,
const size_t* const __restrict start_rights,
const size_t n_rights,
const size_t size) {
ElementWiseBitsetPolicy<uint64_t>::op_and_multiple(
left, rights, start_left, start_rights, n_rights, size);
return true;
}
bool
ForwardOpsImpl<uint64_t>::op_or(uint64_t* const left,
const uint64_t* const right,
const size_t start_left,
const size_t start_right,
const size_t size) {
ElementWiseBitsetPolicy<uint64_t>::op_or(
left, right, start_left, start_right, size);
return true;
}
bool
ForwardOpsImpl<uint64_t>::op_or_multiple(
uint64_t* const left,
const uint64_t* const* const rights,
const size_t start_left,
const size_t* const __restrict start_rights,
const size_t n_rights,
const size_t size) {
ElementWiseBitsetPolicy<uint64_t>::op_or_multiple(
left, rights, start_left, start_rights, n_rights, size);
return true;
}
bool
ForwardOpsImpl<uint64_t>::op_xor(uint64_t* const left,
const uint64_t* const right,
const size_t start_left,
const size_t start_right,
const size_t size) {
ElementWiseBitsetPolicy<uint64_t>::op_xor(
left, right, start_left, start_right, size);
return true;
}
bool
ForwardOpsImpl<uint64_t>::op_sub(uint64_t* const left,
const uint64_t* const right,
const size_t start_left,
const size_t start_right,
const size_t size) {
ElementWiseBitsetPolicy<uint64_t>::op_sub(
left, right, start_left, start_right, size);
return true;
}
/////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////
} // namespace avx2 } // namespace avx2

View File

@ -55,6 +55,30 @@ struct VectorizedAvx2 {
template <typename T, ArithOpType AOp, CompareOpType CmpOp> template <typename T, ArithOpType AOp, CompareOpType CmpOp>
static constexpr inline auto op_arith_compare = static constexpr inline auto op_arith_compare =
avx2::OpArithCompareImpl<T, AOp, CmpOp>::op_arith_compare; avx2::OpArithCompareImpl<T, AOp, CmpOp>::op_arith_compare;
template <typename ElementT>
static constexpr inline auto forward_op_and =
avx2::ForwardOpsImpl<ElementT>::op_and;
template <typename ElementT>
static constexpr inline auto forward_op_and_multiple =
avx2::ForwardOpsImpl<ElementT>::op_and_multiple;
template <typename ElementT>
static constexpr inline auto forward_op_or =
avx2::ForwardOpsImpl<ElementT>::op_or;
template <typename ElementT>
static constexpr inline auto forward_op_or_multiple =
avx2::ForwardOpsImpl<ElementT>::op_or_multiple;
template <typename ElementT>
static constexpr inline auto forward_op_xor =
avx2::ForwardOpsImpl<ElementT>::op_xor;
template <typename ElementT>
static constexpr inline auto forward_op_sub =
avx2::ForwardOpsImpl<ElementT>::op_sub;
}; };
} // namespace x86 } // namespace x86

View File

@ -39,6 +39,11 @@ namespace avx512 {
FUNC(float); \ FUNC(float); \
FUNC(double); FUNC(double);
// a facility to run through all acceptable forward types
#define ALL_FORWARD_TYPES_1(FUNC) \
FUNC(uint8_t); \
FUNC(uint64_t);
/////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////
// the default implementation does nothing // the default implementation does nothing
@ -192,7 +197,122 @@ ALL_DATATYPES_1(DECLARE_PARTIAL_OP_ARITH_COMPARE)
/////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////
// forward ops
template <typename ElementT>
struct ForwardOpsImpl {
static inline bool
op_and(ElementT* const left,
const ElementT* const right,
const size_t start_left,
const size_t start_right,
const size_t size) {
return false;
}
static inline bool
op_and_multiple(ElementT* const left,
const ElementT* const* const rights,
const size_t start_left,
const size_t* const __restrict start_rights,
const size_t n_rights,
const size_t size) {
return false;
}
static inline bool
op_or(ElementT* const left,
const ElementT* const right,
const size_t start_left,
const size_t start_right,
const size_t size) {
return false;
}
static inline bool
op_or_multiple(ElementT* const left,
const ElementT* const* const rights,
const size_t start_left,
const size_t* const __restrict start_rights,
const size_t n_rights,
const size_t size) {
return false;
}
static inline bool
op_xor(ElementT* const left,
const ElementT* const right,
const size_t start_left,
const size_t start_right,
const size_t size) {
return false;
}
static inline bool
op_sub(ElementT* const left,
const ElementT* const right,
const size_t start_left,
const size_t start_right,
const size_t size) {
return false;
}
};
#define DECLARE_PARTIAL_FORWARD_OPS(ELEMENTTYPE) \
template <> \
struct ForwardOpsImpl<ELEMENTTYPE> { \
static bool \
op_and(ELEMENTTYPE* const left, \
const ELEMENTTYPE* const right, \
const size_t start_left, \
const size_t start_right, \
const size_t size); \
\
static bool \
op_and_multiple(ELEMENTTYPE* const left, \
const ELEMENTTYPE* const* const rights, \
const size_t start_left, \
const size_t* const __restrict start_rights, \
const size_t n_rights, \
const size_t size); \
\
static bool \
op_or(ELEMENTTYPE* const left, \
const ELEMENTTYPE* const right, \
const size_t start_left, \
const size_t start_right, \
const size_t size); \
\
static bool \
op_or_multiple(ELEMENTTYPE* const left, \
const ELEMENTTYPE* const* const rights, \
const size_t start_left, \
const size_t* const __restrict start_rights, \
const size_t n_rights, \
const size_t size); \
\
static bool \
op_sub(ELEMENTTYPE* const left, \
const ELEMENTTYPE* const right, \
const size_t start_left, \
const size_t start_right, \
const size_t size); \
\
static bool \
op_xor(ELEMENTTYPE* const left, \
const ELEMENTTYPE* const right, \
const size_t start_left, \
const size_t start_right, \
const size_t size); \
};
ALL_FORWARD_TYPES_1(DECLARE_PARTIAL_FORWARD_OPS)
#undef DECLARE_PARTIAL_FORWARD_OPS
///////////////////////////////////////////////////////////////////////////
#undef ALL_DATATYPES_1 #undef ALL_DATATYPES_1
#undef ALL_FORWARD_TYPES_1
} // namespace avx512 } // namespace avx512
} // namespace x86 } // namespace x86

View File

@ -28,6 +28,7 @@
#include "avx512-decl.h" #include "avx512-decl.h"
#include "bitset/common.h" #include "bitset/common.h"
#include "bitset/detail/element_wise.h"
#include "common.h" #include "common.h"
namespace milvus { namespace milvus {
@ -1871,6 +1872,151 @@ OpArithCompareImpl<double, AOp, CmpOp>::op_arith_compare(
} }
} }
///////////////////////////////////////////////////////////////////////////
// forward ops
//
bool
ForwardOpsImpl<uint8_t>::op_and(uint8_t* const left,
const uint8_t* const right,
const size_t start_left,
const size_t start_right,
const size_t size) {
ElementWiseBitsetPolicy<uint8_t>::op_and(
left, right, start_left, start_right, size);
return true;
}
bool
ForwardOpsImpl<uint8_t>::op_and_multiple(
uint8_t* const left,
const uint8_t* const* const rights,
const size_t start_left,
const size_t* const __restrict start_rights,
const size_t n_rights,
const size_t size) {
ElementWiseBitsetPolicy<uint8_t>::op_and_multiple(
left, rights, start_left, start_rights, n_rights, size);
return true;
}
bool
ForwardOpsImpl<uint8_t>::op_or(uint8_t* const left,
const uint8_t* const right,
const size_t start_left,
const size_t start_right,
const size_t size) {
ElementWiseBitsetPolicy<uint8_t>::op_or(
left, right, start_left, start_right, size);
return true;
}
bool
ForwardOpsImpl<uint8_t>::op_or_multiple(
uint8_t* const left,
const uint8_t* const* const rights,
const size_t start_left,
const size_t* const __restrict start_rights,
const size_t n_rights,
const size_t size) {
ElementWiseBitsetPolicy<uint8_t>::op_or_multiple(
left, rights, start_left, start_rights, n_rights, size);
return true;
}
bool
ForwardOpsImpl<uint8_t>::op_xor(uint8_t* const left,
const uint8_t* const right,
const size_t start_left,
const size_t start_right,
const size_t size) {
ElementWiseBitsetPolicy<uint8_t>::op_xor(
left, right, start_left, start_right, size);
return true;
}
bool
ForwardOpsImpl<uint8_t>::op_sub(uint8_t* const left,
const uint8_t* const right,
const size_t start_left,
const size_t start_right,
const size_t size) {
ElementWiseBitsetPolicy<uint8_t>::op_sub(
left, right, start_left, start_right, size);
return true;
}
//
bool
ForwardOpsImpl<uint64_t>::op_and(uint64_t* const left,
const uint64_t* const right,
const size_t start_left,
const size_t start_right,
const size_t size) {
ElementWiseBitsetPolicy<uint64_t>::op_and(
left, right, start_left, start_right, size);
return true;
}
bool
ForwardOpsImpl<uint64_t>::op_and_multiple(
uint64_t* const left,
const uint64_t* const* const rights,
const size_t start_left,
const size_t* const __restrict start_rights,
const size_t n_rights,
const size_t size) {
ElementWiseBitsetPolicy<uint64_t>::op_and_multiple(
left, rights, start_left, start_rights, n_rights, size);
return true;
}
bool
ForwardOpsImpl<uint64_t>::op_or(uint64_t* const left,
const uint64_t* const right,
const size_t start_left,
const size_t start_right,
const size_t size) {
ElementWiseBitsetPolicy<uint64_t>::op_or(
left, right, start_left, start_right, size);
return true;
}
bool
ForwardOpsImpl<uint64_t>::op_or_multiple(
uint64_t* const left,
const uint64_t* const* const rights,
const size_t start_left,
const size_t* const __restrict start_rights,
const size_t n_rights,
const size_t size) {
ElementWiseBitsetPolicy<uint64_t>::op_or_multiple(
left, rights, start_left, start_rights, n_rights, size);
return true;
}
bool
ForwardOpsImpl<uint64_t>::op_xor(uint64_t* const left,
const uint64_t* const right,
const size_t start_left,
const size_t start_right,
const size_t size) {
ElementWiseBitsetPolicy<uint64_t>::op_xor(
left, right, start_left, start_right, size);
return true;
}
bool
ForwardOpsImpl<uint64_t>::op_sub(uint64_t* const left,
const uint64_t* const right,
const size_t start_left,
const size_t start_right,
const size_t size) {
ElementWiseBitsetPolicy<uint64_t>::op_sub(
left, right, start_left, start_right, size);
return true;
}
/////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////
} // namespace avx512 } // namespace avx512

View File

@ -55,6 +55,30 @@ struct VectorizedAvx512 {
template <typename T, ArithOpType AOp, CompareOpType CmpOp> template <typename T, ArithOpType AOp, CompareOpType CmpOp>
static constexpr inline auto op_arith_compare = static constexpr inline auto op_arith_compare =
avx512::OpArithCompareImpl<T, AOp, CmpOp>::op_arith_compare; avx512::OpArithCompareImpl<T, AOp, CmpOp>::op_arith_compare;
template <typename ElementT>
static constexpr inline auto forward_op_and =
avx512::ForwardOpsImpl<ElementT>::op_and;
template <typename ElementT>
static constexpr inline auto forward_op_and_multiple =
avx512::ForwardOpsImpl<ElementT>::op_and_multiple;
template <typename ElementT>
static constexpr inline auto forward_op_or =
avx512::ForwardOpsImpl<ElementT>::op_or;
template <typename ElementT>
static constexpr inline auto forward_op_or_multiple =
avx512::ForwardOpsImpl<ElementT>::op_or_multiple;
template <typename ElementT>
static constexpr inline auto forward_op_xor =
avx512::ForwardOpsImpl<ElementT>::op_xor;
template <typename ElementT>
static constexpr inline auto forward_op_sub =
avx512::ForwardOpsImpl<ElementT>::op_sub;
}; };
} // namespace x86 } // namespace x86

View File

@ -23,14 +23,13 @@ namespace detail {
template <typename PolicyT> template <typename PolicyT>
struct ConstProxy { struct ConstProxy {
using policy_type = PolicyT; using policy_type = PolicyT;
using size_type = typename policy_type::size_type;
using data_type = typename policy_type::data_type; using data_type = typename policy_type::data_type;
using self_type = ConstProxy; using self_type = ConstProxy;
const data_type& element; const data_type& element;
data_type mask; data_type mask;
inline ConstProxy(const data_type& _element, const size_type _shift) inline ConstProxy(const data_type& _element, const size_t _shift)
: element{_element} { : element{_element} {
mask = (data_type(1) << _shift); mask = (data_type(1) << _shift);
} }
@ -47,15 +46,13 @@ struct ConstProxy {
template <typename PolicyT> template <typename PolicyT>
struct Proxy { struct Proxy {
using policy_type = PolicyT; using policy_type = PolicyT;
using size_type = typename policy_type::size_type;
using data_type = typename policy_type::data_type; using data_type = typename policy_type::data_type;
using self_type = Proxy; using self_type = Proxy;
data_type& element; data_type& element;
data_type mask; data_type mask;
inline Proxy(data_type& _element, const size_type _shift) inline Proxy(data_type& _element, const size_t _shift) : element{_element} {
: element{_element} {
mask = (data_type(1) << _shift); mask = (data_type(1) << _shift);
} }

File diff suppressed because it is too large Load Diff