Support dynamic chunk size

Signed-off-by: FluorineDog <guilin.gou@zilliz.com>
This commit is contained in:
FluorineDog 2020-12-23 19:02:37 +08:00 committed by yefu.chen
parent 7ce0f27ebc
commit 342f4cb741
21 changed files with 160 additions and 202 deletions

View File

@ -71,7 +71,7 @@ BinarySearchBruteForceFast(MetricType metric_type,
float* result_distances,
idx_t* result_labels,
faiss::ConcurrentBitsetPtr bitset) {
const idx_t block_size = segcore::DefaultElementPerChunk;
const idx_t block_size = chunk_size;
bool use_heap = true;
if (metric_type == faiss::METRIC_Jaccard || metric_type == faiss::METRIC_Tanimoto) {

View File

@ -19,7 +19,6 @@
#include "query/BruteForceSearch.h"
namespace milvus::query {
using segcore::DefaultElementPerChunk;
static faiss::ConcurrentBitsetPtr
create_bitmap_view(std::optional<const BitmapSimple*> bitmaps_opt, int64_t chunk_id) {
@ -48,7 +47,6 @@ QueryBruteForceImpl(const segcore::SegmentSmallIndex& segment,
auto& record = segment.get_insert_record();
// step 1: binary search to find the barrier of the snapshot
auto ins_barrier = get_barrier(record, timestamp);
auto max_chunk = upper_div(ins_barrier, DefaultElementPerChunk);
// auto del_barrier = get_barrier(deleted_record_, timestamp);
#if 0
@ -89,29 +87,40 @@ QueryBruteForceImpl(const segcore::SegmentSmallIndex& segment,
for (int64_t i = 0; i < total_count; ++i) {
auto& x = uids[i];
if (x != -1) {
x += chunk_id * DefaultElementPerChunk;
x += chunk_id * indexing_entry.get_chunk_size();
}
}
segcore::merge_into(num_queries, topK, final_dis.data(), final_uids.data(), dis, uids);
}
using segcore::FloatVector;
auto vec_ptr = record.get_entity<FloatVector>(vecfield_offset);
// step 4: brute force search where small indexing is unavailable
auto vec_chunk_size = vec_ptr->get_chunk_size();
Assert(vec_chunk_size == indexing_entry.get_chunk_size());
auto max_chunk = upper_div(ins_barrier, vec_chunk_size);
for (int chunk_id = max_indexed_id; chunk_id < max_chunk; ++chunk_id) {
std::vector<int64_t> buf_uids(total_count, -1);
std::vector<float> buf_dis(total_count, std::numeric_limits<float>::max());
faiss::float_maxheap_array_t buf = {(size_t)num_queries, (size_t)topK, buf_uids.data(), buf_dis.data()};
auto& chunk = vec_ptr->get_chunk(chunk_id);
auto nsize =
chunk_id != max_chunk - 1 ? DefaultElementPerChunk : ins_barrier - chunk_id * DefaultElementPerChunk;
auto element_begin = chunk_id * vec_chunk_size;
auto element_end = std::min(ins_barrier, (chunk_id + 1) * vec_chunk_size);
auto nsize = element_end - element_begin;
auto bitmap_view = create_bitmap_view(bitmaps_opt, chunk_id);
faiss::knn_L2sqr(query_data, chunk.data(), dim, num_queries, nsize, &buf, bitmap_view);
Assert(buf_uids.size() == total_count);
// convert chunk uid to segment uid
for (auto& x : buf_uids) {
if (x != -1) {
x += chunk_id * DefaultElementPerChunk;
x += chunk_id * vec_chunk_size;
}
}
segcore::merge_into(num_queries, topK, final_dis.data(), final_uids.data(), buf_dis.data(), buf_uids.data());
@ -148,7 +157,6 @@ BinaryQueryBruteForceImpl(const segcore::SegmentSmallIndex& segment,
auto& record = segment.get_insert_record();
// step 1: binary search to find the barrier of the snapshot
auto ins_barrier = get_barrier(record, timestamp);
auto max_chunk = upper_div(ins_barrier, DefaultElementPerChunk);
auto metric_type = GetMetricType(info.metric_type_);
// auto del_barrier = get_barrier(deleted_record_, timestamp);
@ -181,13 +189,17 @@ BinaryQueryBruteForceImpl(const segcore::SegmentSmallIndex& segment,
auto max_indexed_id = 0;
// step 4: brute force search where small indexing is unavailable
auto vec_chunk_size = vec_ptr->get_chunk_size();
auto max_chunk = upper_div(ins_barrier, vec_chunk_size);
for (int chunk_id = max_indexed_id; chunk_id < max_chunk; ++chunk_id) {
std::vector<int64_t> buf_uids(total_count, -1);
std::vector<float> buf_dis(total_count, std::numeric_limits<float>::max());
auto& chunk = vec_ptr->get_chunk(chunk_id);
auto nsize =
chunk_id != max_chunk - 1 ? DefaultElementPerChunk : ins_barrier - chunk_id * DefaultElementPerChunk;
auto element_begin = chunk_id * vec_chunk_size;
auto element_end = std::min(ins_barrier, (chunk_id + 1) * vec_chunk_size);
auto nsize = element_end - element_begin;
auto bitmap_view = create_bitmap_view(bitmaps_opt, chunk_id);
BinarySearchBruteForce(query_dataset, chunk.data(), nsize, buf_dis.data(), buf_uids.data(), bitmap_view);
@ -195,7 +207,7 @@ BinaryQueryBruteForceImpl(const segcore::SegmentSmallIndex& segment,
// convert chunk uid to segment uid
for (auto& x : buf_uids) {
if (x != -1) {
x += chunk_id * DefaultElementPerChunk;
x += chunk_id * vec_chunk_size;
}
}

View File

@ -132,23 +132,24 @@ ExecExprVisitor::ExecRangeVisitorImpl(RangeExprImpl<T>& expr, IndexFunc index_fu
RetType results(vec.num_chunk());
auto indexing_barrier = indexing_record.get_finished_ack();
auto chunk_size = vec.get_chunk_size();
for (auto chunk_id = 0; chunk_id < indexing_barrier; ++chunk_id) {
auto& result = results[chunk_id];
auto indexing = entry.get_indexing(chunk_id);
auto data = index_func(indexing);
result = std::move(*data);
Assert(result.size() == segcore::DefaultElementPerChunk);
Assert(result.size() == chunk_size);
}
for (auto chunk_id = indexing_barrier; chunk_id < vec.num_chunk(); ++chunk_id) {
auto& result = results[chunk_id];
result.resize(segcore::DefaultElementPerChunk);
result.resize(chunk_size);
auto chunk = vec.get_chunk(chunk_id);
const T* data = chunk.data();
for (int index = 0; index < segcore::DefaultElementPerChunk; ++index) {
for (int index = 0; index < chunk_size; ++index) {
result[index] = element_func(data[index]);
}
Assert(result.size() == segcore::DefaultElementPerChunk);
Assert(result.size() == chunk_size);
}
return results;
}
@ -290,13 +291,13 @@ ExecExprVisitor::ExecTermVisitorImpl(TermExpr& expr_raw) -> RetType {
auto N = records.ack_responder_.GetAck();
// small batch
auto chunk_size = vec.get_chunk_size();
for (int64_t chunk_id = 0; chunk_id < num_chunk; ++chunk_id) {
auto& chunk = vec.get_chunk(chunk_id);
auto size = chunk_id == num_chunk - 1 ? N - chunk_id * segcore::DefaultElementPerChunk
: segcore::DefaultElementPerChunk;
auto size = chunk_id == num_chunk - 1 ? N - chunk_id * chunk_size : chunk_size;
boost::dynamic_bitset<> bitset(segcore::DefaultElementPerChunk);
boost::dynamic_bitset<> bitset(chunk_size);
for (int i = 0; i < size; ++i) {
auto value = chunk[i];
bool is_in = std::binary_search(expr.terms_.begin(), expr.terms_.end(), value);

View File

@ -20,6 +20,7 @@
#include <vector>
#include <utility>
#include "utils/EasyAssert.h"
#include "utils/tools.h"
#include <boost/container/vector.hpp>
namespace milvus::segcore {
@ -53,7 +54,6 @@ namespace milvus::segcore {
template <typename Type>
using FixedVector = boost::container::vector<Type>;
constexpr int64_t DefaultElementPerChunk = 32 * 1024;
template <typename Type>
class ThreadSafeVector {
@ -98,7 +98,8 @@ class ThreadSafeVector {
class VectorBase {
public:
VectorBase() = default;
explicit VectorBase(int64_t chunk_size) : chunk_size_(chunk_size) {
}
virtual ~VectorBase() = default;
virtual void
@ -106,9 +107,17 @@ class VectorBase {
virtual void
set_data_raw(ssize_t element_offset, void* source, ssize_t element_count) = 0;
int64_t
get_chunk_size() const {
return chunk_size_;
}
protected:
const int64_t chunk_size_;
};
template <typename Type, bool is_scalar = false, ssize_t ElementsPerChunk = DefaultElementPerChunk>
template <typename Type, bool is_scalar = false>
class ConcurrentVectorImpl : public VectorBase {
public:
// constants
@ -122,14 +131,14 @@ class ConcurrentVectorImpl : public VectorBase {
operator=(const ConcurrentVectorImpl&) = delete;
public:
explicit ConcurrentVectorImpl(ssize_t dim = 1) : Dim(is_scalar ? 1 : dim), SizePerChunk(Dim * ElementsPerChunk) {
explicit ConcurrentVectorImpl(ssize_t dim, int64_t chunk_size) : VectorBase(chunk_size), Dim(is_scalar ? 1 : dim) {
Assert(is_scalar ? dim == 1 : dim != 1);
}
void
grow_to_at_least(int64_t element_count) override {
auto chunk_count = (element_count + ElementsPerChunk - 1) / ElementsPerChunk;
chunks_.emplace_to_at_least(chunk_count, SizePerChunk);
auto chunk_count = upper_div(element_count, chunk_size_);
chunks_.emplace_to_at_least(chunk_count, Dim * chunk_size_);
}
void
@ -143,28 +152,28 @@ class ConcurrentVectorImpl : public VectorBase {
return;
}
this->grow_to_at_least(element_offset + element_count);
auto chunk_id = element_offset / ElementsPerChunk;
auto chunk_offset = element_offset % ElementsPerChunk;
auto chunk_id = element_offset / chunk_size_;
auto chunk_offset = element_offset % chunk_size_;
ssize_t source_offset = 0;
// first partition:
if (chunk_offset + element_count <= ElementsPerChunk) {
if (chunk_offset + element_count <= chunk_size_) {
// only first
fill_chunk(chunk_id, chunk_offset, element_count, source, source_offset);
return;
}
auto first_size = ElementsPerChunk - chunk_offset;
auto first_size = chunk_size_ - chunk_offset;
fill_chunk(chunk_id, chunk_offset, first_size, source, source_offset);
source_offset += ElementsPerChunk - chunk_offset;
source_offset += chunk_size_ - chunk_offset;
element_count -= first_size;
++chunk_id;
// the middle
while (element_count >= ElementsPerChunk) {
fill_chunk(chunk_id, 0, ElementsPerChunk, source, source_offset);
source_offset += ElementsPerChunk;
element_count -= ElementsPerChunk;
while (element_count >= chunk_size_) {
fill_chunk(chunk_id, 0, chunk_size_, source, source_offset);
source_offset += chunk_size_;
element_count -= chunk_size_;
++chunk_id;
}
@ -182,16 +191,16 @@ class ConcurrentVectorImpl : public VectorBase {
// just for fun, don't use it directly
const Type*
get_element(ssize_t element_index) const {
auto chunk_id = element_index / ElementsPerChunk;
auto chunk_offset = element_index % ElementsPerChunk;
auto chunk_id = element_index / chunk_size_;
auto chunk_offset = element_index % chunk_size_;
return get_chunk(chunk_id).data() + chunk_offset * Dim;
}
const Type&
operator[](ssize_t element_index) const {
Assert(Dim == 1);
auto chunk_id = element_index / ElementsPerChunk;
auto chunk_offset = element_index % ElementsPerChunk;
auto chunk_id = element_index / chunk_size_;
auto chunk_offset = element_index % chunk_size_;
return get_chunk(chunk_id)[chunk_offset];
}
@ -215,7 +224,6 @@ class ConcurrentVectorImpl : public VectorBase {
}
const ssize_t Dim;
const ssize_t SizePerChunk;
private:
ThreadSafeVector<Chunk> chunks_;
@ -223,7 +231,10 @@ class ConcurrentVectorImpl : public VectorBase {
template <typename Type>
class ConcurrentVector : public ConcurrentVectorImpl<Type, true> {
using ConcurrentVectorImpl<Type, true>::ConcurrentVectorImpl;
public:
explicit ConcurrentVector(int64_t chunk_size)
: ConcurrentVectorImpl<Type, true>::ConcurrentVectorImpl(1, chunk_size) {
}
};
class VectorTrait {};
@ -237,13 +248,17 @@ class BinaryVector : public VectorTrait {
template <>
class ConcurrentVector<FloatVector> : public ConcurrentVectorImpl<float, false> {
using ConcurrentVectorImpl<float, false>::ConcurrentVectorImpl;
public:
ConcurrentVector(int64_t dim, int64_t chunk_size)
: ConcurrentVectorImpl<float, false>::ConcurrentVectorImpl(dim, chunk_size) {
}
};
template <>
class ConcurrentVector<BinaryVector> : public ConcurrentVectorImpl<uint8_t, false> {
public:
explicit ConcurrentVector(int64_t dim) : binary_dim_(dim), ConcurrentVectorImpl(dim / 8) {
explicit ConcurrentVector(int64_t dim, int64_t chunk_size)
: binary_dim_(dim), ConcurrentVectorImpl(dim / 8, chunk_size) {
Assert(dim % 8 == 0);
}

View File

@ -29,8 +29,9 @@ struct DeletedRecord {
std::shared_ptr<TmpBitmap>
clone(int64_t capacity);
};
DeletedRecord() : lru_(std::make_shared<TmpBitmap>()) {
static constexpr int64_t deprecated_chunk_size = 32 * 1024;
DeletedRecord()
: lru_(std::make_shared<TmpBitmap>()), timestamps_(deprecated_chunk_size), uids_(deprecated_chunk_size) {
lru_->bitmap_ptr = std::make_shared<faiss::ConcurrentBitset>(0);
}

View File

@ -24,8 +24,8 @@ VecIndexingEntry::BuildIndexRange(int64_t ack_beg, int64_t ack_end, const Vector
auto source = dynamic_cast<const ConcurrentVector<FloatVector>*>(vec_base);
Assert(source);
auto chunk_size = source->num_chunk();
assert(ack_end <= chunk_size);
auto num_chunk = source->num_chunk();
assert(ack_end <= num_chunk);
auto conf = get_build_conf();
data_.grow_to_at_least(ack_end);
for (int chunk_id = ack_beg; chunk_id < ack_end; chunk_id++) {
@ -33,7 +33,7 @@ VecIndexingEntry::BuildIndexRange(int64_t ack_beg, int64_t ack_end, const Vector
// build index for chunk
// TODO
auto indexing = std::make_unique<knowhere::IVF>();
auto dataset = knowhere::GenDataset(DefaultElementPerChunk, dim, chunk.data());
auto dataset = knowhere::GenDataset(source->get_chunk_size(), dim, chunk.data());
indexing->Train(dataset, conf);
indexing->AddWithoutIds(dataset, conf);
data_[chunk_id] = std::move(indexing);
@ -87,25 +87,24 @@ void
ScalarIndexingEntry<T>::BuildIndexRange(int64_t ack_beg, int64_t ack_end, const VectorBase* vec_base) {
auto source = dynamic_cast<const ConcurrentVector<T>*>(vec_base);
Assert(source);
auto chunk_size = source->num_chunk();
assert(ack_end <= chunk_size);
auto num_chunk = source->num_chunk();
assert(ack_end <= num_chunk);
data_.grow_to_at_least(ack_end);
for (int chunk_id = ack_beg; chunk_id < ack_end; chunk_id++) {
const auto& chunk = source->get_chunk(chunk_id);
// build index for chunk
// TODO
Assert(chunk.size() == DefaultElementPerChunk);
auto indexing = std::make_unique<knowhere::scalar::StructuredIndexSort<T>>();
indexing->Build(DefaultElementPerChunk, chunk.data());
indexing->Build(vec_base->get_chunk_size(), chunk.data());
data_[chunk_id] = std::move(indexing);
}
}
std::unique_ptr<IndexingEntry>
CreateIndex(const FieldMeta& field_meta) {
CreateIndex(const FieldMeta& field_meta, int64_t chunk_size) {
if (field_meta.is_vector()) {
if (field_meta.get_data_type() == DataType::VECTOR_FLOAT) {
return std::make_unique<VecIndexingEntry>(field_meta);
return std::make_unique<VecIndexingEntry>(field_meta, chunk_size);
} else {
// TODO
PanicInfo("unsupported");
@ -113,17 +112,17 @@ CreateIndex(const FieldMeta& field_meta) {
}
switch (field_meta.get_data_type()) {
case DataType::INT8:
return std::make_unique<ScalarIndexingEntry<int8_t>>(field_meta);
return std::make_unique<ScalarIndexingEntry<int8_t>>(field_meta, chunk_size);
case DataType::INT16:
return std::make_unique<ScalarIndexingEntry<int16_t>>(field_meta);
return std::make_unique<ScalarIndexingEntry<int16_t>>(field_meta, chunk_size);
case DataType::INT32:
return std::make_unique<ScalarIndexingEntry<int32_t>>(field_meta);
return std::make_unique<ScalarIndexingEntry<int32_t>>(field_meta, chunk_size);
case DataType::INT64:
return std::make_unique<ScalarIndexingEntry<int64_t>>(field_meta);
return std::make_unique<ScalarIndexingEntry<int64_t>>(field_meta, chunk_size);
case DataType::FLOAT:
return std::make_unique<ScalarIndexingEntry<float>>(field_meta);
return std::make_unique<ScalarIndexingEntry<float>>(field_meta, chunk_size);
case DataType::DOUBLE:
return std::make_unique<ScalarIndexingEntry<double>>(field_meta);
return std::make_unique<ScalarIndexingEntry<double>>(field_meta, chunk_size);
default:
PanicInfo("unsupported");
}

View File

@ -26,7 +26,8 @@ namespace milvus::segcore {
// All concurrent
class IndexingEntry {
public:
explicit IndexingEntry(const FieldMeta& field_meta) : field_meta_(field_meta) {
explicit IndexingEntry(const FieldMeta& field_meta, int64_t chunk_size)
: field_meta_(field_meta), chunk_size_(chunk_size) {
}
IndexingEntry(const IndexingEntry&) = delete;
IndexingEntry&
@ -41,9 +42,15 @@ class IndexingEntry {
return field_meta_;
}
int64_t
get_chunk_size() const {
return chunk_size_;
}
protected:
// additional info
const FieldMeta& field_meta_;
const int64_t chunk_size_;
};
template <typename T>
class ScalarIndexingEntry : public IndexingEntry {
@ -88,11 +95,11 @@ class VecIndexingEntry : public IndexingEntry {
};
std::unique_ptr<IndexingEntry>
CreateIndex(const FieldMeta& field_meta);
CreateIndex(const FieldMeta& field_meta, int64_t chunk_size);
class IndexingRecord {
public:
explicit IndexingRecord(const Schema& schema) : schema_(schema) {
explicit IndexingRecord(const Schema& schema, int64_t chunk_size) : schema_(schema), chunk_size_(chunk_size) {
Initialize();
}
@ -101,7 +108,7 @@ class IndexingRecord {
int offset = 0;
for (auto& field : schema_) {
if (field.get_data_type() != DataType::VECTOR_BINARY) {
entries_.try_emplace(offset, CreateIndex(field));
entries_.try_emplace(offset, CreateIndex(field, chunk_size_));
}
++offset;
}
@ -149,6 +156,7 @@ class IndexingRecord {
// std::atomic<int64_t> finished_ack_ = 0;
AckResponder finished_ack_;
std::mutex mutex_;
int64_t chunk_size_;
private:
// field_offset => indexing

View File

@ -13,14 +13,14 @@
namespace milvus::segcore {
InsertRecord::InsertRecord(const Schema& schema) : uids_(1), timestamps_(1) {
InsertRecord::InsertRecord(const Schema& schema, int64_t chunk_size) : uids_(1), timestamps_(1) {
for (auto& field : schema) {
if (field.is_vector()) {
if (field.get_data_type() == DataType::VECTOR_FLOAT) {
entity_vec_.emplace_back(std::make_shared<ConcurrentVector<FloatVector>>(field.get_dim()));
entity_vec_.emplace_back(std::make_shared<ConcurrentVector<FloatVector>>(field.get_dim(), chunk_size));
continue;
} else if (field.get_data_type() == DataType::VECTOR_BINARY) {
entity_vec_.emplace_back(std::make_shared<ConcurrentVector<BinaryVector>>(field.get_dim()));
entity_vec_.emplace_back(std::make_shared<ConcurrentVector<BinaryVector>>(field.get_dim(), chunk_size));
continue;
} else {
PanicInfo("unsupported");
@ -28,30 +28,30 @@ InsertRecord::InsertRecord(const Schema& schema) : uids_(1), timestamps_(1) {
}
switch (field.get_data_type()) {
case DataType::INT8: {
entity_vec_.emplace_back(std::make_shared<ConcurrentVector<int8_t>>());
entity_vec_.emplace_back(std::make_shared<ConcurrentVector<int8_t>>(chunk_size));
break;
}
case DataType::INT16: {
entity_vec_.emplace_back(std::make_shared<ConcurrentVector<int16_t>>());
entity_vec_.emplace_back(std::make_shared<ConcurrentVector<int16_t>>(chunk_size));
break;
}
case DataType::INT32: {
entity_vec_.emplace_back(std::make_shared<ConcurrentVector<int32_t>>());
entity_vec_.emplace_back(std::make_shared<ConcurrentVector<int32_t>>(chunk_size));
break;
}
case DataType::INT64: {
entity_vec_.emplace_back(std::make_shared<ConcurrentVector<int64_t>>());
entity_vec_.emplace_back(std::make_shared<ConcurrentVector<int64_t>>(chunk_size));
break;
}
case DataType::FLOAT: {
entity_vec_.emplace_back(std::make_shared<ConcurrentVector<float>>());
entity_vec_.emplace_back(std::make_shared<ConcurrentVector<float>>(chunk_size));
break;
}
case DataType::DOUBLE: {
entity_vec_.emplace_back(std::make_shared<ConcurrentVector<double>>());
entity_vec_.emplace_back(std::make_shared<ConcurrentVector<double>>(chunk_size));
break;
}
default: {

View File

@ -25,7 +25,7 @@ struct InsertRecord {
ConcurrentVector<idx_t> uids_;
std::vector<std::shared_ptr<VectorBase>> entity_vec_;
explicit InsertRecord(const Schema& schema);
explicit InsertRecord(const Schema& schema, int64_t chunk_size);
template <typename Type>
auto
get_entity(int offset) const {

View File

@ -10,7 +10,6 @@
// or implied. See the License for the specific language governing permissions and limitations under the License
#include "segcore/SegmentBase.h"
#include "segcore/SegmentNaive.h"
#include "segcore/SegmentSmallIndex.h"
namespace milvus::segcore {
@ -46,8 +45,8 @@ TestABI() {
}
std::unique_ptr<SegmentBase>
CreateSegment(SchemaPtr schema) {
auto segment = std::make_unique<SegmentSmallIndex>(schema);
CreateSegment(SchemaPtr schema, int64_t chunk_size) {
auto segment = std::make_unique<SegmentSmallIndex>(schema, chunk_size);
return segment;
}
} // namespace milvus::segcore

View File

@ -113,7 +113,7 @@ class SegmentBase {
using SegmentBasePtr = std::unique_ptr<SegmentBase>;
SegmentBasePtr
CreateSegment(SchemaPtr schema);
CreateSegment(SchemaPtr schema, int64_t chunk_size = 32 * 1024);
} // namespace segcore
} // namespace milvus

View File

@ -292,61 +292,7 @@ SegmentNaive::QueryImpl(query::QueryDeprecatedPtr query_info, Timestamp timestam
Status
SegmentNaive::QueryBruteForceImpl(query::QueryDeprecatedPtr query_info, Timestamp timestamp, QueryResult& results) {
auto ins_barrier = get_barrier(record_, timestamp);
auto del_barrier = get_barrier(deleted_record_, timestamp);
auto bitmap_holder = get_deleted_bitmap(del_barrier, timestamp, ins_barrier);
Assert(bitmap_holder);
auto& field = schema_->operator[](query_info->field_name);
Assert(field.get_data_type() == DataType::VECTOR_FLOAT);
auto dim = field.get_dim();
auto bitmap = bitmap_holder->bitmap_ptr;
auto topK = query_info->topK;
auto num_queries = query_info->num_queries;
auto total_count = topK * num_queries;
// TODO: optimize
auto the_offset_opt = schema_->get_offset(query_info->field_name);
Assert(the_offset_opt.has_value());
Assert(the_offset_opt.value() < record_.entity_vec_.size());
auto vec_ptr =
std::static_pointer_cast<ConcurrentVector<FloatVector>>(record_.entity_vec_.at(the_offset_opt.value()));
std::vector<int64_t> final_uids(total_count);
std::vector<float> final_dis(total_count, std::numeric_limits<float>::max());
auto max_chunk = (ins_barrier + DefaultElementPerChunk - 1) / DefaultElementPerChunk;
for (int chunk_id = 0; chunk_id < max_chunk; ++chunk_id) {
std::vector<int64_t> buf_uids(total_count, -1);
std::vector<float> buf_dis(total_count, std::numeric_limits<float>::max());
faiss::float_maxheap_array_t buf = {(size_t)num_queries, (size_t)topK, buf_uids.data(), buf_dis.data()};
auto src_data = vec_ptr->get_chunk(chunk_id).data();
auto nsize =
chunk_id != max_chunk - 1 ? DefaultElementPerChunk : ins_barrier - chunk_id * DefaultElementPerChunk;
auto offset = chunk_id * DefaultElementPerChunk;
faiss::BitsetView view(bitmap->data() + offset / 8, DefaultElementPerChunk);
faiss::knn_L2sqr(query_info->query_raw_data.data(), src_data, dim, num_queries, nsize, &buf, view);
if (chunk_id == 0) {
final_uids = buf_uids;
final_dis = buf_dis;
} else {
merge_into(num_queries, topK, final_dis.data(), final_uids.data(), buf_dis.data(), buf_uids.data());
}
}
for (auto& id : final_uids) {
id = record_.uids_[id];
}
results.result_ids_ = std::move(final_uids);
results.result_distances_ = std::move(final_dis);
results.topK_ = topK;
results.num_queries_ = num_queries;
// throw std::runtime_error("unimplemented");
return Status::OK();
PanicInfo("deprecated");
}
Status
@ -460,32 +406,7 @@ SegmentNaive::Close() {
template <typename Type>
knowhere::IndexPtr
SegmentNaive::BuildVecIndexImpl(const IndexMeta::Entry& entry) {
auto offset_opt = schema_->get_offset(entry.field_name);
Assert(offset_opt.has_value());
auto offset = offset_opt.value();
auto field = (*schema_)[offset];
auto dim = field.get_dim();
auto indexing = knowhere::VecIndexFactory::GetInstance().CreateVecIndex(entry.type, entry.mode);
auto chunk_size = record_.uids_.num_chunk();
auto& uids = record_.uids_;
auto entities = record_.get_entity<FloatVector>(offset);
std::vector<knowhere::DatasetPtr> datasets;
for (int chunk_id = 0; chunk_id < uids.num_chunk(); ++chunk_id) {
auto entities_chunk = entities->get_chunk(chunk_id).data();
int64_t count = chunk_id == uids.num_chunk() - 1 ? record_.reserved - chunk_id * DefaultElementPerChunk
: DefaultElementPerChunk;
datasets.push_back(knowhere::GenDataset(count, dim, entities_chunk));
}
for (auto& ds : datasets) {
indexing->Train(ds, entry.config);
}
for (auto& ds : datasets) {
indexing->AddWithoutIds(ds, entry.config);
}
return indexing;
PanicInfo("deprecated");
}
Status
@ -544,20 +465,7 @@ SegmentNaive::BuildIndex(IndexMetaPtr remote_index_meta) {
int64_t
SegmentNaive::GetMemoryUsageInBytes() {
int64_t total_bytes = 0;
if (index_ready_) {
auto& index_entries = index_meta_->get_entries();
for (auto [index_name, entry] : index_entries) {
Assert(schema_->operator[](entry.field_name).is_vector());
auto vec_ptr = std::static_pointer_cast<knowhere::VecIndex>(indexings_[index_name]);
total_bytes += vec_ptr->IndexSize();
}
}
int64_t ins_n = (record_.reserved + DefaultElementPerChunk - 1) & ~(DefaultElementPerChunk - 1);
total_bytes += ins_n * (schema_->get_total_sizeof() + 16 + 1);
int64_t del_n = (deleted_record_.reserved + DefaultElementPerChunk - 1) & ~(DefaultElementPerChunk - 1);
total_bytes += del_n * (16 * 2);
return total_bytes;
PanicInfo("Deprecated");
}
} // namespace milvus::segcore

View File

@ -124,7 +124,8 @@ class SegmentNaive : public SegmentBase {
friend std::unique_ptr<SegmentBase>
CreateSegment(SchemaPtr schema);
explicit SegmentNaive(const SchemaPtr& schema) : schema_(schema), record_(*schema) {
static constexpr int64_t deprecated_fixed_chunk_size = 32 * 1024;
explicit SegmentNaive(const SchemaPtr& schema) : schema_(schema), record_(*schema, deprecated_fixed_chunk_size) {
}
private:

View File

@ -16,7 +16,6 @@
#include <thread>
#include <queue>
#include "segcore/SegmentNaive.h"
#include <knowhere/index/vector_index/adapter/VectorAdapter.h>
#include <knowhere/index/vector_index/VecIndexFactory.h>
#include <faiss/utils/distances.h>
@ -179,7 +178,7 @@ SegmentSmallIndex::Insert(int64_t reserved_begin,
}
record_.ack_responder_.AddSegment(reserved_begin, reserved_begin + size);
indexing_record_.UpdateResourceAck(record_.ack_responder_.GetAck() / DefaultElementPerChunk, record_);
indexing_record_.UpdateResourceAck(record_.ack_responder_.GetAck() / chunk_size_, record_);
return Status::OK();
}
@ -243,8 +242,7 @@ SegmentSmallIndex::BuildVecIndexImpl(const IndexMeta::Entry& entry) {
std::vector<knowhere::DatasetPtr> datasets;
for (int chunk_id = 0; chunk_id < uids.num_chunk(); ++chunk_id) {
auto entities_chunk = entities->get_chunk(chunk_id).data();
int64_t count = chunk_id == uids.num_chunk() - 1 ? record_.reserved - chunk_id * DefaultElementPerChunk
: DefaultElementPerChunk;
int64_t count = chunk_id == uids.num_chunk() - 1 ? record_.reserved - chunk_id * chunk_size_ : chunk_size_;
datasets.push_back(knowhere::GenDataset(count, dim, entities_chunk));
}
for (auto& ds : datasets) {
@ -326,9 +324,9 @@ SegmentSmallIndex::GetMemoryUsageInBytes() {
}
}
#endif
int64_t ins_n = upper_align(record_.reserved, DefaultElementPerChunk);
int64_t ins_n = upper_align(record_.reserved, chunk_size_);
total_bytes += ins_n * (schema_->get_total_sizeof() + 16 + 1);
int64_t del_n = upper_align(deleted_record_.reserved, DefaultElementPerChunk);
int64_t del_n = upper_align(deleted_record_.reserved, chunk_size_);
total_bytes += del_n * (16 * 2);
return total_bytes;
}

View File

@ -131,10 +131,13 @@ class SegmentSmallIndex : public SegmentBase {
public:
friend std::unique_ptr<SegmentBase>
CreateSegment(SchemaPtr schema);
CreateSegment(SchemaPtr schema, int64_t chunk_size);
explicit SegmentSmallIndex(SchemaPtr schema)
: schema_(std::move(schema)), record_(*schema_), indexing_record_(*schema_) {
explicit SegmentSmallIndex(SchemaPtr schema, int64_t chunk_size)
: chunk_size_(chunk_size),
schema_(std::move(schema)),
record_(*schema_, chunk_size),
indexing_record_(*schema_, chunk_size) {
}
public:
@ -149,6 +152,7 @@ class SegmentSmallIndex : public SegmentBase {
FillTargetEntry(const query::Plan* Plan, QueryResult& results) override;
private:
int64_t chunk_size_;
SchemaPtr schema_;
std::atomic<SegmentState> state_ = SegmentState::Open;
IndexMetaPtr index_meta_;
@ -157,8 +161,6 @@ class SegmentSmallIndex : public SegmentBase {
DeletedRecord deleted_record_;
IndexingRecord indexing_record_;
// std::atomic<bool> index_ready_ = false;
// std::unordered_map<std::string, knowhere::IndexPtr> indexings_; // index_name => indexing
tbb::concurrent_unordered_multimap<idx_t, int64_t> uid2offset_;
};

View File

@ -38,7 +38,7 @@ TEST(ConcurrentVector, TestABI) {
TEST(ConcurrentVector, TestSingle) {
auto dim = 8;
ConcurrentVectorImpl<int, false, 32> c_vec(dim);
ConcurrentVectorImpl<int, false> c_vec(dim, 32);
std::default_random_engine e(42);
int data = 0;
auto total_count = 0;
@ -66,7 +66,7 @@ TEST(ConcurrentVector, TestMultithreads) {
constexpr int threads = 16;
std::vector<int64_t> total_counts(threads);
ConcurrentVectorImpl<int64_t, false, 32> c_vec(dim);
ConcurrentVectorImpl<int64_t, false> c_vec(dim, 32);
std::atomic<int64_t> ack_counter = 0;
// std::mutex mutex;

View File

@ -313,11 +313,11 @@ TEST(Expr, TestRange) {
dsl_string.replace(loc, 4, clause);
auto plan = CreatePlan(*schema, dsl_string);
auto final = visitor.call_child(*plan->plan_node_->predicate_.value());
EXPECT_EQ(final.size(), upper_div(N * num_iters, DefaultElementPerChunk));
EXPECT_EQ(final.size(), upper_div(N * num_iters, TestChunkSize));
for (int i = 0; i < N * num_iters; ++i) {
auto vec_id = i / DefaultElementPerChunk;
auto offset = i % DefaultElementPerChunk;
auto vec_id = i / TestChunkSize;
auto offset = i % TestChunkSize;
auto ans = final[vec_id][offset];
auto val = age_col[i];
@ -397,11 +397,11 @@ TEST(Expr, TestTerm) {
dsl_string.replace(loc, 4, clause);
auto plan = CreatePlan(*schema, dsl_string);
auto final = visitor.call_child(*plan->plan_node_->predicate_.value());
EXPECT_EQ(final.size(), upper_div(N * num_iters, DefaultElementPerChunk));
EXPECT_EQ(final.size(), upper_div(N * num_iters, TestChunkSize));
for (int i = 0; i < N * num_iters; ++i) {
auto vec_id = i / DefaultElementPerChunk;
auto offset = i % DefaultElementPerChunk;
auto vec_id = i / TestChunkSize;
auto offset = i % TestChunkSize;
auto ans = final[vec_id][offset];
auto val = age_col[i];
@ -499,11 +499,11 @@ TEST(Expr, TestSimpleDsl) {
// std::cout << dsl.dump(2);
auto plan = CreatePlan(*schema, dsl.dump());
auto final = visitor.call_child(*plan->plan_node_->predicate_.value());
EXPECT_EQ(final.size(), upper_div(N * num_iters, DefaultElementPerChunk));
EXPECT_EQ(final.size(), upper_div(N * num_iters, TestChunkSize));
for (int i = 0; i < N * num_iters; ++i) {
auto vec_id = i / DefaultElementPerChunk;
auto offset = i % DefaultElementPerChunk;
auto vec_id = i / TestChunkSize;
auto offset = i % TestChunkSize;
bool ans = final[vec_id][offset];
auto val = age_col[i];
auto ref = ref_func(val);

View File

@ -93,11 +93,11 @@ TEST(Indexing, SmartBruteForce) {
vector<int64_t> final_uids(total_count, -1);
vector<float> final_dis(total_count, std::numeric_limits<float>::max());
for (int beg = 0; beg < N; beg += DefaultElementPerChunk) {
for (int beg = 0; beg < N; beg += TestChunkSize) {
vector<int64_t> buf_uids(total_count, -1);
vector<float> buf_dis(total_count, std::numeric_limits<float>::max());
faiss::float_maxheap_array_t buf = {queries, TOPK, buf_uids.data(), buf_dis.data()};
auto end = beg + DefaultElementPerChunk;
auto end = beg + TestChunkSize;
if (end > N) {
end = N;
}
@ -152,8 +152,8 @@ TEST(Indexing, DISABLED_Naive) {
std::vector<knowhere::DatasetPtr> datasets;
std::vector<std::vector<float>> ftrashs;
auto raw = raw_data.data();
for (int beg = 0; beg < N; beg += DefaultElementPerChunk) {
auto end = beg + DefaultElementPerChunk;
for (int beg = 0; beg < N; beg += TestChunkSize) {
auto end = beg + TestChunkSize;
if (end > N) {
end = N;
}

View File

@ -202,7 +202,7 @@ TEST(Query, ExecWithPredicate) {
})";
int64_t N = 1000 * 1000;
auto dataset = DataGen(schema, N);
auto segment = std::make_unique<SegmentSmallIndex>(schema);
auto segment = CreateSegment(schema);
segment->PreInsert(N);
segment->Insert(0, N, dataset.row_ids_.data(), dataset.timestamps_.data(), dataset.raw_);
@ -292,7 +292,7 @@ TEST(Query, ExecTerm) {
})";
int64_t N = 1000 * 1000;
auto dataset = DataGen(schema, N);
auto segment = std::make_unique<SegmentSmallIndex>(schema);
auto segment = CreateSegment(schema);
segment->PreInsert(N);
segment->Insert(0, N, dataset.row_ids_.data(), dataset.timestamps_.data(), dataset.raw_);
@ -338,7 +338,7 @@ TEST(Query, ExecWithoutPredicate) {
})";
int64_t N = 1000 * 1000;
auto dataset = DataGen(schema, N);
auto segment = std::make_unique<SegmentSmallIndex>(schema);
auto segment = CreateSegment(schema);
segment->PreInsert(N);
segment->Insert(0, N, dataset.row_ids_.data(), dataset.timestamps_.data(), dataset.raw_);
@ -512,7 +512,7 @@ TEST(Query, ExecWithPredicateBinary) {
})";
int64_t N = 1000 * 1000;
auto dataset = DataGen(schema, N);
auto segment = std::make_unique<SegmentSmallIndex>(schema);
auto segment = CreateSegment(schema);
segment->PreInsert(N);
segment->Insert(0, N, dataset.row_ids_.data(), dataset.timestamps_.data(), dataset.raw_);
auto vec_ptr = dataset.get_col<uint8_t>(0);

View File

@ -0,0 +1,13 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#pragma once
constexpr int64_t TestChunkSize = 32 * 1024;

View File

@ -15,6 +15,7 @@
#include <memory>
#include <cstring>
#include "segcore/SegmentBase.h"
#include "Constants.h"
namespace milvus::segcore {
struct GeneratedData {