Add MinIO kv implements

Signed-off-by: godchen <qingxiang.chen@zilliz.com>
This commit is contained in:
godchen 2020-12-03 19:00:11 +08:00 committed by yefu.chen
parent cec903da19
commit 7ab5b5d80d
92 changed files with 7505 additions and 330 deletions

3
.env
View File

@ -1,8 +1,9 @@
REPO=milvusdb/milvus-distributed-dev
ARCH=amd64
UBUNTU=18.04
DATE_VERSION=20201120-092740
DATE_VERSION=20201202-085131
LATEST_DATE_VERSION=latest
MINIO_ADDRESS=minio:9000
PULSAR_ADDRESS=pulsar://pulsar:6650
ETCD_ADDRESS=etcd:2379
MASTER_ADDRESS=localhost:53100

View File

@ -51,7 +51,7 @@ jobs:
- name: Start Service
shell: bash
run: |
docker-compose up -d pulsar etcd
docker-compose up -d pulsar etcd minio
- name: Build and UnitTest
env:
CHECK_BUILDER: "1"

View File

@ -6,12 +6,12 @@ on:
push:
# file paths to consider in the event. Optional; defaults to all.
paths:
- 'build/docker/**'
- 'build/docker/env/**'
- '.github/workflows/publish-builder.yaml'
pull_request:
# file paths to consider in the event. Optional; defaults to all.
paths:
- 'build/docker/**'
- 'build/docker/env/**'
- '.github/workflows/publish-builder.yaml'
jobs:

View File

@ -12,12 +12,15 @@ dir ('build/docker/deploy') {
try {
withCredentials([usernamePassword(credentialsId: "${env.DOCKER_CREDENTIALS_ID}", usernameVariable: 'DOCKER_USERNAME', passwordVariable: 'DOCKER_PASSWORD')]) {
sh 'docker login -u ${DOCKER_USERNAME} -p ${DOCKER_PASSWORD} ${DOKCER_REGISTRY_URL}'
sh 'docker pull ${SOURCE_REPO}/master:${SOURCE_TAG} || true'
sh 'docker-compose build --force-rm master'
sh 'docker-compose push master'
sh 'docker pull ${SOURCE_REPO}/proxy:${SOURCE_TAG} || true'
sh 'docker-compose build --force-rm proxy'
sh 'docker-compose push proxy'
sh 'docker pull ${SOURCE_REPO}/querynode:${SOURCE_TAG} || true'
sh 'docker-compose build --force-rm querynode'
sh 'docker-compose push querynode'

View File

@ -0,0 +1,37 @@
try {
sh 'docker-compose -p ${DOCKER_COMPOSE_PROJECT_NAME} up -d etcd'
sh 'docker-compose -p ${DOCKER_COMPOSE_PROJECT_NAME} up -d pulsar'
dir ('build/docker/deploy') {
sh 'docker-compose -p ${DOCKER_COMPOSE_PROJECT_NAME} pull'
sh 'docker-compose -p ${DOCKER_COMPOSE_PROJECT_NAME} up -d'
}
dir ('build/docker/test') {
sh 'docker pull ${SOURCE_REPO}/pytest:${SOURCE_TAG} || true'
sh 'docker-compose build --force-rm regression'
sh 'docker-compose -p ${DOCKER_COMPOSE_PROJECT_NAME} run --rm regression'
try {
withCredentials([usernamePassword(credentialsId: "${env.DOCKER_CREDENTIALS_ID}", usernameVariable: 'DOCKER_USERNAME', passwordVariable: 'DOCKER_PASSWORD')]) {
sh 'docker login -u ${DOCKER_USERNAME} -p ${DOCKER_PASSWORD} ${DOKCER_REGISTRY_URL}'
sh 'docker-compose push regression'
}
} catch (exc) {
throw exc
} finally {
sh 'docker logout ${DOKCER_REGISTRY_URL}'
}
}
} catch(exc) {
throw exc
} finally {
sh 'docker-compose -p ${DOCKER_COMPOSE_PROJECT_NAME} rm -f -s -v pulsar'
sh 'docker-compose -p ${DOCKER_COMPOSE_PROJECT_NAME} rm -f -s -v etcd'
dir ('build/docker/deploy') {
sh 'docker-compose -p ${DOCKER_COMPOSE_PROJECT_NAME} down --rmi all -v || true'
}
dir ('build/docker/test') {
sh 'docker-compose -p ${DOCKER_COMPOSE_PROJECT_NAME} run --rm regression /bin/bash -c "rm -rf __pycache__ && rm -rf .pytest_cache"'
sh 'docker-compose -p ${DOCKER_COMPOSE_PROJECT_NAME} down --rmi all -v || true'
}
}

View File

@ -35,21 +35,21 @@ fmt:
@echo "Running $@ check"
@GO111MODULE=on env bash $(PWD)/scripts/gofmt.sh cmd/
@GO111MODULE=on env bash $(PWD)/scripts/gofmt.sh internal/
@GO111MODULE=on env bash $(PWD)/scripts/gofmt.sh test/
@GO111MODULE=on env bash $(PWD)/scripts/gofmt.sh tests/go/
#TODO: Check code specifications by golangci-lint
lint:
@echo "Running $@ check"
@GO111MODULE=on ${GOPATH}/bin/golangci-lint cache clean
@GO111MODULE=on ${GOPATH}/bin/golangci-lint run --timeout=3m --config ./.golangci.yml ./internal/...
@GO111MODULE=on ${GOPATH}/bin/golangci-lint run --timeout=3m --config ./.golangci.yml ./cmd/...
@GO111MODULE=on ${GOPATH}/bin/golangci-lint run --timeout=3m --config ./.golangci.yml ./test/...
@GO111MODULE=on ${GOPATH}/bin/golangci-lint run --timeout=5m --config ./.golangci.yml ./internal/...
@GO111MODULE=on ${GOPATH}/bin/golangci-lint run --timeout=5m --config ./.golangci.yml ./cmd/...
@GO111MODULE=on ${GOPATH}/bin/golangci-lint run --timeout=5m --config ./.golangci.yml ./tests/go/...
ruleguard:
@echo "Running $@ check"
@${GOPATH}/bin/ruleguard -rules ruleguard.rules.go ./internal/...
@${GOPATH}/bin/ruleguard -rules ruleguard.rules.go ./cmd/...
@${GOPATH}/bin/ruleguard -rules ruleguard.rules.go ./test/...
@${GOPATH}/bin/ruleguard -rules ruleguard.rules.go ./tests/go/...
verifiers: cppcheck fmt lint ruleguard

View File

@ -18,6 +18,11 @@ pipeline {
PACKAGE_ARTFACTORY_URL = "${JFROG_ARTFACTORY_URL}/${PROJECT_NAME}/package/${PACKAGE_NAME}"
DOCKER_CREDENTIALS_ID = "ba070c98-c8cc-4f7c-b657-897715f359fc"
DOKCER_REGISTRY_URL = "registry.zilliz.com"
SOURCE_REPO = "${DOKCER_REGISTRY_URL}/milvus-distributed"
TARGET_REPO = "${DOKCER_REGISTRY_URL}/milvus-distributed"
SOURCE_TAG = "${CHANGE_TARGET ? CHANGE_TARGET : SEMVER}-${LOWER_BUILD_TYPE}"
TARGET_TAG = "${SEMVER}-${LOWER_BUILD_TYPE}"
DOCKER_BUILDKIT = 1
}
stages {
stage ('Build and UnitTest') {
@ -51,18 +56,28 @@ pipeline {
yamlFile "build/ci/jenkins/pod/docker-pod.yaml"
}
}
environment{
SOURCE_REPO = "${DOKCER_REGISTRY_URL}/milvus-distributed"
TARGET_REPO = "${DOKCER_REGISTRY_URL}/milvus-distributed"
SOURCE_TAG = "${CHANGE_TARGET ? CHANGE_TARGET : SEMVER}-${LOWER_BUILD_TYPE}"
TARGET_TAG = "${SEMVER}-${LOWER_BUILD_TYPE}"
DOCKER_BUILDKIT = 1
}
steps {
container('publish-images') {
MPLModule('Publish')
}
}
}
stage ('Dev Test') {
agent {
label "performance"
}
environment {
DOCKER_COMPOSE_PROJECT_NAME = "${PROJECT_NAME}-${SEMVER}-${env.BUILD_NUMBER}".replaceAll("\\.", "-").replaceAll("_", "-")
}
steps {
MPLModule('Python Regression')
}
post {
cleanup {
deleteDir() /* clean up our workspace */
}
}
}
}
}

View File

@ -13,8 +13,6 @@ services:
ETCD_ADDRESS: ${ETCD_ADDRESS}
networks:
- milvus
ports:
- "53100:53100"
proxy:
image: ${TARGET_REPO}/proxy:${TARGET_TAG}
@ -26,8 +24,6 @@ services:
environment:
PULSAR_ADDRESS: ${PULSAR_ADDRESS}
MASTER_ADDRESS: ${MASTER_ADDRESS}
ports:
- "19530:19530"
networks:
- milvus

View File

@ -0,0 +1,18 @@
# Copyright (C) 2019-2020 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing permissions and limitations under the License.
FROM python:3.6.8-jessie
COPY ./tests/python/requirements.txt /requirements.txt
RUN python3 -m pip install -r /requirements.txt
CMD ["tail", "-f", "/dev/null"]

View File

@ -0,0 +1,20 @@
version: '3.5'
services:
regression:
image: ${TARGET_REPO}/pytest:${TARGET_TAG}
build:
context: ../../../
dockerfile: build/docker/test/Dockerfile
cache_from:
- ${SOURCE_REPO}/pytest:${SOURCE_TAG}
volumes:
- ../../..:/milvus-distributed:delegated
working_dir: "/milvus-distributed/tests/python"
command: >
/bin/bash -c "pytest --ip proxy"
networks:
- milvus
networks:
milvus:

View File

@ -42,6 +42,13 @@ etcd:
rootpath: by-dev
segthreshold: 10000
minio:
address: localhost
port: 9000
accessKeyID: minioadmin
secretAccessKey: minioadmin
useSSL: false
timesync:
interval: 400

View File

@ -20,6 +20,22 @@ services:
networks:
- milvus
minio:
image: minio/minio:RELEASE.2020-12-03T00-03-10Z
ports:
- "9000:9000"
environment:
MINIO_ACCESS_KEY: minioadmin
MINIO_SECRET_KEY: minioadmin
command: minio server /minio_data
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"]
interval: 30s
timeout: 20s
retries: 3
networks:
- milvus
networks:
milvus:

View File

@ -21,6 +21,7 @@ services:
PULSAR_ADDRESS: ${PULSAR_ADDRESS}
ETCD_ADDRESS: ${ETCD_ADDRESS}
MASTER_ADDRESS: ${MASTER_ADDRESS}
MINIO_ADDRESS: ${MINIO_ADDRESS}
volumes: &ubuntu-volumes
- .:/go/src/github.com/zilliztech/milvus-distributed:delegated
- ${DOCKER_VOLUME_DIRECTORY:-.docker}/${ARCH}-ubuntu${UBUNTU}-cache:/ccache:delegated
@ -45,6 +46,7 @@ services:
PULSAR_ADDRESS: ${PULSAR_ADDRESS}
ETCD_ADDRESS: ${ETCD_ADDRESS}
MASTER_ADDRESS: ${MASTER_ADDRESS}
MINIO_ADDRESS: ${MINIO_ADDRESS}
volumes:
- .:/go/src/github.com/zilliztech/milvus-distributed:delegated
- ${DOCKER_VOLUME_DIRECTORY:-.docker}/${ARCH}-ubuntu${UBUNTU}-gdbserver-home:/home/debugger:delegated
@ -59,16 +61,28 @@ services:
etcd:
image: quay.io/coreos/etcd:v3.4.13
command: etcd -advertise-client-urls=http://127.0.0.1:2379 -listen-client-urls http://0.0.0.0:2379
ports:
- "2379:2379"
networks:
- milvus
pulsar:
image: apachepulsar/pulsar:2.6.1
command: bin/pulsar standalone
networks:
- milvus
minio:
image: minio/minio:RELEASE.2020-12-03T00-03-10Z
ports:
- "6650:6650"
- "9000:9000"
environment:
MINIO_ACCESS_KEY: minioadmin
MINIO_SECRET_KEY: minioadmin
command: minio server /minio_data
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"]
interval: 30s
timeout: 20s
retries: 3
networks:
- milvus

View File

@ -184,7 +184,7 @@ Note that *tenantId*, *proxyId*, *collectionId*, *segmentId* are unique strings
```go
type metaTable struct {
kv kv.Base // client of a reliable kv service, i.e. etcd client
kv kv.TxnBase // client of a reliable kv service, i.e. etcd client
tenantId2Meta map[UniqueId]TenantMeta // tenant id to tenant meta
proxyId2Meta map[UniqueId]ProxyMeta // proxy id to proxy meta
collId2Meta map[UniqueId]CollectionMeta // collection id to collection meta
@ -216,7 +216,7 @@ func (meta *metaTable) GetSegmentById(segId UniqueId)(*SegmentMeta, error)
func (meta *metaTable) DeleteSegment(segId UniqueId) error
func (meta *metaTable) CloseSegment(segId UniqueId, closeTs Timestamp, num_rows int64) error
func NewMetaTable(kv kv.Base) (*metaTable,error)
func NewMetaTable(kv kv.TxnBase) (*metaTable,error)
```
*metaTable* maintains meta both in memory and *etcdKV*. It keeps meta's consistency in both sides. All its member functions may be called concurrently.

View File

@ -54,7 +54,7 @@ func (ia *IDAllocator) syncID() {
cancel()
if err != nil {
log.Panic("syncID Failed!!!!!")
log.Println("syncID Failed!!!!!")
return
}
ia.idStart = resp.GetID()

View File

@ -210,7 +210,7 @@ func (sa *SegIDAssigner) syncSegments() {
}
if err != nil {
log.Panic("syncID Failed!!!!!")
log.Println("syncSemgnet Failed!!!!!")
return
}
}

View File

@ -63,7 +63,7 @@ func (ta *TimestampAllocator) syncTs() {
cancel()
if err != nil {
log.Panic("syncID Failed!!!!!")
log.Println("syncTimestamp Failed!!!!!")
return
}
ta.lastTsBegin = resp.GetTimestamp()

View File

@ -93,8 +93,8 @@ endif ()
set(INDEX_INCLUDE_DIRS ${INDEX_INCLUDE_DIRS} PARENT_SCOPE)
if (KNOWHERE_BUILD_TESTS)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DELPP_DISABLE_LOGS")
add_subdirectory(unittest)
#set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DELPP_DISABLE_LOGS")
#add_subdirectory(unittest)
endif ()
config_summary()

View File

@ -143,6 +143,7 @@ endif ()
target_link_libraries(
knowhere
milvus_utils
${depend_libs}
)

View File

@ -12,4 +12,4 @@ set(MILVUS_QUERY_SRCS
BruteForceSearch.cpp
)
add_library(milvus_query ${MILVUS_QUERY_SRCS})
target_link_libraries(milvus_query milvus_proto)
target_link_libraries(milvus_query milvus_proto milvus_utils)

View File

@ -24,18 +24,22 @@ namespace milvus::query {
static std::unique_ptr<VectorPlanNode>
ParseVecNode(Plan* plan, const Json& out_body) {
Assert(out_body.is_object());
// TODO add binary info
auto vec_node = std::make_unique<FloatVectorANNS>();
Assert(out_body.size() == 1);
auto iter = out_body.begin();
std::string field_name = iter.key();
auto& vec_info = iter.value();
Assert(vec_info.is_object());
auto topK = vec_info["topk"];
AssertInfo(topK > 0, "topK must greater than 0");
AssertInfo(topK < 16384, "topK is too large");
vec_node->query_info_.topK_ = topK;
vec_node->query_info_.metric_type_ = vec_info["metric_type"];
vec_node->query_info_.search_params_ = vec_info["params"];
vec_node->query_info_.metric_type_ = vec_info.at("metric_type");
vec_node->query_info_.search_params_ = vec_info.at("params");
vec_node->query_info_.field_id_ = field_name;
vec_node->placeholder_tag_ = vec_info["query"];
vec_node->placeholder_tag_ = vec_info.at("query");
auto tag = vec_node->placeholder_tag_;
AssertInfo(!plan->tag2field_.count(tag), "duplicated placeholder tag");
plan->tag2field_.emplace(tag, field_name);
@ -56,6 +60,8 @@ to_lower(const std::string& raw) {
return data;
}
template <class...>
constexpr std::false_type always_false{};
template <typename T>
std::unique_ptr<Expr>
ParseRangeNodeImpl(const Schema& schema, const std::string& field_name, const Json& body) {
@ -63,11 +69,19 @@ ParseRangeNodeImpl(const Schema& schema, const std::string& field_name, const Js
auto data_type = schema[field_name].get_data_type();
expr->data_type_ = data_type;
expr->field_id_ = field_name;
Assert(body.is_object());
for (auto& item : body.items()) {
auto op_name = to_lower(item.key());
AssertInfo(RangeExpr::mapping_.count(op_name), "op(" + op_name + ") not found");
auto op = RangeExpr::mapping_.at(op_name);
if constexpr (std::is_integral_v<T>) {
Assert(item.value().is_number_integer());
} else if constexpr (std::is_floating_point_v<T>) {
Assert(item.value().is_number());
} else {
static_assert(always_false<T>, "unsupported type");
}
T value = item.value();
expr->conditions_.emplace_back(op, value);
}
@ -83,8 +97,10 @@ ParseRangeNode(const Schema& schema, const Json& out_body) {
auto data_type = schema[field_name].get_data_type();
Assert(!field_is_vector(data_type));
switch (data_type) {
case DataType::BOOL:
return ParseRangeNodeImpl<bool>(schema, field_name, body);
case DataType::BOOL: {
PanicInfo("bool is not supported in Range node");
// return ParseRangeNodeImpl<bool>(schema, field_name, body);
}
case DataType::INT8:
return ParseRangeNodeImpl<int8_t>(schema, field_name, body);
case DataType::INT16:
@ -109,16 +125,17 @@ CreatePlanImplNaive(const Schema& schema, const std::string& dsl_str) {
nlohmann::json vec_pack;
std::optional<std::unique_ptr<Expr>> predicate;
auto& bool_dsl = dsl["bool"];
auto& bool_dsl = dsl.at("bool");
if (bool_dsl.contains("must")) {
auto& packs = bool_dsl["must"];
auto& packs = bool_dsl.at("must");
Assert(packs.is_array());
for (auto& pack : packs) {
if (pack.contains("vector")) {
auto& out_body = pack["vector"];
auto& out_body = pack.at("vector");
plan->plan_node_ = ParseVecNode(plan.get(), out_body);
} else if (pack.contains("range")) {
AssertInfo(!predicate, "unsupported complex DSL");
auto& out_body = pack["range"];
auto& out_body = pack.at("range");
predicate = ParseRangeNode(schema, out_body);
} else {
PanicInfo("unsupported node");
@ -126,7 +143,7 @@ CreatePlanImplNaive(const Schema& schema, const std::string& dsl_str) {
}
AssertInfo(plan->plan_node_, "vector node not found");
} else if (bool_dsl.contains("vector")) {
auto& out_body = bool_dsl["vector"];
auto& out_body = bool_dsl.at("vector");
plan->plan_node_ = ParseVecNode(plan.get(), out_body);
Assert(plan->plan_node_);
} else {

View File

@ -17,7 +17,7 @@ add_library(milvus_segcore SHARED
)
target_link_libraries(milvus_segcore
tbb utils pthread knowhere log milvus_proto
tbb milvus_utils pthread knowhere log milvus_proto
dl backtrace
milvus_common
milvus_query

View File

@ -19,6 +19,9 @@ NewCollection(const char* collection_proto) {
auto collection = std::make_unique<milvus::segcore::Collection>(proto);
// TODO: delete print
std::cout << "create collection " << collection->get_collection_name() << std::endl;
return (void*)collection.release();
}

View File

@ -50,97 +50,187 @@ DeleteMarshaledHits(CMarshaledHits c_marshaled_hits) {
}
struct SearchResultPair {
uint64_t id_;
float distance_;
int64_t segment_id_;
SearchResult* search_result_;
int64_t offset_;
int64_t index_;
SearchResultPair(uint64_t id, float distance, int64_t segment_id)
: id_(id), distance_(distance), segment_id_(segment_id) {
SearchResultPair(float distance, SearchResult* search_result, int64_t offset, int64_t index)
: distance_(distance), search_result_(search_result), offset_(offset), index_(index) {
}
bool
operator<(const SearchResultPair& pair) const {
return (distance_ < pair.distance_);
}
void
reset_distance() {
distance_ = search_result_->result_distances_[offset_];
}
};
void
GetResultData(std::vector<SearchResult*>& search_results,
SearchResult& final_result,
GetResultData(std::vector<std::vector<int64_t>>& search_records,
std::vector<SearchResult*>& search_results,
int64_t query_offset,
bool* is_selected,
int64_t topk) {
auto num_segments = search_results.size();
std::map<int, int> iter_loc_peer_result;
AssertInfo(num_segments > 0, "num segment must greater than 0");
std::vector<SearchResultPair> result_pairs;
for (int j = 0; j < num_segments; ++j) {
auto id = search_results[j]->result_ids_[query_offset];
auto distance = search_results[j]->result_distances_[query_offset];
result_pairs.push_back(SearchResultPair(id, distance, j));
iter_loc_peer_result[j] = query_offset;
auto search_result = search_results[j];
AssertInfo(search_result != nullptr, "search result must not equal to nullptr");
result_pairs.push_back(SearchResultPair(distance, search_result, query_offset, j));
}
std::sort(result_pairs.begin(), result_pairs.end());
final_result.result_ids_.push_back(result_pairs[0].id_);
final_result.result_distances_.push_back(result_pairs[0].distance_);
for (int i = 1; i < topk; ++i) {
auto segment_id = result_pairs[0].segment_id_;
auto query_offset = ++(iter_loc_peer_result[segment_id]);
auto id = search_results[segment_id]->result_ids_[query_offset];
auto distance = search_results[segment_id]->result_distances_[query_offset];
result_pairs[0] = SearchResultPair(id, distance, segment_id);
int64_t loc_offset = query_offset;
AssertInfo(topk > 0, "topK must greater than 0");
for (int i = 0; i < topk; ++i) {
result_pairs[0].reset_distance();
std::sort(result_pairs.begin(), result_pairs.end());
final_result.result_ids_.push_back(result_pairs[0].id_);
final_result.result_distances_.push_back(result_pairs[0].distance_);
auto& result_pair = result_pairs[0];
auto index = result_pair.index_;
is_selected[index] = true;
result_pair.search_result_->result_offsets_.push_back(loc_offset++);
search_records[index].push_back(result_pair.offset_++);
}
}
CQueryResult
ReduceQueryResults(CQueryResult* query_results, int64_t num_segments) {
void
ResetSearchResult(std::vector<std::vector<int64_t>>& search_records,
std::vector<SearchResult*>& search_results,
bool* is_selected) {
auto num_segments = search_results.size();
AssertInfo(num_segments > 0, "num segment must greater than 0");
for (int i = 0; i < num_segments; i++) {
if (is_selected[i] == false) {
continue;
}
auto search_result = search_results[i];
AssertInfo(search_result != nullptr, "search result must not equal to nullptr");
std::vector<float> result_distances;
std::vector<int64_t> internal_seg_offsets;
std::vector<int64_t> result_ids;
for (int j = 0; j < search_records[i].size(); j++) {
auto& offset = search_records[i][j];
auto distance = search_result->result_distances_[offset];
auto internal_seg_offset = search_result->internal_seg_offsets_[offset];
auto id = search_result->result_ids_[offset];
result_distances.push_back(distance);
internal_seg_offsets.push_back(internal_seg_offset);
result_ids.push_back(id);
}
search_result->result_distances_ = result_distances;
search_result->internal_seg_offsets_ = internal_seg_offsets;
search_result->result_ids_ = result_ids;
}
}
CStatus
ReduceQueryResults(CQueryResult* c_search_results, int64_t num_segments, bool* is_selected) {
std::vector<SearchResult*> search_results;
for (int i = 0; i < num_segments; ++i) {
search_results.push_back((SearchResult*)query_results[i]);
search_results.push_back((SearchResult*)c_search_results[i]);
}
auto topk = search_results[0]->topK_;
auto num_queries = search_results[0]->num_queries_;
auto final_result = std::make_unique<SearchResult>();
try {
auto topk = search_results[0]->topK_;
auto num_queries = search_results[0]->num_queries_;
std::vector<std::vector<int64_t>> search_records(num_segments);
int64_t query_offset = 0;
for (int j = 0; j < num_queries; ++j) {
GetResultData(search_results, *final_result, query_offset, topk);
query_offset += topk;
int64_t query_offset = 0;
for (int j = 0; j < num_queries; ++j) {
GetResultData(search_records, search_results, query_offset, is_selected, topk);
query_offset += topk;
}
ResetSearchResult(search_records, search_results, is_selected);
auto status = CStatus();
status.error_code = Success;
status.error_msg = "";
return status;
} catch (std::exception& e) {
auto status = CStatus();
status.error_code = UnexpectedException;
status.error_msg = strdup(e.what());
return status;
}
return (CQueryResult)final_result.release();
}
CMarshaledHits
ReorganizeQueryResults(CQueryResult c_query_result,
CPlan c_plan,
CStatus
ReorganizeQueryResults(CMarshaledHits* c_marshaled_hits,
CPlaceholderGroup* c_placeholder_groups,
int64_t num_groups) {
auto marshaledHits = std::make_unique<MarshaledHits>(num_groups);
auto search_result = (milvus::engine::QueryResult*)c_query_result;
auto& result_ids = search_result->result_ids_;
auto& result_distances = search_result->result_distances_;
auto topk = GetTopK(c_plan);
int64_t queries_offset = 0;
for (int i = 0; i < num_groups; i++) {
auto num_queries = GetNumOfQueries(c_placeholder_groups[i]);
MarshaledHitsPeerGroup& hits_peer_group = (*marshaledHits).marshaled_hits_[i];
for (int j = 0; j < num_queries; j++) {
auto index = topk * queries_offset++;
milvus::proto::service::Hits hits;
for (int k = index; k < index + topk; k++) {
hits.add_ids(result_ids[k]);
hits.add_scores(result_distances[k]);
}
auto blob = hits.SerializeAsString();
hits_peer_group.hits_.push_back(blob);
hits_peer_group.blob_length_.push_back(blob.size());
int64_t num_groups,
CQueryResult* c_search_results,
bool* is_selected,
int64_t num_segments,
CPlan c_plan) {
try {
auto marshaledHits = std::make_unique<MarshaledHits>(num_groups);
auto topk = GetTopK(c_plan);
std::vector<int64_t> num_queries_peer_group;
int64_t total_num_queries = 0;
for (int i = 0; i < num_groups; i++) {
auto num_queries = GetNumOfQueries(c_placeholder_groups[i]);
num_queries_peer_group.push_back(num_queries);
total_num_queries += num_queries;
}
}
return (CMarshaledHits)marshaledHits.release();
std::vector<float> result_distances(total_num_queries * topk);
std::vector<int64_t> result_ids(total_num_queries * topk);
std::vector<std::vector<char>> row_datas(total_num_queries * topk);
int64_t count = 0;
for (int i = 0; i < num_segments; i++) {
if (is_selected[i] == false) {
continue;
}
auto search_result = (SearchResult*)c_search_results[i];
AssertInfo(search_result != nullptr, "search result must not equal to nullptr");
auto size = search_result->result_offsets_.size();
for (int j = 0; j < size; j++) {
auto loc = search_result->result_offsets_[j];
result_distances[loc] = search_result->result_distances_[j];
row_datas[loc] = search_result->row_data_[j];
result_ids[loc] = search_result->result_ids_[j];
}
count += size;
}
AssertInfo(count == total_num_queries * topk, "the reduces result's size less than total_num_queries*topk");
int64_t fill_hit_offset = 0;
for (int i = 0; i < num_groups; i++) {
MarshaledHitsPeerGroup& hits_peer_group = (*marshaledHits).marshaled_hits_[i];
for (int j = 0; j < num_queries_peer_group[i]; j++) {
milvus::proto::service::Hits hits;
for (int k = 0; k < topk; k++, fill_hit_offset++) {
hits.add_ids(result_ids[fill_hit_offset]);
hits.add_scores(result_distances[fill_hit_offset]);
auto& row_data = row_datas[fill_hit_offset];
hits.add_row_data(row_data.data(), row_data.size());
}
auto blob = hits.SerializeAsString();
hits_peer_group.hits_.push_back(blob);
hits_peer_group.blob_length_.push_back(blob.size());
}
}
auto status = CStatus();
status.error_code = Success;
status.error_msg = "";
auto marshled_res = (CMarshaledHits)marshaledHits.release();
*c_marshaled_hits = marshled_res;
return status;
} catch (std::exception& e) {
auto status = CStatus();
status.error_code = UnexpectedException;
status.error_msg = strdup(e.what());
*c_marshaled_hits = nullptr;
return status;
}
}
int64_t

View File

@ -25,14 +25,17 @@ DeleteMarshaledHits(CMarshaledHits c_marshaled_hits);
int
MergeInto(int64_t num_queries, int64_t topk, float* distances, int64_t* uids, float* new_distances, int64_t* new_uids);
CQueryResult
ReduceQueryResults(CQueryResult* query_results, int64_t num_segments);
CStatus
ReduceQueryResults(CQueryResult* query_results, int64_t num_segments, bool* is_selected);
CMarshaledHits
ReorganizeQueryResults(CQueryResult query_result,
CPlan c_plan,
CStatus
ReorganizeQueryResults(CMarshaledHits* c_marshaled_hits,
CPlaceholderGroup* c_placeholder_groups,
int64_t num_groups);
int64_t num_groups,
CQueryResult* c_search_results,
bool* is_selected,
int64_t num_segments,
CPlan c_plan);
int64_t
GetHitsBlobSize(CMarshaledHits c_marshaled_hits);

View File

@ -155,6 +155,24 @@ Search(CSegmentBase c_segment,
return status;
}
CStatus
FillTargetEntry(CSegmentBase c_segment, CPlan c_plan, CQueryResult c_result) {
auto segment = (milvus::segcore::SegmentBase*)c_segment;
auto plan = (milvus::query::Plan*)c_plan;
auto result = (milvus::engine::QueryResult*)c_result;
auto status = CStatus();
try {
auto res = segment->FillTargetEntry(plan, *result);
status.error_code = Success;
status.error_msg = "";
} catch (std::runtime_error& e) {
status.error_code = UnexpectedException;
status.error_msg = strdup(e.what());
}
return status;
}
//////////////////////////////////////////////////////////////////
int

View File

@ -61,6 +61,9 @@ Search(CSegmentBase c_segment,
int num_groups,
CQueryResult* result);
CStatus
FillTargetEntry(CSegmentBase c_segment, CPlan c_plan, CQueryResult result);
//////////////////////////////////////////////////////////////////
int

View File

@ -17,8 +17,8 @@ set(UTILS_FILES
EasyAssert.cpp
)
add_library( utils STATIC ${UTILS_FILES} )
add_library( milvus_utils STATIC ${UTILS_FILES} )
target_link_libraries(utils
target_link_libraries(milvus_utils
libboost_filesystem.a
libboost_system.a)

View File

@ -11,11 +11,20 @@
#include <iostream>
#include "EasyAssert.h"
// #define BOOST_STACKTRACE_USE_ADDR2LINE
#define BOOST_STACKTRACE_USE_BACKTRACE
#include <boost/stacktrace.hpp>
#include <sstream>
namespace milvus::impl {
std::string
EasyStackTrace() {
auto stack_info = boost::stacktrace::stacktrace();
std::ostringstream ss;
ss << stack_info;
return ss.str();
}
void
EasyAssertInfo(
bool value, std::string_view expr_str, std::string_view filename, int lineno, std::string_view extra_info) {
@ -26,11 +35,15 @@ EasyAssertInfo(
if (!extra_info.empty()) {
info += " => " + std::string(extra_info);
}
auto fuck = boost::stacktrace::stacktrace();
std::cout << fuck;
// std::string s = fuck;
// info += ;
throw std::runtime_error(info);
throw std::runtime_error(info + "\n" + EasyStackTrace());
}
}
[[noreturn]] void
ThrowWithTrace(const std::exception& exception) {
auto err_msg = exception.what() + std::string("\n") + EasyStackTrace();
throw std::runtime_error(err_msg);
}
} // namespace milvus::impl

View File

@ -11,6 +11,7 @@
#pragma once
#include <string_view>
#include <exception>
#include <stdio.h>
#include <stdlib.h>
@ -20,7 +21,11 @@ namespace milvus::impl {
void
EasyAssertInfo(
bool value, std::string_view expr_str, std::string_view filename, int lineno, std::string_view extra_info);
}
[[noreturn]] void
ThrowWithTrace(const std::exception& exception);
} // namespace milvus::impl
#define AssertInfo(expr, info) milvus::impl::EasyAssertInfo(bool(expr), #expr, __FILE__, __LINE__, (info))
#define Assert(expr) AssertInfo((expr), "")

View File

@ -11,6 +11,13 @@
#pragma once
#include "utils/EasyAssert.h"
#define JSON_ASSERT(expr) Assert((expr))
// TODO: dispatch error by type
#define JSON_THROW_USER(e) milvus::impl::ThrowWithTrace((e))
#include "nlohmann/json.hpp"
namespace milvus {

View File

@ -24,5 +24,6 @@ target_link_libraries(all_tests
knowhere
log
pthread
milvus_utils
)
install (TARGETS all_tests DESTINATION unittest)

View File

@ -641,8 +641,14 @@ TEST(CApiTest, Reduce) {
results.push_back(res1);
results.push_back(res2);
auto reduced_search_result = ReduceQueryResults(results.data(), 2);
auto reorganize_search_result = ReorganizeQueryResults(reduced_search_result, plan, placeholderGroups.data(), 1);
bool is_selected[1] = {false};
status = ReduceQueryResults(results.data(), 1, is_selected);
assert(status.error_code == Success);
FillTargetEntry(segment, plan, res1);
void* reorganize_search_result = nullptr;
status = ReorganizeQueryResults(&reorganize_search_result, placeholderGroups.data(), 1, results.data(), is_selected,
1, plan);
assert(status.error_code == Success);
auto hits_blob_size = GetHitsBlobSize(reorganize_search_result);
assert(hits_blob_size > 0);
std::vector<char> hits_blob;
@ -660,7 +666,6 @@ TEST(CApiTest, Reduce) {
DeletePlaceholderGroup(placeholderGroup);
DeleteQueryResult(res1);
DeleteQueryResult(res2);
DeleteQueryResult(reduced_search_result);
DeleteMarshaledHits(reorganize_search_result);
DeleteCollection(collection);
DeleteSegment(segment);

View File

@ -107,6 +107,83 @@ TEST(Expr, Range) {
std::cout << out.dump(4);
}
TEST(Expr, InvalidRange) {
SUCCEED();
using namespace milvus;
using namespace milvus::query;
using namespace milvus::segcore;
std::string dsl_string = R"(
{
"bool": {
"must": [
{
"range": {
"age": {
"GT": 1,
"LT": "100"
}
}
},
{
"vector": {
"fakevec": {
"metric_type": "L2",
"params": {
"nprobe": 10
},
"query": "$0",
"topk": 10
}
}
}
]
}
})";
auto schema = std::make_shared<Schema>();
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16);
schema->AddField("age", DataType::INT32);
ASSERT_ANY_THROW(CreatePlan(*schema, dsl_string));
}
TEST(Expr, InvalidDSL) {
SUCCEED();
using namespace milvus;
using namespace milvus::query;
using namespace milvus::segcore;
std::string dsl_string = R"(
{
"float": {
"must": [
{
"range": {
"age": {
"GT": 1,
"LT": 100
}
}
},
{
"vector": {
"fakevec": {
"metric_type": "L2",
"params": {
"nprobe": 10
},
"query": "$0",
"topk": 10
}
}
}
]
}
})";
auto schema = std::make_shared<Schema>();
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16);
schema->AddField("age", DataType::INT32);
ASSERT_ANY_THROW(CreatePlan(*schema, dsl_string));
}
TEST(Expr, ShowExecutor) {
using namespace milvus::query;
using namespace milvus::segcore;

View File

@ -1,4 +1,4 @@
package kv
package etcdkv
import (
"context"

View File

@ -1,11 +1,11 @@
package kv_test
package etcdkv_test
import (
"os"
"testing"
"github.com/stretchr/testify/assert"
"github.com/zilliztech/milvus-distributed/internal/kv"
etcdkv "github.com/zilliztech/milvus-distributed/internal/kv/etcd"
"github.com/zilliztech/milvus-distributed/internal/util/paramtable"
"go.etcd.io/etcd/clientv3"
)
@ -28,7 +28,7 @@ func TestEtcdKV_Load(t *testing.T) {
cli, err := clientv3.New(clientv3.Config{Endpoints: []string{etcdAddr}})
assert.Nil(t, err)
rootPath := "/etcd/test/root"
etcdKV := kv.NewEtcdKV(cli, rootPath)
etcdKV := etcdkv.NewEtcdKV(cli, rootPath)
defer etcdKV.Close()
defer etcdKV.RemoveWithPrefix("")
@ -86,7 +86,7 @@ func TestEtcdKV_MultiSave(t *testing.T) {
cli, err := clientv3.New(clientv3.Config{Endpoints: []string{etcdAddr}})
assert.Nil(t, err)
rootPath := "/etcd/test/root"
etcdKV := kv.NewEtcdKV(cli, rootPath)
etcdKV := etcdkv.NewEtcdKV(cli, rootPath)
defer etcdKV.Close()
defer etcdKV.RemoveWithPrefix("")
@ -117,7 +117,7 @@ func TestEtcdKV_Remove(t *testing.T) {
cli, err := clientv3.New(clientv3.Config{Endpoints: []string{etcdAddr}})
assert.Nil(t, err)
rootPath := "/etcd/test/root"
etcdKV := kv.NewEtcdKV(cli, rootPath)
etcdKV := etcdkv.NewEtcdKV(cli, rootPath)
defer etcdKV.Close()
defer etcdKV.RemoveWithPrefix("")
@ -188,7 +188,7 @@ func TestEtcdKV_MultiSaveAndRemove(t *testing.T) {
cli, err := clientv3.New(clientv3.Config{Endpoints: []string{etcdAddr}})
assert.Nil(t, err)
rootPath := "/etcd/test/root"
etcdKV := kv.NewEtcdKV(cli, rootPath)
etcdKV := etcdkv.NewEtcdKV(cli, rootPath)
defer etcdKV.Close()
defer etcdKV.RemoveWithPrefix("")

View File

@ -8,6 +8,11 @@ type Base interface {
MultiSave(kvs map[string]string) error
Remove(key string) error
MultiRemove(keys []string) error
MultiSaveAndRemove(saves map[string]string, removals []string) error
Close()
}
type TxnBase interface {
Base
MultiSaveAndRemove(saves map[string]string, removals []string) error
}

View File

@ -1,4 +1,4 @@
package kv
package memkv
import (
"sync"

View File

@ -0,0 +1,149 @@
package miniokv
import (
"context"
"io"
"log"
"strings"
"github.com/minio/minio-go/v7"
)
type MinIOKV struct {
ctx context.Context
minioClient *minio.Client
bucketName string
}
// NewMinIOKV creates a new MinIO kv.
func NewMinIOKV(ctx context.Context, client *minio.Client, bucketName string) (*MinIOKV, error) {
bucketExists, err := client.BucketExists(ctx, bucketName)
if err != nil {
return nil, err
}
if !bucketExists {
err = client.MakeBucket(ctx, bucketName, minio.MakeBucketOptions{})
if err != nil {
return nil, err
}
}
return &MinIOKV{
ctx: ctx,
minioClient: client,
bucketName: bucketName,
}, nil
}
func (kv *MinIOKV) LoadWithPrefix(key string) ([]string, []string, error) {
objects := kv.minioClient.ListObjects(kv.ctx, kv.bucketName, minio.ListObjectsOptions{Prefix: key})
var objectsKeys []string
var objectsValues []string
for object := range objects {
objectsKeys = append(objectsKeys, object.Key)
}
objectsValues, err := kv.MultiLoad(objectsKeys)
if err != nil {
log.Printf("cannot load value with prefix:%s", key)
}
return objectsKeys, objectsValues, nil
}
func (kv *MinIOKV) Load(key string) (string, error) {
object, err := kv.minioClient.GetObject(kv.ctx, kv.bucketName, key, minio.GetObjectOptions{})
if err != nil {
return "", err
}
buf := new(strings.Builder)
_, err = io.Copy(buf, object)
if err != nil && err != io.EOF {
return "", err
}
return buf.String(), nil
}
func (kv *MinIOKV) MultiLoad(keys []string) ([]string, error) {
var resultErr error
var objectsValues []string
for _, key := range keys {
objectValue, err := kv.Load(key)
if err != nil {
if resultErr == nil {
resultErr = err
}
}
objectsValues = append(objectsValues, objectValue)
}
return objectsValues, resultErr
}
func (kv *MinIOKV) Save(key, value string) error {
reader := strings.NewReader(value)
_, err := kv.minioClient.PutObject(kv.ctx, kv.bucketName, key, reader, int64(len(value)), minio.PutObjectOptions{})
if err != nil {
return err
}
return err
}
func (kv *MinIOKV) MultiSave(kvs map[string]string) error {
var resultErr error
for key, value := range kvs {
err := kv.Save(key, value)
if err != nil {
if resultErr == nil {
resultErr = err
}
}
}
return resultErr
}
func (kv *MinIOKV) RemoveWithPrefix(prefix string) error {
objectsCh := make(chan minio.ObjectInfo)
go func() {
defer close(objectsCh)
for object := range kv.minioClient.ListObjects(kv.ctx, kv.bucketName, minio.ListObjectsOptions{Prefix: prefix}) {
objectsCh <- object
}
}()
for rErr := range kv.minioClient.RemoveObjects(kv.ctx, kv.bucketName, objectsCh, minio.RemoveObjectsOptions{GovernanceBypass: true}) {
if rErr.Err != nil {
return rErr.Err
}
}
return nil
}
func (kv *MinIOKV) Remove(key string) error {
err := kv.minioClient.RemoveObject(kv.ctx, kv.bucketName, string(key), minio.RemoveObjectOptions{})
return err
}
func (kv *MinIOKV) MultiRemove(keys []string) error {
var resultErr error
for _, key := range keys {
err := kv.Remove(key)
if err != nil {
if resultErr == nil {
resultErr = err
}
}
}
return resultErr
}
func (kv *MinIOKV) Close() {
}

View File

@ -0,0 +1,195 @@
package miniokv_test
import (
"context"
"strconv"
"testing"
"github.com/minio/minio-go/v7"
"github.com/minio/minio-go/v7/pkg/credentials"
miniokv "github.com/zilliztech/milvus-distributed/internal/kv/minio"
"github.com/zilliztech/milvus-distributed/internal/util/paramtable"
"github.com/stretchr/testify/assert"
)
var Params paramtable.BaseTable
func TestMinIOKV_Load(t *testing.T) {
Params.Init()
endPoint, _ := Params.Load("_MinioAddress")
accessKeyID, _ := Params.Load("minio.accessKeyID")
secretAccessKey, _ := Params.Load("minio.secretAccessKey")
useSSLStr, _ := Params.Load("minio.useSSL")
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
useSSL, _ := strconv.ParseBool(useSSLStr)
minioClient, err := minio.New(endPoint, &minio.Options{
Creds: credentials.NewStaticV4(accessKeyID, secretAccessKey, ""),
Secure: useSSL,
})
assert.Nil(t, err)
bucketName := "fantastic-tech-test"
MinIOKV, err := miniokv.NewMinIOKV(ctx, minioClient, bucketName)
assert.Nil(t, err)
defer MinIOKV.RemoveWithPrefix("")
err = MinIOKV.Save("abc", "123")
assert.Nil(t, err)
err = MinIOKV.Save("abcd", "1234")
assert.Nil(t, err)
val, err := MinIOKV.Load("abc")
assert.Nil(t, err)
assert.Equal(t, val, "123")
keys, vals, err := MinIOKV.LoadWithPrefix("abc")
assert.Nil(t, err)
assert.Equal(t, len(keys), len(vals))
assert.Equal(t, len(keys), 2)
assert.Equal(t, vals[0], "123")
assert.Equal(t, vals[1], "1234")
err = MinIOKV.Save("key_1", "123")
assert.Nil(t, err)
err = MinIOKV.Save("key_2", "456")
assert.Nil(t, err)
err = MinIOKV.Save("key_3", "789")
assert.Nil(t, err)
keys = []string{"key_1", "key_100"}
vals, err = MinIOKV.MultiLoad(keys)
assert.NotNil(t, err)
assert.Equal(t, len(vals), len(keys))
assert.Equal(t, vals[0], "123")
assert.Empty(t, vals[1])
keys = []string{"key_1", "key_2"}
vals, err = MinIOKV.MultiLoad(keys)
assert.Nil(t, err)
assert.Equal(t, len(vals), len(keys))
assert.Equal(t, vals[0], "123")
assert.Equal(t, vals[1], "456")
}
func TestMinIOKV_MultiSave(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
Params.Init()
endPoint, _ := Params.Load("_MinioAddress")
accessKeyID, _ := Params.Load("minio.accessKeyID")
secretAccessKey, _ := Params.Load("minio.secretAccessKey")
useSSLStr, _ := Params.Load("minio.useSSL")
useSSL, _ := strconv.ParseBool(useSSLStr)
minioClient, err := minio.New(endPoint, &minio.Options{
Creds: credentials.NewStaticV4(accessKeyID, secretAccessKey, ""),
Secure: useSSL,
})
assert.Nil(t, err)
bucketName := "fantastic-tech-test"
MinIOKV, err := miniokv.NewMinIOKV(ctx, minioClient, bucketName)
assert.Nil(t, err)
defer MinIOKV.RemoveWithPrefix("")
err = MinIOKV.Save("key_1", "111")
assert.Nil(t, err)
kvs := map[string]string{
"key_1": "123",
"key_2": "456",
}
err = MinIOKV.MultiSave(kvs)
assert.Nil(t, err)
val, err := MinIOKV.Load("key_1")
assert.Nil(t, err)
assert.Equal(t, val, "123")
}
func TestMinIOKV_Remove(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
Params.Init()
endPoint, _ := Params.Load("_MinioAddress")
accessKeyID, _ := Params.Load("minio.accessKeyID")
secretAccessKey, _ := Params.Load("minio.secretAccessKey")
useSSLStr, _ := Params.Load("minio.useSSL")
useSSL, _ := strconv.ParseBool(useSSLStr)
minioClient, err := minio.New(endPoint, &minio.Options{
Creds: credentials.NewStaticV4(accessKeyID, secretAccessKey, ""),
Secure: useSSL,
})
assert.Nil(t, err)
bucketName := "fantastic-tech-test"
MinIOKV, err := miniokv.NewMinIOKV(ctx, minioClient, bucketName)
assert.Nil(t, err)
defer MinIOKV.RemoveWithPrefix("")
err = MinIOKV.Save("key_1", "123")
assert.Nil(t, err)
err = MinIOKV.Save("key_2", "456")
assert.Nil(t, err)
val, err := MinIOKV.Load("key_1")
assert.Nil(t, err)
assert.Equal(t, val, "123")
// delete "key_1"
err = MinIOKV.Remove("key_1")
assert.Nil(t, err)
val, err = MinIOKV.Load("key_1")
assert.Error(t, err)
assert.Empty(t, val)
val, err = MinIOKV.Load("key_2")
assert.Nil(t, err)
assert.Equal(t, val, "456")
keys, vals, err := MinIOKV.LoadWithPrefix("key")
assert.Nil(t, err)
assert.Equal(t, len(keys), len(vals))
assert.Equal(t, len(keys), 1)
assert.Equal(t, vals[0], "456")
// MultiRemove
err = MinIOKV.Save("key_1", "111")
assert.Nil(t, err)
kvs := map[string]string{
"key_1": "123",
"key_2": "456",
"key_3": "789",
"key_4": "012",
}
err = MinIOKV.MultiSave(kvs)
assert.Nil(t, err)
val, err = MinIOKV.Load("key_1")
assert.Nil(t, err)
assert.Equal(t, val, "123")
val, err = MinIOKV.Load("key_3")
assert.Nil(t, err)
assert.Equal(t, val, "789")
keys = []string{"key_1", "key_2", "key_3"}
err = MinIOKV.MultiRemove(keys)
assert.Nil(t, err)
val, err = MinIOKV.Load("key_1")
assert.Error(t, err)
assert.Empty(t, val)
}

View File

@ -1,14 +1,14 @@
package mockkv
import (
"github.com/zilliztech/milvus-distributed/internal/kv"
memkv "github.com/zilliztech/milvus-distributed/internal/kv/mem"
)
// use MemoryKV to mock EtcdKV
func NewEtcdKV() *kv.MemoryKV {
return kv.NewMemoryKV()
func NewEtcdKV() *memkv.MemoryKV {
return memkv.NewMemoryKV()
}
func NewMemoryKV() *kv.MemoryKV {
return kv.NewMemoryKV()
func NewMemoryKV() *memkv.MemoryKV {
return memkv.NewMemoryKV()
}

View File

@ -4,13 +4,13 @@ import (
"log"
"github.com/zilliztech/milvus-distributed/internal/errors"
"github.com/zilliztech/milvus-distributed/internal/kv"
etcdkv "github.com/zilliztech/milvus-distributed/internal/kv/etcd"
"github.com/zilliztech/milvus-distributed/internal/proto/internalpb"
)
type getSysConfigsTask struct {
baseTask
configkv *kv.EtcdKV
configkv *etcdkv.EtcdKV
req *internalpb.SysConfigRequest
keys []string
values []string

View File

@ -36,7 +36,7 @@ type GlobalTSOAllocator struct {
}
// NewGlobalTSOAllocator creates a new global TSO allocator.
func NewGlobalTSOAllocator(key string, kvBase kv.Base) *GlobalTSOAllocator {
func NewGlobalTSOAllocator(key string, kvBase kv.TxnBase) *GlobalTSOAllocator {
var saveInterval = 3 * time.Second
return &GlobalTSOAllocator{
tso: &timestampOracle{

View File

@ -271,7 +271,7 @@ func (s *Master) HasPartition(ctx context.Context, in *internalpb.HasPartitionRe
return &servicepb.BoolResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR,
Reason: "WaitToFinish failed",
Reason: err.Error(),
},
Value: t.(*hasPartitionTask).hasPartition,
}, nil

View File

@ -9,7 +9,7 @@ type GlobalIDAllocator struct {
allocator Allocator
}
func NewGlobalIDAllocator(key string, base kv.Base) *GlobalIDAllocator {
func NewGlobalIDAllocator(key string, base kv.TxnBase) *GlobalIDAllocator {
return &GlobalIDAllocator{
allocator: NewGlobalTSOAllocator(key, base),
}

View File

@ -10,7 +10,7 @@ import (
"sync/atomic"
"time"
"github.com/zilliztech/milvus-distributed/internal/kv"
etcdkv "github.com/zilliztech/milvus-distributed/internal/kv/etcd"
ms "github.com/zilliztech/milvus-distributed/internal/msgstream"
"github.com/zilliztech/milvus-distributed/internal/proto/masterpb"
"github.com/zilliztech/milvus-distributed/internal/util/tsoutil"
@ -42,7 +42,7 @@ type Master struct {
grpcServer *grpc.Server
grpcErr chan error
kvBase *kv.EtcdKV
kvBase *etcdkv.EtcdKV
scheduler *ddRequestScheduler
metaTable *metaTable
timesSyncMsgProducer *timeSyncMsgProducer
@ -63,12 +63,12 @@ type Master struct {
tsoAllocator *GlobalTSOAllocator
}
func newKVBase(kvRoot string, etcdAddr []string) *kv.EtcdKV {
func newKVBase(kvRoot string, etcdAddr []string) *etcdkv.EtcdKV {
cli, _ := clientv3.New(clientv3.Config{
Endpoints: etcdAddr,
DialTimeout: 5 * time.Second,
})
kvBase := kv.NewEtcdKV(cli, kvRoot)
kvBase := etcdkv.NewEtcdKV(cli, kvRoot)
return kvBase
}
@ -89,8 +89,8 @@ func CreateServer(ctx context.Context) (*Master, error) {
if err != nil {
return nil, err
}
etcdkv := kv.NewEtcdKV(etcdClient, metaRootPath)
metakv, err := NewMetaTable(etcdkv)
etcdKV := etcdkv.NewEtcdKV(etcdClient, metaRootPath)
metakv, err := NewMetaTable(etcdKV)
if err != nil {
return nil, err
}

View File

@ -11,7 +11,7 @@ import (
)
type metaTable struct {
client *kv.EtcdKV // client of a reliable kv service, i.e. etcd client
client kv.TxnBase // client of a reliable kv service, i.e. etcd client
tenantID2Meta map[UniqueID]pb.TenantMeta // tenant id to tenant meta
proxyID2Meta map[UniqueID]pb.ProxyMeta // proxy id to proxy meta
collID2Meta map[UniqueID]pb.CollectionMeta // collection id to collection meta
@ -23,7 +23,7 @@ type metaTable struct {
ddLock sync.RWMutex
}
func NewMetaTable(kv *kv.EtcdKV) (*metaTable, error) {
func NewMetaTable(kv kv.TxnBase) (*metaTable, error) {
mt := &metaTable{
client: kv,
tenantLock: sync.RWMutex{},

View File

@ -7,7 +7,7 @@ import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/zilliztech/milvus-distributed/internal/kv"
etcdkv "github.com/zilliztech/milvus-distributed/internal/kv/etcd"
pb "github.com/zilliztech/milvus-distributed/internal/proto/etcdpb"
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
"go.etcd.io/etcd/clientv3"
@ -19,7 +19,7 @@ func TestMetaTable_Collection(t *testing.T) {
etcdAddr := Params.EtcdAddress
cli, err := clientv3.New(clientv3.Config{Endpoints: []string{etcdAddr}})
assert.Nil(t, err)
etcdKV := kv.NewEtcdKV(cli, "/etcd/test/root")
etcdKV := etcdkv.NewEtcdKV(cli, "/etcd/test/root")
_, err = cli.Delete(context.TODO(), "/etcd/test/root", clientv3.WithPrefix())
assert.Nil(t, err)
@ -157,7 +157,7 @@ func TestMetaTable_DeletePartition(t *testing.T) {
cli, err := clientv3.New(clientv3.Config{Endpoints: []string{etcdAddr}})
assert.Nil(t, err)
etcdKV := kv.NewEtcdKV(cli, "/etcd/test/root")
etcdKV := etcdkv.NewEtcdKV(cli, "/etcd/test/root")
_, err = cli.Delete(context.TODO(), "/etcd/test/root", clientv3.WithPrefix())
assert.Nil(t, err)
@ -252,7 +252,7 @@ func TestMetaTable_Segment(t *testing.T) {
cli, err := clientv3.New(clientv3.Config{Endpoints: []string{etcdAddr}})
assert.Nil(t, err)
etcdKV := kv.NewEtcdKV(cli, "/etcd/test/root")
etcdKV := etcdkv.NewEtcdKV(cli, "/etcd/test/root")
_, err = cli.Delete(context.TODO(), "/etcd/test/root", clientv3.WithPrefix())
assert.Nil(t, err)
@ -333,7 +333,7 @@ func TestMetaTable_UpdateSegment(t *testing.T) {
cli, err := clientv3.New(clientv3.Config{Endpoints: []string{etcdAddr}})
assert.Nil(t, err)
etcdKV := kv.NewEtcdKV(cli, "/etcd/test/root")
etcdKV := etcdkv.NewEtcdKV(cli, "/etcd/test/root")
_, err = cli.Delete(context.TODO(), "/etcd/test/root", clientv3.WithPrefix())
assert.Nil(t, err)
@ -379,7 +379,7 @@ func TestMetaTable_AddPartition_Limit(t *testing.T) {
cli, err := clientv3.New(clientv3.Config{Endpoints: []string{etcdAddr}})
assert.Nil(t, err)
etcdKV := kv.NewEtcdKV(cli, "/etcd/test/root")
etcdKV := etcdkv.NewEtcdKV(cli, "/etcd/test/root")
_, err = cli.Delete(context.TODO(), "/etcd/test/root", clientv3.WithPrefix())
assert.Nil(t, err)

View File

@ -9,7 +9,7 @@ import (
"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/assert"
"github.com/zilliztech/milvus-distributed/internal/kv"
etcdkv "github.com/zilliztech/milvus-distributed/internal/kv/etcd"
ms "github.com/zilliztech/milvus-distributed/internal/msgstream"
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
"github.com/zilliztech/milvus-distributed/internal/proto/internalpb"
@ -31,7 +31,7 @@ func TestMaster_Scheduler_Collection(t *testing.T) {
cli, err := clientv3.New(clientv3.Config{Endpoints: []string{etcdAddress}})
assert.Nil(t, err)
etcdKV := kv.NewEtcdKV(cli, "/etcd/test/root")
etcdKV := etcdkv.NewEtcdKV(cli, "/etcd/test/root")
meta, err := NewMetaTable(etcdKV)
assert.Nil(t, err)
@ -168,7 +168,7 @@ func TestMaster_Scheduler_Partition(t *testing.T) {
cli, err := clientv3.New(clientv3.Config{Endpoints: []string{etcdAddress}})
assert.Nil(t, err)
etcdKV := kv.NewEtcdKV(cli, "/etcd/test/root")
etcdKV := etcdkv.NewEtcdKV(cli, "/etcd/test/root")
meta, err := NewMetaTable(etcdKV)
assert.Nil(t, err)

View File

@ -11,6 +11,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/zilliztech/milvus-distributed/internal/errors"
"github.com/zilliztech/milvus-distributed/internal/kv"
etcdkv "github.com/zilliztech/milvus-distributed/internal/kv/etcd"
"github.com/zilliztech/milvus-distributed/internal/msgstream"
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
pb "github.com/zilliztech/milvus-distributed/internal/proto/etcdpb"
@ -29,7 +30,7 @@ var segMgr *SegmentManager
var collName = "coll_segmgr_test"
var collID = int64(1001)
var partitionTag = "test"
var kvBase *kv.EtcdKV
var kvBase kv.TxnBase
var master *Master
var masterCancelFunc context.CancelFunc
@ -48,7 +49,7 @@ func setup() {
if err != nil {
panic(err)
}
kvBase = kv.NewEtcdKV(cli, rootPath)
kvBase = etcdkv.NewEtcdKV(cli, rootPath)
tmpMt, err := NewMetaTable(kvBase)
if err != nil {
panic(err)

View File

@ -10,11 +10,11 @@ import (
"github.com/spf13/viper"
"github.com/zilliztech/milvus-distributed/internal/errors"
"github.com/zilliztech/milvus-distributed/internal/kv"
etcdkv "github.com/zilliztech/milvus-distributed/internal/kv/etcd"
)
type SysConfig struct {
kv *kv.EtcdKV
kv *etcdkv.EtcdKV
}
// Initialize Configs from config files, and store them in Etcd.

View File

@ -9,7 +9,7 @@ import (
"testing"
"time"
"github.com/zilliztech/milvus-distributed/internal/kv"
etcdkv "github.com/zilliztech/milvus-distributed/internal/kv/etcd"
"github.com/spf13/viper"
"github.com/stretchr/testify/assert"
@ -31,7 +31,7 @@ func Test_SysConfig(t *testing.T) {
require.Nil(t, err)
rootPath := "/test/root"
configKV := kv.NewEtcdKV(cli, rootPath)
configKV := etcdkv.NewEtcdKV(cli, rootPath)
defer configKV.Close()
sc := SysConfig{kv: configKV}

View File

@ -47,7 +47,7 @@ type atomicObject struct {
// timestampOracle is used to maintain the logic of tso.
type timestampOracle struct {
key string
kvBase kv.Base
kvBase kv.TxnBase
// TODO: remove saveInterval
saveInterval time.Duration

View File

@ -5,6 +5,7 @@ import (
"log"
"reflect"
"sync"
"time"
"github.com/apache/pulsar-client-go/pulsar"
"github.com/golang/protobuf/proto"
@ -69,11 +70,22 @@ func (ms *PulsarMsgStream) SetPulsarClient(address string) {
func (ms *PulsarMsgStream) CreatePulsarProducers(channels []string) {
for i := 0; i < len(channels); i++ {
pp, err := (*ms.client).CreateProducer(pulsar.ProducerOptions{Topic: channels[i]})
if err != nil {
log.Printf("Failed to create querynode producer %s, error = %v", channels[i], err)
fn := func() error {
pp, err := (*ms.client).CreateProducer(pulsar.ProducerOptions{Topic: channels[i]})
if err != nil {
return err
}
if pp == nil {
return errors.New("pulsar is not ready, producer is nil")
}
ms.producers = append(ms.producers, &pp)
return nil
}
err := Retry(10, time.Millisecond*200, fn)
if err != nil {
errMsg := "Failed to create producer " + channels[i] + ", error = " + err.Error()
panic(errMsg)
}
ms.producers = append(ms.producers, &pp)
}
}
@ -83,18 +95,29 @@ func (ms *PulsarMsgStream) CreatePulsarConsumers(channels []string,
pulsarBufSize int64) {
ms.unmarshal = unmarshal
for i := 0; i < len(channels); i++ {
receiveChannel := make(chan pulsar.ConsumerMessage, pulsarBufSize)
pc, err := (*ms.client).Subscribe(pulsar.ConsumerOptions{
Topic: channels[i],
SubscriptionName: subName,
Type: pulsar.KeyShared,
SubscriptionInitialPosition: pulsar.SubscriptionPositionEarliest,
MessageChannel: receiveChannel,
})
if err != nil {
log.Printf("Failed to subscribe topic, error = %v", err)
fn := func() error {
receiveChannel := make(chan pulsar.ConsumerMessage, pulsarBufSize)
pc, err := (*ms.client).Subscribe(pulsar.ConsumerOptions{
Topic: channels[i],
SubscriptionName: subName,
Type: pulsar.KeyShared,
SubscriptionInitialPosition: pulsar.SubscriptionPositionEarliest,
MessageChannel: receiveChannel,
})
if err != nil {
return err
}
if pc == nil {
return errors.New("pulsar is not ready, consumer is nil")
}
ms.consumers = append(ms.consumers, &pc)
return nil
}
err := Retry(10, time.Millisecond*200, fn)
if err != nil {
errMsg := "Failed to create consumer " + channels[i] + ", error = " + err.Error()
panic(errMsg)
}
ms.consumers = append(ms.consumers, &pc)
}
}

View File

@ -0,0 +1,32 @@
package msgstream
import (
"log"
"time"
)
// Reference: https://blog.cyeam.com/golang/2018/08/27/retry
func Retry(attempts int, sleep time.Duration, fn func() error) error {
if err := fn(); err != nil {
if s, ok := err.(InterruptError); ok {
return s.error
}
if attempts--; attempts > 0 {
log.Printf("retry func error: %s. attempts #%d after %s.", err.Error(), attempts, sleep)
time.Sleep(sleep)
return Retry(attempts, 2*sleep, fn)
}
return err
}
return nil
}
type InterruptError struct {
error
}
func NoRetryError(err error) InterruptError {
return InterruptError{err}
}

View File

@ -18,6 +18,7 @@ const (
)
func (p *Proxy) Insert(ctx context.Context, in *servicepb.RowBatch) (*servicepb.IntegerRangeResponse, error) {
log.Println("insert into: ", in.CollectionName)
it := &InsertTask{
Condition: NewTaskCondition(ctx),
BaseInsertTask: BaseInsertTask{
@ -76,6 +77,7 @@ func (p *Proxy) Insert(ctx context.Context, in *servicepb.RowBatch) (*servicepb.
}
func (p *Proxy) CreateCollection(ctx context.Context, req *schemapb.CollectionSchema) (*commonpb.Status, error) {
log.Println("create collection: ", req)
cct := &CreateCollectionTask{
Condition: NewTaskCondition(ctx),
CreateCollectionRequest: internalpb.CreateCollectionRequest{
@ -117,6 +119,7 @@ func (p *Proxy) CreateCollection(ctx context.Context, req *schemapb.CollectionSc
}
func (p *Proxy) Search(ctx context.Context, req *servicepb.Query) (*servicepb.QueryResult, error) {
log.Println("search: ", req.CollectionName, req.Dsl)
qt := &QueryTask{
Condition: NewTaskCondition(ctx),
SearchRequest: internalpb.SearchRequest{
@ -164,6 +167,7 @@ func (p *Proxy) Search(ctx context.Context, req *servicepb.Query) (*servicepb.Qu
}
func (p *Proxy) DropCollection(ctx context.Context, req *servicepb.CollectionName) (*commonpb.Status, error) {
log.Println("drop collection: ", req)
dct := &DropCollectionTask{
Condition: NewTaskCondition(ctx),
DropCollectionRequest: internalpb.DropCollectionRequest{
@ -204,6 +208,7 @@ func (p *Proxy) DropCollection(ctx context.Context, req *servicepb.CollectionNam
}
func (p *Proxy) HasCollection(ctx context.Context, req *servicepb.CollectionName) (*servicepb.BoolResponse, error) {
log.Println("has collection: ", req)
hct := &HasCollectionTask{
Condition: NewTaskCondition(ctx),
HasCollectionRequest: internalpb.HasCollectionRequest{
@ -248,6 +253,7 @@ func (p *Proxy) HasCollection(ctx context.Context, req *servicepb.CollectionName
}
func (p *Proxy) DescribeCollection(ctx context.Context, req *servicepb.CollectionName) (*servicepb.CollectionDescription, error) {
log.Println("describe collection: ", req)
dct := &DescribeCollectionTask{
Condition: NewTaskCondition(ctx),
DescribeCollectionRequest: internalpb.DescribeCollectionRequest{
@ -292,6 +298,7 @@ func (p *Proxy) DescribeCollection(ctx context.Context, req *servicepb.Collectio
}
func (p *Proxy) ShowCollections(ctx context.Context, req *commonpb.Empty) (*servicepb.StringListResponse, error) {
log.Println("show collections")
sct := &ShowCollectionsTask{
Condition: NewTaskCondition(ctx),
ShowCollectionRequest: internalpb.ShowCollectionRequest{
@ -335,6 +342,7 @@ func (p *Proxy) ShowCollections(ctx context.Context, req *commonpb.Empty) (*serv
}
func (p *Proxy) CreatePartition(ctx context.Context, in *servicepb.PartitionName) (*commonpb.Status, error) {
log.Println("create partition", in)
cpt := &CreatePartitionTask{
Condition: NewTaskCondition(ctx),
CreatePartitionRequest: internalpb.CreatePartitionRequest{
@ -380,6 +388,7 @@ func (p *Proxy) CreatePartition(ctx context.Context, in *servicepb.PartitionName
}
func (p *Proxy) DropPartition(ctx context.Context, in *servicepb.PartitionName) (*commonpb.Status, error) {
log.Println("drop partition: ", in)
dpt := &DropPartitionTask{
Condition: NewTaskCondition(ctx),
DropPartitionRequest: internalpb.DropPartitionRequest{
@ -426,6 +435,7 @@ func (p *Proxy) DropPartition(ctx context.Context, in *servicepb.PartitionName)
}
func (p *Proxy) HasPartition(ctx context.Context, in *servicepb.PartitionName) (*servicepb.BoolResponse, error) {
log.Println("has partition: ", in)
hpt := &HasPartitionTask{
Condition: NewTaskCondition(ctx),
HasPartitionRequest: internalpb.HasPartitionRequest{
@ -478,6 +488,7 @@ func (p *Proxy) HasPartition(ctx context.Context, in *servicepb.PartitionName) (
}
func (p *Proxy) DescribePartition(ctx context.Context, in *servicepb.PartitionName) (*servicepb.PartitionDescription, error) {
log.Println("describe partition: ", in)
dpt := &DescribePartitionTask{
Condition: NewTaskCondition(ctx),
DescribePartitionRequest: internalpb.DescribePartitionRequest{
@ -532,6 +543,7 @@ func (p *Proxy) DescribePartition(ctx context.Context, in *servicepb.PartitionNa
}
func (p *Proxy) ShowPartitions(ctx context.Context, req *servicepb.CollectionName) (*servicepb.StringListResponse, error) {
log.Println("show partitions: ", req)
spt := &ShowPartitionsTask{
Condition: NewTaskCondition(ctx),
ShowPartitionRequest: internalpb.ShowPartitionRequest{

View File

@ -4,30 +4,26 @@ import (
"context"
"sync"
"github.com/zilliztech/milvus-distributed/internal/allocator"
"github.com/zilliztech/milvus-distributed/internal/errors"
"github.com/zilliztech/milvus-distributed/internal/proto/internalpb"
"github.com/zilliztech/milvus-distributed/internal/proto/masterpb"
"github.com/zilliztech/milvus-distributed/internal/proto/servicepb"
)
type Cache interface {
Hit(collectionName string) bool
Get(collectionName string) (*servicepb.CollectionDescription, error)
Update(collectionName string) error
Sync(collectionName string) error
Update(collectionName string, desc *servicepb.CollectionDescription) error
Remove(collectionName string) error
}
var globalMetaCache Cache
type SimpleMetaCache struct {
mu sync.RWMutex
proxyID UniqueID
metas map[string]*servicepb.CollectionDescription // collection name to schema
masterClient masterpb.MasterClient
reqIDAllocator *allocator.IDAllocator
tsoAllocator *allocator.TimestampAllocator
ctx context.Context
mu sync.RWMutex
metas map[string]*servicepb.CollectionDescription // collection name to schema
ctx context.Context
proxyInstance *Proxy
}
func (metaCache *SimpleMetaCache) Hit(collectionName string) bool {
@ -47,58 +43,34 @@ func (metaCache *SimpleMetaCache) Get(collectionName string) (*servicepb.Collect
return schema, nil
}
func (metaCache *SimpleMetaCache) Update(collectionName string) error {
reqID, err := metaCache.reqIDAllocator.AllocOne()
if err != nil {
return err
}
ts, err := metaCache.tsoAllocator.AllocOne()
if err != nil {
return err
}
hasCollectionReq := &internalpb.HasCollectionRequest{
MsgType: internalpb.MsgType_kHasCollection,
ReqID: reqID,
Timestamp: ts,
ProxyID: metaCache.proxyID,
CollectionName: &servicepb.CollectionName{
CollectionName: collectionName,
func (metaCache *SimpleMetaCache) Sync(collectionName string) error {
dct := &DescribeCollectionTask{
Condition: NewTaskCondition(metaCache.ctx),
DescribeCollectionRequest: internalpb.DescribeCollectionRequest{
MsgType: internalpb.MsgType_kDescribeCollection,
CollectionName: &servicepb.CollectionName{
CollectionName: collectionName,
},
},
masterClient: metaCache.proxyInstance.masterClient,
}
has, err := metaCache.masterClient.HasCollection(metaCache.ctx, hasCollectionReq)
if err != nil {
return err
}
if !has.Value {
return errors.New("collection " + collectionName + " not exists")
}
var cancel func()
dct.ctx, cancel = context.WithTimeout(metaCache.ctx, reqTimeoutInterval)
defer cancel()
reqID, err = metaCache.reqIDAllocator.AllocOne()
if err != nil {
return err
}
ts, err = metaCache.tsoAllocator.AllocOne()
if err != nil {
return err
}
req := &internalpb.DescribeCollectionRequest{
MsgType: internalpb.MsgType_kDescribeCollection,
ReqID: reqID,
Timestamp: ts,
ProxyID: metaCache.proxyID,
CollectionName: &servicepb.CollectionName{
CollectionName: collectionName,
},
}
resp, err := metaCache.masterClient.DescribeCollection(metaCache.ctx, req)
err := metaCache.proxyInstance.sched.DdQueue.Enqueue(dct)
if err != nil {
return err
}
return dct.WaitToFinish()
}
func (metaCache *SimpleMetaCache) Update(collectionName string, desc *servicepb.CollectionDescription) error {
metaCache.mu.Lock()
defer metaCache.mu.Unlock()
metaCache.metas[collectionName] = resp
metaCache.metas[collectionName] = desc
return nil
}
@ -115,23 +87,14 @@ func (metaCache *SimpleMetaCache) Remove(collectionName string) error {
return nil
}
func newSimpleMetaCache(ctx context.Context,
mCli masterpb.MasterClient,
idAllocator *allocator.IDAllocator,
tsoAllocator *allocator.TimestampAllocator) *SimpleMetaCache {
func newSimpleMetaCache(ctx context.Context, proxyInstance *Proxy) *SimpleMetaCache {
return &SimpleMetaCache{
metas: make(map[string]*servicepb.CollectionDescription),
masterClient: mCli,
reqIDAllocator: idAllocator,
tsoAllocator: tsoAllocator,
proxyID: Params.ProxyID(),
ctx: ctx,
metas: make(map[string]*servicepb.CollectionDescription),
proxyInstance: proxyInstance,
ctx: ctx,
}
}
func initGlobalMetaCache(ctx context.Context,
mCli masterpb.MasterClient,
idAllocator *allocator.IDAllocator,
tsoAllocator *allocator.TimestampAllocator) {
globalMetaCache = newSimpleMetaCache(ctx, mCli, idAllocator, tsoAllocator)
func initGlobalMetaCache(ctx context.Context, proxyInstance *Proxy) {
globalMetaCache = newSimpleMetaCache(ctx, proxyInstance)
}

View File

@ -163,7 +163,7 @@ func (pt *ParamTable) convertRangeToSlice(rangeStr, sep string) []int {
panic(err)
}
var ret []int
for i := start; i <= end; i++ {
for i := start; i < end; i++ {
ret = append(ret, i)
}
return ret

View File

@ -109,7 +109,7 @@ func (p *Proxy) startProxy() error {
if err != nil {
return err
}
initGlobalMetaCache(p.proxyLoopCtx, p.masterClient, p.idAllocator, p.tsoAllocator)
initGlobalMetaCache(p.proxyLoopCtx, p)
p.manipulationMsgStream.Start()
p.queryMsgStream.Start()
p.sched.Start()

View File

@ -119,7 +119,7 @@ func createCollection(t *testing.T, name string) {
Name: name,
Description: "no description",
AutoID: true,
Fields: make([]*schemapb.FieldSchema, 1),
Fields: make([]*schemapb.FieldSchema, 2),
}
fieldName := "Field1"
req.Fields[0] = &schemapb.FieldSchema{
@ -127,6 +127,24 @@ func createCollection(t *testing.T, name string) {
Description: "no description",
DataType: schemapb.DataType_INT32,
}
fieldName = "vec"
req.Fields[1] = &schemapb.FieldSchema{
Name: fieldName,
Description: "vector",
DataType: schemapb.DataType_VECTOR_FLOAT,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "16",
},
},
IndexParams: []*commonpb.KeyValuePair{
{
Key: "metric_type",
Value: "L2",
},
},
}
resp, err := proxyClient.CreateCollection(ctx, req)
assert.Nil(t, err)
msg := "Create Collection " + name + " should succeed!"
@ -139,7 +157,7 @@ func dropCollection(t *testing.T, name string) {
}
resp, err := proxyClient.DropCollection(ctx, req)
assert.Nil(t, err)
msg := "Drop Collection " + name + " should succeed!"
msg := "Drop Collection " + name + " should succeed! err :" + resp.Reason
assert.Equal(t, resp.ErrorCode, commonpb.ErrorCode_SUCCESS, msg)
}
@ -152,6 +170,7 @@ func TestProxy_CreateCollection(t *testing.T) {
go func(group *sync.WaitGroup) {
defer group.Done()
createCollection(t, collectionName)
dropCollection(t, collectionName)
}(&wg)
}
wg.Wait()
@ -165,9 +184,11 @@ func TestProxy_HasCollection(t *testing.T) {
wg.Add(1)
go func(group *sync.WaitGroup) {
defer group.Done()
createCollection(t, collectionName)
has := hasCollection(t, collectionName)
msg := "Should has Collection " + collectionName
assert.Equal(t, has, true, msg)
dropCollection(t, collectionName)
}(&wg)
}
wg.Wait()
@ -182,6 +203,7 @@ func TestProxy_DescribeCollection(t *testing.T) {
wg.Add(1)
go func(group *sync.WaitGroup) {
defer group.Done()
createCollection(t, collectionName)
has := hasCollection(t, collectionName)
if has {
resp, err := proxyClient.DescribeCollection(ctx, &servicepb.CollectionName{CollectionName: collectionName})
@ -191,6 +213,7 @@ func TestProxy_DescribeCollection(t *testing.T) {
msg := "Describe Collection " + strconv.Itoa(i) + " should succeed!"
assert.Equal(t, resp.Status.ErrorCode, commonpb.ErrorCode_SUCCESS, msg)
t.Logf("Describe Collection %v: %v", i, resp)
dropCollection(t, collectionName)
}
}(&wg)
}
@ -206,6 +229,7 @@ func TestProxy_ShowCollections(t *testing.T) {
wg.Add(1)
go func(group *sync.WaitGroup) {
defer group.Done()
createCollection(t, collectionName)
has := hasCollection(t, collectionName)
if has {
resp, err := proxyClient.ShowCollections(ctx, &commonpb.Empty{})
@ -215,6 +239,7 @@ func TestProxy_ShowCollections(t *testing.T) {
msg := "Show collections " + strconv.Itoa(i) + " should succeed!"
assert.Equal(t, resp.Status.ErrorCode, commonpb.ErrorCode_SUCCESS, msg)
t.Logf("Show collections %v: %v", i, resp)
dropCollection(t, collectionName)
}
}(&wg)
}
@ -246,6 +271,7 @@ func TestProxy_Insert(t *testing.T) {
}
msg := "Insert into Collection " + strconv.Itoa(i) + " should succeed!"
assert.Equal(t, resp.Status.ErrorCode, commonpb.ErrorCode_SUCCESS, msg)
dropCollection(t, collectionName)
}
}(&wg)
}
@ -308,6 +334,7 @@ func TestProxy_Search(t *testing.T) {
queryWg.Add(1)
go func(group *sync.WaitGroup) {
defer group.Done()
//createCollection(t, collectionName)
has := hasCollection(t, collectionName)
if !has {
createCollection(t, collectionName)
@ -315,6 +342,7 @@ func TestProxy_Search(t *testing.T) {
resp, err := proxyClient.Search(ctx, req)
t.Logf("response of search collection %v: %v", i, resp)
assert.Nil(t, err)
dropCollection(t, collectionName)
}(&queryWg)
}
@ -328,9 +356,9 @@ func TestProxy_Search(t *testing.T) {
func TestProxy_AssignSegID(t *testing.T) {
collectionName := "CreateCollection1"
createCollection(t, collectionName)
testNum := 4
testNum := 1
for i := 0; i < testNum; i++ {
segID, err := proxyServer.segAssigner.GetSegmentID(collectionName, "default", int32(i), 200000)
segID, err := proxyServer.segAssigner.GetSegmentID(collectionName, Params.defaultPartitionTag(), int32(i), 200000)
assert.Nil(t, err)
fmt.Println("segID", segID)
}
@ -345,6 +373,7 @@ func TestProxy_DropCollection(t *testing.T) {
wg.Add(1)
go func(group *sync.WaitGroup) {
defer group.Done()
createCollection(t, collectionName)
has := hasCollection(t, collectionName)
if has {
dropCollection(t, collectionName)
@ -357,27 +386,14 @@ func TestProxy_DropCollection(t *testing.T) {
func TestProxy_PartitionGRPC(t *testing.T) {
var wg sync.WaitGroup
collName := "collPartTest"
filedName := "collPartTestF1"
collReq := &schemapb.CollectionSchema{
Name: collName,
Fields: []*schemapb.FieldSchema{
&schemapb.FieldSchema{
Name: filedName,
Description: "",
DataType: schemapb.DataType_VECTOR_FLOAT,
},
},
}
st, err := proxyClient.CreateCollection(ctx, collReq)
assert.Nil(t, err)
assert.Equal(t, st.ErrorCode, commonpb.ErrorCode_SUCCESS)
createCollection(t, collName)
for i := 0; i < testNum; i++ {
wg.Add(1)
i := i
go func() {
defer wg.Done()
tag := fmt.Sprintf("partition-%d", i)
tag := fmt.Sprintf("partition_%d", i)
preq := &servicepb.PartitionName{
CollectionName: collName,
Tag: tag,
@ -413,6 +429,7 @@ func TestProxy_PartitionGRPC(t *testing.T) {
}()
}
wg.Wait()
dropCollection(t, collName)
}
func TestMain(m *testing.M) {

View File

@ -91,7 +91,7 @@ func (it *InsertTask) PreExecute() error {
func (it *InsertTask) Execute() error {
collectionName := it.BaseInsertTask.CollectionName
if !globalMetaCache.Hit(collectionName) {
err := globalMetaCache.Update(collectionName)
err := globalMetaCache.Sync(collectionName)
if err != nil {
return err
}
@ -103,17 +103,20 @@ func (it *InsertTask) Execute() error {
autoID := description.Schema.AutoID
var rowIDBegin UniqueID
var rowIDEnd UniqueID
if autoID || true {
rowNums := len(it.BaseInsertTask.RowData)
rowIDBegin, rowIDEnd, _ = it.rowIDAllocator.Alloc(uint32(rowNums))
it.BaseInsertTask.RowIDs = make([]UniqueID, rowNums)
for i := rowIDBegin; i < rowIDEnd; i++ {
offset := i - rowIDBegin
it.BaseInsertTask.RowIDs[offset] = i
}
if autoID {
if it.HashValues == nil || len(it.HashValues) == 0 {
it.HashValues = make([]uint32, 0)
}
rowNums := len(it.BaseInsertTask.RowData)
rowIDBegin, rowIDEnd, _ = it.rowIDAllocator.Alloc(uint32(rowNums))
it.BaseInsertTask.RowIDs = make([]UniqueID, rowNums)
for i := rowIDBegin; i < rowIDEnd; i++ {
offset := i - rowIDBegin
it.BaseInsertTask.RowIDs[offset] = i
hashValue, _ := typeutil.Hash32Int64(i)
for _, rowID := range it.RowIDs {
hashValue, _ := typeutil.Hash32Int64(rowID)
it.HashValues = append(it.HashValues, hashValue)
}
}
@ -126,6 +129,7 @@ func (it *InsertTask) Execute() error {
}
msgPack.Msgs[0] = tsMsg
err = it.manipulationMsgStream.Produce(msgPack)
it.result = &servicepb.IntegerRangeResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_SUCCESS,
@ -352,7 +356,7 @@ func (qt *QueryTask) SetTs(ts Timestamp) {
func (qt *QueryTask) PreExecute() error {
collectionName := qt.query.CollectionName
if !globalMetaCache.Hit(collectionName) {
err := globalMetaCache.Update(collectionName)
err := globalMetaCache.Sync(collectionName)
if err != nil {
return err
}
@ -605,14 +609,9 @@ func (dct *DescribeCollectionTask) PreExecute() error {
}
func (dct *DescribeCollectionTask) Execute() error {
if !globalMetaCache.Hit(dct.CollectionName.CollectionName) {
err := globalMetaCache.Update(dct.CollectionName.CollectionName)
if err != nil {
return err
}
}
var err error
dct.result, err = globalMetaCache.Get(dct.CollectionName.CollectionName)
dct.result, err = dct.masterClient.DescribeCollection(dct.ctx, &dct.DescribeCollectionRequest)
globalMetaCache.Update(dct.CollectionName.CollectionName, dct.result)
return err
}

View File

@ -27,7 +27,7 @@ func TestTimeTick_Start(t *testing.T) {
func TestTimeTick_Start2(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
masterAddr := Params.MasterAddress()
tsoAllocator, err := allocator.NewTimestampAllocator(ctx, masterAddr)
assert.Nil(t, err)

View File

@ -9,7 +9,6 @@ import (
)
func TestValidateCollectionName(t *testing.T) {
Params.Init()
assert.Nil(t, ValidateCollectionName("abc"))
assert.Nil(t, ValidateCollectionName("_123abc"))
assert.Nil(t, ValidateCollectionName("abc123_$"))
@ -34,8 +33,8 @@ func TestValidateCollectionName(t *testing.T) {
}
func TestValidatePartitionTag(t *testing.T) {
Params.Init()
assert.Nil(t, ValidatePartitionTag("abc", true))
assert.Nil(t, ValidatePartitionTag("123abc", true))
assert.Nil(t, ValidatePartitionTag("_123abc", true))
assert.Nil(t, ValidatePartitionTag("abc123_$", true))
@ -44,7 +43,6 @@ func TestValidatePartitionTag(t *testing.T) {
longName[i] = 'a'
}
invalidNames := []string{
"123abc",
"$abc",
"_12 ac",
" ",
@ -62,7 +60,6 @@ func TestValidatePartitionTag(t *testing.T) {
}
func TestValidateFieldName(t *testing.T) {
Params.Init()
assert.Nil(t, ValidateFieldName("abc"))
assert.Nil(t, ValidateFieldName("_123abc"))
@ -86,7 +83,6 @@ func TestValidateFieldName(t *testing.T) {
}
func TestValidateDimension(t *testing.T) {
Params.Init()
assert.Nil(t, ValidateDimension(1, false))
assert.Nil(t, ValidateDimension(Params.MaxDimension(), false))
assert.Nil(t, ValidateDimension(8, true))

View File

@ -12,6 +12,7 @@ package querynode
*/
import "C"
import (
"fmt"
"github.com/zilliztech/milvus-distributed/internal/proto/internalpb"
"log"
"strconv"
@ -229,6 +230,7 @@ func (colReplica *collectionReplicaImpl) addPartitionsByCollectionMeta(colMeta *
if err != nil {
log.Println(err)
}
fmt.Println("add partition: ", tag)
}
return nil
@ -262,6 +264,7 @@ func (colReplica *collectionReplicaImpl) removePartitionsByCollectionMeta(colMet
if err != nil {
log.Println(err)
}
fmt.Println("delete partition: ", tag)
}
return nil

View File

@ -12,7 +12,7 @@ import (
"github.com/golang/protobuf/proto"
"github.com/zilliztech/milvus-distributed/internal/kv"
etcdkv "github.com/zilliztech/milvus-distributed/internal/kv/etcd"
"github.com/zilliztech/milvus-distributed/internal/proto/etcdpb"
"go.etcd.io/etcd/clientv3"
"go.etcd.io/etcd/mvcc/mvccpb"
@ -25,7 +25,7 @@ const (
type metaService struct {
ctx context.Context
kvBase *kv.EtcdKV
kvBase *etcdkv.EtcdKV
replica *collectionReplica
}
@ -40,7 +40,7 @@ func newMetaService(ctx context.Context, replica *collectionReplica) *metaServic
return &metaService{
ctx: ctx,
kvBase: kv.NewEtcdKV(cli, MetaRootPath),
kvBase: etcdkv.NewEtcdKV(cli, MetaRootPath),
replica: replica,
}
}
@ -145,7 +145,7 @@ func printSegmentStruct(obj *etcdpb.SegmentMeta) {
}
func (mService *metaService) processCollectionCreate(id string, value string) {
println(fmt.Sprintf("Create Collection:$%s$", id))
//println(fmt.Sprintf("Create Collection:$%s$", id))
col := mService.collectionUnmarshal(value)
if col != nil {
@ -163,7 +163,7 @@ func (mService *metaService) processCollectionCreate(id string, value string) {
}
func (mService *metaService) processSegmentCreate(id string, value string) {
println("Create Segment: ", id)
//println("Create Segment: ", id)
seg := mService.segmentUnmarshal(value)
if !isSegmentChannelRangeInQueryNodeChannelRange(seg) {
@ -182,7 +182,7 @@ func (mService *metaService) processSegmentCreate(id string, value string) {
}
func (mService *metaService) processCreate(key string, msg string) {
println("process create", key)
//println("process create", key)
if isCollectionObj(key) {
objID := GetCollectionObjID(key)
mService.processCollectionCreate(objID, msg)
@ -214,7 +214,7 @@ func (mService *metaService) processSegmentModify(id string, value string) {
}
func (mService *metaService) processCollectionModify(id string, value string) {
println("Modify Collection: ", id)
//println("Modify Collection: ", id)
col := mService.collectionUnmarshal(value)
if col != nil {
@ -242,7 +242,7 @@ func (mService *metaService) processModify(key string, msg string) {
}
func (mService *metaService) processSegmentDelete(id string) {
println("Delete segment: ", id)
//println("Delete segment: ", id)
var segmentID, err = strconv.ParseInt(id, 10, 64)
if err != nil {
@ -257,7 +257,7 @@ func (mService *metaService) processSegmentDelete(id string) {
}
func (mService *metaService) processCollectionDelete(id string) {
println("Delete collection: ", id)
//println("Delete collection: ", id)
var collectionID, err = strconv.ParseInt(id, 10, 64)
if err != nil {
@ -272,7 +272,7 @@ func (mService *metaService) processCollectionDelete(id string) {
}
func (mService *metaService) processDelete(key string) {
println("process delete")
//println("process delete")
if isCollectionObj(key) {
objID := GetCollectionObjID(key)

View File

@ -29,7 +29,7 @@ func createPlan(col Collection, dsl string) (*Plan, error) {
if errorCode != 0 {
errorMsg := C.GoString(status.error_msg)
defer C.free(unsafe.Pointer(status.error_msg))
return nil, errors.New("Insert failed, C runtime error detected, error code = " + strconv.Itoa(int(errorCode)) + ", error msg = " + errorMsg)
return nil, errors.New("Create plan failed, C runtime error detected, error code = " + strconv.Itoa(int(errorCode)) + ", error msg = " + errorMsg)
}
var newPlan = &Plan{cPlan: cPlan}
@ -60,7 +60,7 @@ func parserPlaceholderGroup(plan *Plan, placeHolderBlob []byte) (*PlaceholderGro
if errorCode != 0 {
errorMsg := C.GoString(status.error_msg)
defer C.free(unsafe.Pointer(status.error_msg))
return nil, errors.New("Insert failed, C runtime error detected, error code = " + strconv.Itoa(int(errorCode)) + ", error msg = " + errorMsg)
return nil, errors.New("Parser placeholder group failed, C runtime error detected, error code = " + strconv.Itoa(int(errorCode)) + ", error msg = " + errorMsg)
}
var newPlaceholderGroup = &PlaceholderGroup{cPlaceholderGroup: cPlaceholderGroup}

View File

@ -10,6 +10,8 @@ package querynode
*/
import "C"
import (
"errors"
"strconv"
"unsafe"
)
@ -21,26 +23,66 @@ type MarshaledHits struct {
cMarshaledHits C.CMarshaledHits
}
func reduceSearchResults(searchResults []*SearchResult, numSegments int64) *SearchResult {
func reduceSearchResults(searchResults []*SearchResult, numSegments int64, inReduced []bool) error {
cSearchResults := make([]C.CQueryResult, 0)
for _, res := range searchResults {
cSearchResults = append(cSearchResults, res.cQueryResult)
}
cSearchResultPtr := (*C.CQueryResult)(&cSearchResults[0])
cNumSegments := C.long(numSegments)
res := C.ReduceQueryResults(cSearchResultPtr, cNumSegments)
return &SearchResult{cQueryResult: res}
cInReduced := (*C.bool)(&inReduced[0])
status := C.ReduceQueryResults(cSearchResultPtr, cNumSegments, cInReduced)
errorCode := status.error_code
if errorCode != 0 {
errorMsg := C.GoString(status.error_msg)
defer C.free(unsafe.Pointer(status.error_msg))
return errors.New("reduceSearchResults failed, C runtime error detected, error code = " + strconv.Itoa(int(errorCode)) + ", error msg = " + errorMsg)
}
return nil
}
func (sr *SearchResult) reorganizeQueryResults(plan *Plan, placeholderGroups []*PlaceholderGroup) *MarshaledHits {
func fillTargetEntry(plan *Plan, searchResults []*SearchResult, matchedSegments []*Segment, inReduced []bool) error {
for i, value := range inReduced {
if value {
err := matchedSegments[i].fillTargetEntry(plan, searchResults[i])
if err != nil {
return err
}
}
}
return nil
}
func reorganizeQueryResults(plan *Plan, placeholderGroups []*PlaceholderGroup, searchResults []*SearchResult, numSegments int64, inReduced []bool) (*MarshaledHits, error) {
cPlaceholderGroups := make([]C.CPlaceholderGroup, 0)
for _, pg := range placeholderGroups {
cPlaceholderGroups = append(cPlaceholderGroups, (*pg).cPlaceholderGroup)
}
cNumGroup := (C.long)(len(placeholderGroups))
var cPlaceHolder = (*C.CPlaceholderGroup)(&cPlaceholderGroups[0])
res := C.ReorganizeQueryResults(sr.cQueryResult, plan.cPlan, cPlaceHolder, cNumGroup)
return &MarshaledHits{cMarshaledHits: res}
var cPlaceHolderGroupPtr = (*C.CPlaceholderGroup)(&cPlaceholderGroups[0])
var cNumGroup = (C.long)(len(placeholderGroups))
cSearchResults := make([]C.CQueryResult, 0)
for _, res := range searchResults {
cSearchResults = append(cSearchResults, res.cQueryResult)
}
cSearchResultPtr := (*C.CQueryResult)(&cSearchResults[0])
var cNumSegments = C.long(numSegments)
var cInReduced = (*C.bool)(&inReduced[0])
var cMarshaledHits C.CMarshaledHits
status := C.ReorganizeQueryResults(&cMarshaledHits, cPlaceHolderGroupPtr, cNumGroup, cSearchResultPtr, cInReduced, cNumSegments, plan.cPlan)
errorCode := status.error_code
if errorCode != 0 {
errorMsg := C.GoString(status.error_msg)
defer C.free(unsafe.Pointer(status.error_msg))
return nil, errors.New("reorganizeQueryResults failed, C runtime error detected, error code = " + strconv.Itoa(int(errorCode)) + ", error msg = " + errorMsg)
}
return &MarshaledHits{cMarshaledHits: cMarshaledHits}, nil
}
func (mh *MarshaledHits) getHitsBlobSize() int64 {

View File

@ -107,15 +107,21 @@ func TestReduce_AllFunc(t *testing.T) {
placeholderGroups = append(placeholderGroups, holder)
searchResults := make([]*SearchResult, 0)
matchedSegment := make([]*Segment, 0)
searchResult, err := segment.segmentSearch(plan, placeholderGroups, []Timestamp{0})
assert.Nil(t, err)
searchResults = append(searchResults, searchResult)
matchedSegment = append(matchedSegment, segment)
reducedSearchResults := reduceSearchResults(searchResults, 1)
assert.NotNil(t, reducedSearchResults)
testReduce := make([]bool, len(searchResults))
err = reduceSearchResults(searchResults, 1, testReduce)
assert.Nil(t, err)
err = fillTargetEntry(plan, searchResults, matchedSegment, testReduce)
assert.Nil(t, err)
marshaledHits := reducedSearchResults.reorganizeQueryResults(plan, placeholderGroups)
marshaledHits, err := reorganizeQueryResults(plan, placeholderGroups, searchResults, 1, testReduce)
assert.NotNil(t, marshaledHits)
assert.Nil(t, err)
hitsBlob, err := marshaledHits.getHitsBlob()
assert.Nil(t, err)
@ -137,7 +143,6 @@ func TestReduce_AllFunc(t *testing.T) {
plan.delete()
holder.delete()
deleteSearchResults(searchResults)
deleteSearchResults([]*SearchResult{reducedSearchResults})
deleteMarshaledHits(marshaledHits)
deleteSegment(segment)
deleteCollection(collection)

View File

@ -139,7 +139,7 @@ func (ss *searchService) receiveSearchMsg() {
err := ss.search(msg)
if err != nil {
log.Println(err)
err = ss.publishFailedSearchResult(msg)
err = ss.publishFailedSearchResult(msg, err.Error())
if err != nil {
log.Println("publish FailedSearchResult failed, error message: ", err)
}
@ -191,7 +191,7 @@ func (ss *searchService) doUnsolvedMsgSearch() {
err := ss.search(msg)
if err != nil {
log.Println(err)
err = ss.publishFailedSearchResult(msg)
err = ss.publishFailedSearchResult(msg, err.Error())
if err != nil {
log.Println("publish FailedSearchResult failed, error message: ", err)
}
@ -238,6 +238,7 @@ func (ss *searchService) search(msg msgstream.TsMsg) error {
placeholderGroups = append(placeholderGroups, placeholderGroup)
searchResults := make([]*SearchResult, 0)
matchedSegments := make([]*Segment, 0)
for _, partitionTag := range partitionTags {
hasPartition := (*ss.replica).hasPartition(collectionID, partitionTag)
@ -257,6 +258,7 @@ func (ss *searchService) search(msg msgstream.TsMsg) error {
return err
}
searchResults = append(searchResults, searchResult)
matchedSegments = append(matchedSegments, segment)
}
}
@ -282,8 +284,20 @@ func (ss *searchService) search(msg msgstream.TsMsg) error {
return nil
}
reducedSearchResult := reduceSearchResults(searchResults, int64(len(searchResults)))
marshaledHits := reducedSearchResult.reorganizeQueryResults(plan, placeholderGroups)
inReduced := make([]bool, len(searchResults))
numSegment := int64(len(searchResults))
err = reduceSearchResults(searchResults, numSegment, inReduced)
if err != nil {
return err
}
err = fillTargetEntry(plan, searchResults, matchedSegments, inReduced)
if err != nil {
return err
}
marshaledHits, err := reorganizeQueryResults(plan, placeholderGroups, searchResults, numSegment, inReduced)
if err != nil {
return err
}
hitsBlob, err := marshaledHits.getHitsBlob()
if err != nil {
return err
@ -291,12 +305,12 @@ func (ss *searchService) search(msg msgstream.TsMsg) error {
var offset int64 = 0
for index := range placeholderGroups {
hitBolbSizePeerQuery, err := marshaledHits.hitBlobSizeInGroup(int64(index))
hitBlobSizePeerQuery, err := marshaledHits.hitBlobSizeInGroup(int64(index))
if err != nil {
return err
}
hits := make([][]byte, 0)
for _, len := range hitBolbSizePeerQuery {
for _, len := range hitBlobSizePeerQuery {
hits = append(hits, hitsBlob[offset:offset+len])
//test code to checkout marshaled hits
//marshaledHit := hitsBlob[offset:offset+len]
@ -329,7 +343,6 @@ func (ss *searchService) search(msg msgstream.TsMsg) error {
}
deleteSearchResults(searchResults)
deleteSearchResults([]*SearchResult{reducedSearchResult})
deleteMarshaledHits(marshaledHits)
plan.delete()
placeholderGroup.delete()
@ -346,7 +359,7 @@ func (ss *searchService) publishSearchResult(msg msgstream.TsMsg) error {
return nil
}
func (ss *searchService) publishFailedSearchResult(msg msgstream.TsMsg) error {
func (ss *searchService) publishFailedSearchResult(msg msgstream.TsMsg, errMsg string) error {
msgPack := msgstream.MsgPack{}
searchMsg, ok := msg.(*msgstream.SearchMsg)
if !ok {
@ -354,7 +367,7 @@ func (ss *searchService) publishFailedSearchResult(msg msgstream.TsMsg) error {
}
var results = internalpb.SearchResult{
MsgType: internalpb.MsgType_kSearchResult,
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR},
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR, Reason: errMsg},
ReqID: searchMsg.ReqID,
ProxyID: searchMsg.ProxyID,
QueryNodeID: searchMsg.ProxyID,

View File

@ -253,3 +253,242 @@ func TestSearch_Search(t *testing.T) {
cancel()
node.Close()
}
func TestSearch_SearchMultiSegments(t *testing.T) {
Params.Init()
ctx, cancel := context.WithCancel(context.Background())
// init query node
pulsarURL, _ := Params.pulsarAddress()
node := NewQueryNode(ctx, 0)
// init meta
collectionName := "collection0"
fieldVec := schemapb.FieldSchema{
Name: "vec",
IsPrimaryKey: false,
DataType: schemapb.DataType_VECTOR_FLOAT,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "16",
},
},
}
fieldInt := schemapb.FieldSchema{
Name: "age",
IsPrimaryKey: false,
DataType: schemapb.DataType_INT32,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "1",
},
},
}
schema := schemapb.CollectionSchema{
Name: collectionName,
AutoID: true,
Fields: []*schemapb.FieldSchema{
&fieldVec, &fieldInt,
},
}
collectionMeta := etcdpb.CollectionMeta{
ID: UniqueID(0),
Schema: &schema,
CreateTime: Timestamp(0),
SegmentIDs: []UniqueID{0},
PartitionTags: []string{"default"},
}
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
assert.NotEqual(t, "", collectionMetaBlob)
var err = (*node.replica).addCollection(&collectionMeta, collectionMetaBlob)
assert.NoError(t, err)
collection, err := (*node.replica).getCollectionByName(collectionName)
assert.NoError(t, err)
assert.Equal(t, collection.meta.Schema.Name, "collection0")
assert.Equal(t, collection.meta.ID, UniqueID(0))
assert.Equal(t, (*node.replica).getCollectionNum(), 1)
err = (*node.replica).addPartition(collection.ID(), collectionMeta.PartitionTags[0])
assert.NoError(t, err)
segmentID := UniqueID(0)
err = (*node.replica).addSegment(segmentID, collectionMeta.PartitionTags[0], UniqueID(0))
assert.NoError(t, err)
// test data generate
const msgLength = 1024
const receiveBufSize = 1024
const DIM = 16
insertProducerChannels := Params.insertChannelNames()
searchProducerChannels := Params.searchChannelNames()
var vec = [DIM]float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
// start search service
dslString := "{\"bool\": { \n\"vector\": {\n \"vec\": {\n \"metric_type\": \"L2\", \n \"params\": {\n \"nprobe\": 10 \n},\n \"query\": \"$0\",\"topk\": 10 \n } \n } \n } \n }"
var searchRawData1 []byte
var searchRawData2 []byte
for i, ele := range vec {
buf := make([]byte, 4)
binary.LittleEndian.PutUint32(buf, math.Float32bits(ele+float32(i*2)))
searchRawData1 = append(searchRawData1, buf...)
}
for i, ele := range vec {
buf := make([]byte, 4)
binary.LittleEndian.PutUint32(buf, math.Float32bits(ele+float32(i*4)))
searchRawData2 = append(searchRawData2, buf...)
}
placeholderValue := servicepb.PlaceholderValue{
Tag: "$0",
Type: servicepb.PlaceholderType_VECTOR_FLOAT,
Values: [][]byte{searchRawData1, searchRawData2},
}
placeholderGroup := servicepb.PlaceholderGroup{
Placeholders: []*servicepb.PlaceholderValue{&placeholderValue},
}
placeGroupByte, err := proto.Marshal(&placeholderGroup)
if err != nil {
log.Print("marshal placeholderGroup failed")
}
query := servicepb.Query{
CollectionName: "collection0",
PartitionTags: []string{"default"},
Dsl: dslString,
PlaceholderGroup: placeGroupByte,
}
queryByte, err := proto.Marshal(&query)
if err != nil {
log.Print("marshal query failed")
}
blob := commonpb.Blob{
Value: queryByte,
}
searchMsg := &msgstream.SearchMsg{
BaseMsg: msgstream.BaseMsg{
HashValues: []uint32{0},
},
SearchRequest: internalpb.SearchRequest{
MsgType: internalpb.MsgType_kSearch,
ReqID: int64(1),
ProxyID: int64(1),
Timestamp: uint64(10 + 1000),
ResultChannelID: int64(0),
Query: &blob,
},
}
msgPackSearch := msgstream.MsgPack{}
msgPackSearch.Msgs = append(msgPackSearch.Msgs, searchMsg)
searchStream := msgstream.NewPulsarMsgStream(ctx, receiveBufSize)
searchStream.SetPulsarClient(pulsarURL)
searchStream.CreatePulsarProducers(searchProducerChannels)
searchStream.Start()
err = searchStream.Produce(&msgPackSearch)
assert.NoError(t, err)
node.searchService = newSearchService(node.ctx, node.replica)
go node.searchService.start()
// start insert
timeRange := TimeRange{
timestampMin: 0,
timestampMax: math.MaxUint64,
}
insertMessages := make([]msgstream.TsMsg, 0)
for i := 0; i < msgLength; i++ {
segmentID := 0
if i >= msgLength/2 {
segmentID = 1
}
var rawData []byte
for _, ele := range vec {
buf := make([]byte, 4)
binary.LittleEndian.PutUint32(buf, math.Float32bits(ele+float32(i*2)))
rawData = append(rawData, buf...)
}
bs := make([]byte, 4)
binary.LittleEndian.PutUint32(bs, 1)
rawData = append(rawData, bs...)
var msg msgstream.TsMsg = &msgstream.InsertMsg{
BaseMsg: msgstream.BaseMsg{
HashValues: []uint32{
uint32(i),
},
},
InsertRequest: internalpb.InsertRequest{
MsgType: internalpb.MsgType_kInsert,
ReqID: int64(i),
CollectionName: "collection0",
PartitionTag: "default",
SegmentID: int64(segmentID),
ChannelID: int64(0),
ProxyID: int64(0),
Timestamps: []uint64{uint64(i + 1000)},
RowIDs: []int64{int64(i)},
RowData: []*commonpb.Blob{
{Value: rawData},
},
},
}
insertMessages = append(insertMessages, msg)
}
msgPack := msgstream.MsgPack{
BeginTs: timeRange.timestampMin,
EndTs: timeRange.timestampMax,
Msgs: insertMessages,
}
// generate timeTick
timeTickMsgPack := msgstream.MsgPack{}
baseMsg := msgstream.BaseMsg{
BeginTimestamp: 0,
EndTimestamp: 0,
HashValues: []uint32{0},
}
timeTickResult := internalpb.TimeTickMsg{
MsgType: internalpb.MsgType_kTimeTick,
PeerID: UniqueID(0),
Timestamp: math.MaxUint64,
}
timeTickMsg := &msgstream.TimeTickMsg{
BaseMsg: baseMsg,
TimeTickMsg: timeTickResult,
}
timeTickMsgPack.Msgs = append(timeTickMsgPack.Msgs, timeTickMsg)
// pulsar produce
insertStream := msgstream.NewPulsarMsgStream(ctx, receiveBufSize)
insertStream.SetPulsarClient(pulsarURL)
insertStream.CreatePulsarProducers(insertProducerChannels)
insertStream.Start()
err = insertStream.Produce(&msgPack)
assert.NoError(t, err)
err = insertStream.Broadcast(&timeTickMsgPack)
assert.NoError(t, err)
// dataSync
node.dataSyncService = newDataSyncService(node.ctx, node.replica)
go node.dataSyncService.start()
time.Sleep(1 * time.Second)
cancel()
node.Close()
}

View File

@ -208,7 +208,7 @@ func (s *Segment) segmentSearch(plan *Plan,
var cTimestamp = (*C.ulong)(&timestamp[0])
var cPlaceHolder = (*C.CPlaceholderGroup)(&cPlaceholderGroups[0])
var cNumGroups = C.int(len(placeHolderGroups))
cQueryResult := (*C.CQueryResult)(&searchResult.cQueryResult)
var cQueryResult = (*C.CQueryResult)(&searchResult.cQueryResult)
var status = C.Search(s.segmentPtr, plan.cPlan, cPlaceHolder, cTimestamp, cNumGroups, cQueryResult)
errorCode := status.error_code
@ -221,3 +221,18 @@ func (s *Segment) segmentSearch(plan *Plan,
return &searchResult, nil
}
func (s *Segment) fillTargetEntry(plan *Plan,
result *SearchResult) error {
var status = C.FillTargetEntry(s.segmentPtr, plan.cPlan, result.cQueryResult)
errorCode := status.error_code
if errorCode != 0 {
errorMsg := C.GoString(status.error_msg)
defer C.free(unsafe.Pointer(status.error_msg))
return errors.New("FillTargetEntry failed, C runtime error detected, error code = " + strconv.Itoa(int(errorCode)) + ", error msg = " + errorMsg)
}
return nil
}

View File

@ -20,7 +20,7 @@ import (
"github.com/spf13/cast"
"github.com/spf13/viper"
"github.com/zilliztech/milvus-distributed/internal/kv"
memkv "github.com/zilliztech/milvus-distributed/internal/kv/mem"
)
type Base interface {
@ -33,16 +33,25 @@ type Base interface {
}
type BaseTable struct {
params *kv.MemoryKV
params *memkv.MemoryKV
}
func (gp *BaseTable) Init() {
gp.params = kv.NewMemoryKV()
gp.params = memkv.NewMemoryKV()
err := gp.LoadYaml("config.yaml")
if err != nil {
panic(err)
}
minioAddress := os.Getenv("MINIO_ADDRESS")
if minioAddress == "" {
minioAddress = "localhost:9000"
}
err = gp.Save("_MinioAddress", minioAddress)
if err != nil {
panic(err)
}
etcdAddress := os.Getenv("ETCD_ADDRESS")
if etcdAddress == "" {
etcdAddress = "localhost:2379"

View File

@ -4,7 +4,7 @@ import (
"path"
"time"
"github.com/zilliztech/milvus-distributed/internal/kv"
etcdkv "github.com/zilliztech/milvus-distributed/internal/kv/etcd"
"go.etcd.io/etcd/clientv3"
)
@ -25,10 +25,10 @@ func ParseTS(ts uint64) (time.Time, uint64) {
return physicalTime, logical
}
func NewTSOKVBase(etcdAddr []string, tsoRoot, subPath string) *kv.EtcdKV {
func NewTSOKVBase(etcdAddr []string, tsoRoot, subPath string) *etcdkv.EtcdKV {
client, _ := clientv3.New(clientv3.Config{
Endpoints: etcdAddr,
DialTimeout: 5 * time.Second,
})
return kv.NewEtcdKV(client, path.Join(tsoRoot, subPath))
return etcdkv.NewEtcdKV(client, path.Join(tsoRoot, subPath))
}

View File

@ -13,5 +13,5 @@ SCRIPTS_DIR="$( cd -P "$( dirname "$SOURCE" )" && pwd )"
# ignore Minio,S3 unittes
MILVUS_DIR="${SCRIPTS_DIR}/../internal/"
echo $MILVUS_DIR
#go test -cover "${MILVUS_DIR}/kv/..." "${MILVUS_DIR}/msgstream/..." "${MILVUS_DIR}/master/..." "${MILVUS_DIR}/querynode/..." "${MILVUS_DIR}/proxy/..." -failfast
go test -cover "${MILVUS_DIR}/kv/..." "${MILVUS_DIR}/msgstream/..." "${MILVUS_DIR}/master/..." "${MILVUS_DIR}/querynode/..." -failfast
go test -cover "${MILVUS_DIR}/kv/..." "${MILVUS_DIR}/msgstream/..." "${MILVUS_DIR}/master/..." "${MILVUS_DIR}/querynode/..." "${MILVUS_DIR}/proxy/..." -failfast
#go test -cover "${MILVUS_DIR}/kv/..." "${MILVUS_DIR}/msgstream/..." "${MILVUS_DIR}/master/..." "${MILVUS_DIR}/querynode/..." -failfast

0
tests/python/__init__.py Normal file
View File

235
tests/python/conftest.py Normal file
View File

@ -0,0 +1,235 @@
import socket
import pytest
from .utils import *
timeout = 60
dimension = 128
delete_timeout = 60
def pytest_addoption(parser):
parser.addoption("--ip", action="store", default="localhost")
parser.addoption("--service", action="store", default="")
parser.addoption("--port", action="store", default=19530)
parser.addoption("--http-port", action="store", default=19121)
parser.addoption("--handler", action="store", default="GRPC")
parser.addoption("--tag", action="store", default="all", help="only run tests matching the tag.")
parser.addoption('--dry-run', action='store_true', default=False)
def pytest_configure(config):
# register an additional marker
config.addinivalue_line(
"markers", "tag(name): mark test to run only matching the tag"
)
def pytest_runtest_setup(item):
tags = list()
for marker in item.iter_markers(name="tag"):
for tag in marker.args:
tags.append(tag)
if tags:
cmd_tag = item.config.getoption("--tag")
if cmd_tag != "all" and cmd_tag not in tags:
pytest.skip("test requires tag in {!r}".format(tags))
def pytest_runtestloop(session):
if session.config.getoption('--dry-run'):
total_passed = 0
total_skipped = 0
test_file_to_items = {}
for item in session.items:
file_name, test_class, test_func = item.nodeid.split("::")
if test_file_to_items.get(file_name) is not None:
test_file_to_items[file_name].append(item)
else:
test_file_to_items[file_name] = [item]
for k, items in test_file_to_items.items():
skip_case = []
should_pass_but_skipped = []
skipped_other_reason = []
level2_case = []
for item in items:
if "pytestmark" in item.keywords.keys():
markers = item.keywords["pytestmark"]
skip_case.extend([item.nodeid for marker in markers if marker.name == 'skip'])
should_pass_but_skipped.extend([item.nodeid for marker in markers if marker.name == 'skip' and len(marker.args) > 0 and marker.args[0] == "should pass"])
skipped_other_reason.extend([item.nodeid for marker in markers if marker.name == 'skip' and (len(marker.args) < 1 or marker.args[0] != "should pass")])
level2_case.extend([item.nodeid for marker in markers if marker.name == 'level' and marker.args[0] == 2])
print("")
print(f"[{k}]:")
print(f" Total : {len(items):13}")
print(f" Passed : {len(items) - len(skip_case):13}")
print(f" Skipped : {len(skip_case):13}")
print(f" - should pass: {len(should_pass_but_skipped):4}")
print(f" - not supported: {len(skipped_other_reason):4}")
print(f" Level2 : {len(level2_case):13}")
print(f" ---------------------------------------")
print(f" should pass but skipped: ")
print("")
for nodeid in should_pass_but_skipped:
name, test_class, test_func = nodeid.split("::")
print(f" {name:8}: {test_class}.{test_func}")
print("")
print(f"===============================================")
total_passed += len(items) - len(skip_case)
total_skipped += len(skip_case)
print("Total tests : ", len(session.items))
print("Total passed: ", total_passed)
print("Total skiped: ", total_skipped)
return True
def check_server_connection(request):
ip = request.config.getoption("--ip")
port = request.config.getoption("--port")
connected = True
if ip and (ip not in ['localhost', '127.0.0.1']):
try:
socket.getaddrinfo(ip, port, 0, 0, socket.IPPROTO_TCP)
except Exception as e:
print("Socket connnet failed: %s" % str(e))
connected = False
return connected
@pytest.fixture(scope="module")
def connect(request):
ip = request.config.getoption("--ip")
service_name = request.config.getoption("--service")
port = request.config.getoption("--port")
http_port = request.config.getoption("--http-port")
handler = request.config.getoption("--handler")
if handler == "HTTP":
port = http_port
try:
milvus = get_milvus(host=ip, port=port, handler=handler)
# reset_build_index_threshold(milvus)
except Exception as e:
logging.getLogger().error(str(e))
pytest.exit("Milvus server can not connected, exit pytest ...")
def fin():
try:
milvus.close()
pass
except Exception as e:
logging.getLogger().info(str(e))
request.addfinalizer(fin)
return milvus
@pytest.fixture(scope="module")
def dis_connect(request):
ip = request.config.getoption("--ip")
service_name = request.config.getoption("--service")
port = request.config.getoption("--port")
http_port = request.config.getoption("--http-port")
handler = request.config.getoption("--handler")
if handler == "HTTP":
port = http_port
milvus = get_milvus(host=ip, port=port, handler=handler)
milvus.close()
return milvus
@pytest.fixture(scope="module")
def args(request):
ip = request.config.getoption("--ip")
service_name = request.config.getoption("--service")
port = request.config.getoption("--port")
http_port = request.config.getoption("--http-port")
handler = request.config.getoption("--handler")
if handler == "HTTP":
port = http_port
args = {"ip": ip, "port": port, "handler": handler, "service_name": service_name}
return args
@pytest.fixture(scope="module")
def milvus(request):
ip = request.config.getoption("--ip")
port = request.config.getoption("--port")
http_port = request.config.getoption("--http-port")
handler = request.config.getoption("--handler")
if handler == "HTTP":
port = http_port
return get_milvus(host=ip, port=port, handler=handler)
@pytest.fixture(scope="function")
def collection(request, connect):
ori_collection_name = getattr(request.module, "collection_id", "test")
collection_name = gen_unique_str(ori_collection_name)
try:
default_fields = gen_default_fields()
connect.create_collection(collection_name, default_fields)
except Exception as e:
pytest.exit(str(e))
def teardown():
if connect.has_collection(collection_name):
connect.drop_collection(collection_name, timeout=delete_timeout)
request.addfinalizer(teardown)
assert connect.has_collection(collection_name)
return collection_name
# customised id
@pytest.fixture(scope="function")
def id_collection(request, connect):
ori_collection_name = getattr(request.module, "collection_id", "test")
collection_name = gen_unique_str(ori_collection_name)
try:
fields = gen_default_fields(auto_id=False)
connect.create_collection(collection_name, fields)
except Exception as e:
pytest.exit(str(e))
def teardown():
if connect.has_collection(collection_name):
connect.drop_collection(collection_name, timeout=delete_timeout)
request.addfinalizer(teardown)
assert connect.has_collection(collection_name)
return collection_name
@pytest.fixture(scope="function")
def binary_collection(request, connect):
ori_collection_name = getattr(request.module, "collection_id", "test")
collection_name = gen_unique_str(ori_collection_name)
try:
fields = gen_binary_default_fields()
connect.create_collection(collection_name, fields)
except Exception as e:
pytest.exit(str(e))
def teardown():
collection_names = connect.list_collections()
if connect.has_collection(collection_name):
connect.drop_collection(collection_name, timeout=delete_timeout)
request.addfinalizer(teardown)
assert connect.has_collection(collection_name)
return collection_name
# customised id
@pytest.fixture(scope="function")
def binary_id_collection(request, connect):
ori_collection_name = getattr(request.module, "collection_id", "test")
collection_name = gen_unique_str(ori_collection_name)
try:
fields = gen_binary_default_fields(auto_id=False)
connect.create_collection(collection_name, fields)
except Exception as e:
pytest.exit(str(e))
def teardown():
if connect.has_collection(collection_name):
connect.drop_collection(collection_name, timeout=delete_timeout)
request.addfinalizer(teardown)
assert connect.has_collection(collection_name)
return collection_name

22
tests/python/constants.py Normal file
View File

@ -0,0 +1,22 @@
from . import utils
default_fields = utils.gen_default_fields()
default_binary_fields = utils.gen_binary_default_fields()
default_entity = utils.gen_entities(1)
default_raw_binary_vector, default_binary_entity = utils.gen_binary_entities(1)
default_entity_row = utils.gen_entities_rows(1)
default_raw_binary_vector_row, default_binary_entity_row = utils.gen_binary_entities_rows(1)
default_entities = utils.gen_entities(utils.default_nb)
default_raw_binary_vectors, default_binary_entities = utils.gen_binary_entities(utils.default_nb)
default_entities_new = utils.gen_entities_new(utils.default_nb)
default_raw_binary_vectors_new, default_binary_entities_new = utils.gen_binary_entities_new(utils.default_nb)
default_entities_rows = utils.gen_entities_rows(utils.default_nb)
default_raw_binary_vectors_rows, default_binary_entities_rows = utils.gen_binary_entities_rows(utils.default_nb)

127
tests/python/factorys.py Normal file
View File

@ -0,0 +1,127 @@
# STL imports
import random
import string
import time
import datetime
import random
import struct
import sys
import uuid
from functools import wraps
sys.path.append('..')
# Third party imports
import numpy as np
import faker
from faker.providers import BaseProvider
# local application imports
from milvus.client.types import IndexType, MetricType, DataType
# grpc
from milvus.client.grpc_handler import Prepare as gPrepare
from milvus.grpc_gen import milvus_pb2
def gen_vectors(num, dim):
return [[random.random() for _ in range(dim)] for _ in range(num)]
def gen_single_vector(dim):
return [[random.random() for _ in range(dim)]]
def gen_vector(nb, d, seed=np.random.RandomState(1234)):
xb = seed.rand(nb, d).astype("float32")
return xb.tolist()
def gen_unique_str(str=None):
prefix = "".join(random.choice(string.ascii_letters + string.digits) for _ in range(8))
return prefix if str is None else str + "_" + prefix
def get_current_day():
return time.strftime('%Y-%m-%d', time.localtime())
def get_last_day(day):
tmp = datetime.datetime.now() - datetime.timedelta(days=day)
return tmp.strftime('%Y-%m-%d')
def get_next_day(day):
tmp = datetime.datetime.now() + datetime.timedelta(days=day)
return tmp.strftime('%Y-%m-%d')
def gen_long_str(num):
string = ''
for _ in range(num):
char = random.choice('tomorrow')
string += char
def gen_one_binary(topk):
ids = [random.randrange(10000000, 99999999) for _ in range(topk)]
distances = [random.random() for _ in range(topk)]
return milvus_pb2.TopKQueryResult(struct.pack(str(topk) + 'l', *ids), struct.pack(str(topk) + 'd', *distances))
def gen_nq_binaries(nq, topk):
return [gen_one_binary(topk) for _ in range(nq)]
def fake_query_bin_result(nq, topk):
return gen_nq_binaries(nq, topk)
class FakerProvider(BaseProvider):
def collection_name(self):
return 'collection_names' + str(uuid.uuid4()).replace('-', '_')
def normal_field_name(self):
return 'normal_field_names' + str(uuid.uuid4()).replace('-', '_')
def vector_field_name(self):
return 'vector_field_names' + str(uuid.uuid4()).replace('-', '_')
def name(self):
return 'name' + str(random.randint(1000, 9999))
def dim(self):
return random.randint(0, 999)
fake = faker.Faker()
fake.add_provider(FakerProvider)
def collection_name_factory():
return fake.collection_name()
def collection_schema_factory():
param = {
"fields": [
{"name": fake.normal_field_name(),"type": DataType.INT32},
{"name": fake.vector_field_name(),"type": DataType.FLOAT_VECTOR, "params": {"dim": random.randint(1, 999)}},
],
"auto_id": True,
}
return param
def records_factory(dimension, nq):
return [[random.random() for _ in range(dimension)] for _ in range(nq)]
def time_it(func):
@wraps(func)
def inner(*args, **kwrgs):
pref = time.perf_counter()
result = func(*args, **kwrgs)
delt = time.perf_counter() - pref
print(f"[{func.__name__}][{delt:.4}s]")
return result
return inner

19
tests/python/pytest.ini Normal file
View File

@ -0,0 +1,19 @@
[pytest]
log_format = [%(asctime)s-%(levelname)s-%(name)s]: %(message)s (%(filename)s:%(lineno)s)
log_date_format = %Y-%m-%d %H:%M:%S
# cli arguments. `-x`-stop test when error occurred;
addopts = -x
testpaths = .
log_cli = true
log_level = 10
timeout = 360
markers =
level: test level
serial
; level = 1

View File

@ -0,0 +1,8 @@
grpcio==1.26.0
grpcio-tools==1.26.0
numpy==1.18.1
pytest==5.3.4
pytest-cov==2.8.1
pytest-timeout==1.3.4
pymilvus-distributed==0.0.3
sklearn==0.0

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,314 @@
import pytest
from .utils import *
from .constants import *
uid = "create_collection"
class TestCreateCollection:
"""
******************************************************************
The following cases are used to test `create_collection` function
******************************************************************
"""
@pytest.fixture(
scope="function",
params=gen_single_filter_fields()
)
def get_filter_field(self, request):
yield request.param
@pytest.fixture(
scope="function",
params=gen_single_vector_fields()
)
def get_vector_field(self, request):
yield request.param
@pytest.fixture(
scope="function",
params=gen_segment_row_limits()
)
def get_segment_row_limit(self, request):
yield request.param
def test_create_collection_fields(self, connect, get_filter_field, get_vector_field):
'''
target: test create normal collection with different fields
method: create collection with diff fields: metric/field_type/...
expected: no exception raised
'''
filter_field = get_filter_field
# logging.getLogger().info(filter_field)
vector_field = get_vector_field
collection_name = gen_unique_str(uid)
fields = {
"fields": [filter_field, vector_field],
}
# logging.getLogger().info(fields)
connect.create_collection(collection_name, fields)
assert connect.has_collection(collection_name)
def test_create_collection_fields_create_index(self, connect, get_filter_field, get_vector_field):
'''
target: test create normal collection with different fields
method: create collection with diff fields: metric/field_type/...
expected: no exception raised
'''
filter_field = get_filter_field
vector_field = get_vector_field
collection_name = gen_unique_str(uid)
fields = {
"fields": [filter_field, vector_field],
}
connect.create_collection(collection_name, fields)
assert connect.has_collection(collection_name)
@pytest.mark.skip("no segment_row_limit")
def test_create_collection_segment_row_limit(self, connect):
'''
target: test create normal collection with different fields
method: create collection with diff segment_row_limit
expected: no exception raised
'''
collection_name = gen_unique_str(uid)
fields = copy.deepcopy(default_fields)
# fields["segment_row_limit"] = get_segment_row_limit
connect.create_collection(collection_name, fields)
assert connect.has_collection(collection_name)
@pytest.mark.skip("no flush")
def _test_create_collection_auto_flush_disabled(self, connect):
'''
target: test create normal collection, with large auto_flush_interval
method: create collection with corrent params
expected: create status return ok
'''
disable_flush(connect)
collection_name = gen_unique_str(uid)
try:
connect.create_collection(collection_name, default_fields)
finally:
enable_flush(connect)
def test_create_collection_after_insert(self, connect, collection):
'''
target: test insert vector, then create collection again
method: insert vector and create collection
expected: error raised
'''
# pdb.set_trace()
connect.bulk_insert(collection, default_entity)
with pytest.raises(Exception) as e:
connect.create_collection(collection, default_fields)
def test_create_collection_after_insert_flush(self, connect, collection):
'''
target: test insert vector, then create collection again
method: insert vector and create collection
expected: error raised
'''
connect.bulk_insert(collection, default_entity)
connect.flush([collection])
with pytest.raises(Exception) as e:
connect.create_collection(collection, default_fields)
def test_create_collection_without_connection(self, dis_connect):
'''
target: test create collection, without connection
method: create collection with correct params, with a disconnected instance
expected: error raised
'''
collection_name = gen_unique_str(uid)
with pytest.raises(Exception) as e:
dis_connect.create_collection(collection_name, default_fields)
def test_create_collection_existed(self, connect):
'''
target: test create collection but the collection name have already existed
method: create collection with the same collection_name
expected: error raised
'''
collection_name = gen_unique_str(uid)
connect.create_collection(collection_name, default_fields)
with pytest.raises(Exception) as e:
connect.create_collection(collection_name, default_fields)
def test_create_after_drop_collection(self, connect, collection):
'''
target: create with the same collection name after collection dropped
method: delete, then create
expected: create success
'''
connect.drop_collection(collection)
time.sleep(2)
connect.create_collection(collection, default_fields)
@pytest.mark.level(2)
def test_create_collection_multithread(self, connect):
'''
target: test create collection with multithread
method: create collection using multithread,
expected: collections are created
'''
threads_num = 8
threads = []
collection_names = []
def create():
collection_name = gen_unique_str(uid)
collection_names.append(collection_name)
connect.create_collection(collection_name, default_fields)
for i in range(threads_num):
t = threading.Thread(target=create, args=())
threads.append(t)
t.start()
time.sleep(0.2)
for t in threads:
t.join()
for item in collection_names:
assert item in connect.list_collections()
connect.drop_collection(item)
class TestCreateCollectionInvalid(object):
"""
Test creating collections with invalid params
"""
@pytest.fixture(
scope="function",
params=gen_invalid_metric_types()
)
def get_metric_type(self, request):
yield request.param
@pytest.fixture(
scope="function",
params=gen_invalid_ints()
)
def get_segment_row_limit(self, request):
yield request.param
@pytest.fixture(
scope="function",
params=gen_invalid_ints()
)
def get_dim(self, request):
yield request.param
@pytest.fixture(
scope="function",
params=gen_invalid_strs()
)
def get_invalid_string(self, request):
yield request.param
@pytest.fixture(
scope="function",
params=gen_invalid_field_types()
)
def get_field_type(self, request):
yield request.param
@pytest.mark.level(2)
@pytest.mark.skip("no segment row limit")
def test_create_collection_with_invalid_segment_row_limit(self, connect, get_segment_row_limit):
collection_name = gen_unique_str()
fields = copy.deepcopy(default_fields)
fields["segment_row_limit"] = get_segment_row_limit
with pytest.raises(Exception) as e:
connect.create_collection(collection_name, fields)
@pytest.mark.level(2)
def test_create_collection_with_invalid_dimension(self, connect, get_dim):
dimension = get_dim
collection_name = gen_unique_str()
fields = copy.deepcopy(default_fields)
fields["fields"][-1]["params"]["dim"] = dimension
with pytest.raises(Exception) as e:
connect.create_collection(collection_name, fields)
@pytest.mark.level(2)
def test_create_collection_with_invalid_collectionname(self, connect, get_invalid_string):
collection_name = get_invalid_string
with pytest.raises(Exception) as e:
connect.create_collection(collection_name, default_fields)
@pytest.mark.level(2)
def test_create_collection_with_empty_collectionname(self, connect):
collection_name = ''
with pytest.raises(Exception) as e:
connect.create_collection(collection_name, default_fields)
@pytest.mark.level(2)
def test_create_collection_with_none_collectionname(self, connect):
collection_name = None
with pytest.raises(Exception) as e:
connect.create_collection(collection_name, default_fields)
def test_create_collection_None(self, connect):
'''
target: test create collection but the collection name is None
method: create collection, param collection_name is None
expected: create raise error
'''
with pytest.raises(Exception) as e:
connect.create_collection(None, default_fields)
def test_create_collection_no_dimension(self, connect):
'''
target: test create collection with no dimension params
method: create collection with corrent params
expected: create status return ok
'''
collection_name = gen_unique_str(uid)
fields = copy.deepcopy(default_fields)
fields["fields"][-1]["params"].pop("dim")
with pytest.raises(Exception) as e:
connect.create_collection(collection_name, fields)
@pytest.mark.skip("no segment row limit")
def test_create_collection_no_segment_row_limit(self, connect):
'''
target: test create collection with no segment_row_limit params
method: create collection with correct params
expected: use default default_segment_row_limit
'''
collection_name = gen_unique_str(uid)
fields = copy.deepcopy(default_fields)
fields.pop("segment_row_limit")
connect.create_collection(collection_name, fields)
res = connect.get_collection_info(collection_name)
# logging.getLogger().info(res)
assert res["segment_row_limit"] == default_server_segment_row_limit
def test_create_collection_limit_fields(self, connect):
collection_name = gen_unique_str(uid)
limit_num = 64
fields = copy.deepcopy(default_fields)
for i in range(limit_num):
field_name = gen_unique_str("field_name")
field = {"name": field_name, "type": DataType.INT64}
fields["fields"].append(field)
with pytest.raises(Exception) as e:
connect.create_collection(collection_name, fields)
@pytest.mark.level(2)
def test_create_collection_invalid_field_name(self, connect, get_invalid_string):
collection_name = gen_unique_str(uid)
fields = copy.deepcopy(default_fields)
field_name = get_invalid_string
field = {"name": field_name, "type": DataType.INT64}
fields["fields"].append(field)
with pytest.raises(Exception) as e:
connect.create_collection(collection_name, fields)
def test_create_collection_invalid_field_type(self, connect, get_field_type):
collection_name = gen_unique_str(uid)
fields = copy.deepcopy(default_fields)
field_type = get_field_type
field = {"name": "test_field", "type": field_type}
fields["fields"].append(field)
with pytest.raises(Exception) as e:
connect.create_collection(collection_name, fields)

View File

@ -0,0 +1,98 @@
import pytest
from .utils import *
from .constants import *
uniq_id = "drop_collection"
class TestDropCollection:
"""
******************************************************************
The following cases are used to test `drop_collection` function
******************************************************************
"""
def test_drop_collection(self, connect, collection):
'''
target: test delete collection created with correct params
method: create collection and then delete,
assert the value returned by delete method
expected: status ok, and no collection in collections
'''
connect.drop_collection(collection)
time.sleep(2)
assert not connect.has_collection(collection)
def test_drop_collection_without_connection(self, collection, dis_connect):
'''
target: test describe collection, without connection
method: drop collection with correct params, with a disconnected instance
expected: drop raise exception
'''
with pytest.raises(Exception) as e:
dis_connect.drop_collection(collection)
def test_drop_collection_not_existed(self, connect):
'''
target: test if collection not created
method: random a collection name, which not existed in db,
assert the exception raised returned by drp_collection method
expected: False
'''
collection_name = gen_unique_str(uniq_id)
with pytest.raises(Exception) as e:
connect.drop_collection(collection_name)
@pytest.mark.level(2)
def test_create_drop_collection_multithread(self, connect):
'''
target: test create and drop collection with multithread
method: create and drop collection using multithread,
expected: collections are created, and dropped
'''
threads_num = 8
threads = []
collection_names = []
def create():
collection_name = gen_unique_str(uniq_id)
collection_names.append(collection_name)
connect.create_collection(collection_name, default_fields)
connect.drop_collection(collection_name)
for i in range(threads_num):
t = threading.Thread(target=create, args=())
threads.append(t)
t.start()
time.sleep(0.2)
for t in threads:
t.join()
for item in collection_names:
assert not connect.has_collection(item)
class TestDropCollectionInvalid(object):
"""
Test has collection with invalid params
"""
@pytest.fixture(
scope="function",
params=gen_invalid_strs()
)
def get_collection_name(self, request):
yield request.param
@pytest.mark.level(2)
def test_drop_collection_with_invalid_collectionname(self, connect, get_collection_name):
collection_name = get_collection_name
with pytest.raises(Exception) as e:
connect.has_collection(collection_name)
def test_drop_collection_with_empty_collectionname(self, connect):
collection_name = ''
with pytest.raises(Exception) as e:
connect.has_collection(collection_name)
def test_drop_collection_with_none_collectionname(self, connect):
collection_name = None
with pytest.raises(Exception) as e:
connect.has_collection(collection_name)

View File

@ -0,0 +1,234 @@
import pytest
from .utils import *
from .constants import *
uid = "collection_info"
class TestInfoBase:
@pytest.fixture(
scope="function",
params=gen_single_filter_fields()
)
def get_filter_field(self, request):
yield request.param
@pytest.fixture(
scope="function",
params=gen_single_vector_fields()
)
def get_vector_field(self, request):
yield request.param
@pytest.fixture(
scope="function",
params=gen_segment_row_limits()
)
def get_segment_row_limit(self, request):
yield request.param
@pytest.fixture(
scope="function",
params=gen_simple_index()
)
def get_simple_index(self, request, connect):
logging.getLogger().info(request.param)
if str(connect._cmd("mode")) == "CPU":
if request.param["index_type"] in index_cpu_not_support():
pytest.skip("sq8h not support in CPU mode")
return request.param
"""
******************************************************************
The following cases are used to test `get_collection_info` function, no data in collection
******************************************************************
"""
@pytest.mark.skip("no segment row limit and type")
def test_info_collection_fields(self, connect, get_filter_field, get_vector_field):
'''
target: test create normal collection with different fields, check info returned
method: create collection with diff fields: metric/field_type/..., calling `get_collection_info`
expected: no exception raised, and value returned correct
'''
filter_field = get_filter_field
vector_field = get_vector_field
collection_name = gen_unique_str(uid)
fields = {
"fields": [filter_field, vector_field],
"segment_row_limit": default_segment_row_limit
}
connect.create_collection(collection_name, fields)
res = connect.get_collection_info(collection_name)
assert res['auto_id'] == True
assert res['segment_row_limit'] == default_segment_row_limit
assert len(res["fields"]) == 2
for field in res["fields"]:
if field["type"] == filter_field:
assert field["name"] == filter_field["name"]
elif field["type"] == vector_field:
assert field["name"] == vector_field["name"]
assert field["params"] == vector_field["params"]
@pytest.mark.skip("no segment row limit and type")
def test_create_collection_segment_row_limit(self, connect, get_segment_row_limit):
'''
target: test create normal collection with different fields
method: create collection with diff segment_row_limit
expected: no exception raised
'''
collection_name = gen_unique_str(uid)
fields = copy.deepcopy(default_fields)
fields["segment_row_limit"] = get_segment_row_limit
connect.create_collection(collection_name, fields)
# assert segment row count
res = connect.get_collection_info(collection_name)
assert res['segment_row_limit'] == get_segment_row_limit
@pytest.mark.skip("no create Index")
def test_get_collection_info_after_index_created(self, connect, collection, get_simple_index):
connect.create_index(collection, default_float_vec_field_name, get_simple_index)
res = connect.get_collection_info(collection)
for field in res["fields"]:
if field["name"] == default_float_vec_field_name:
index = field["indexes"][0]
assert index["index_type"] == get_simple_index["index_type"]
assert index["metric_type"] == get_simple_index["metric_type"]
@pytest.mark.level(2)
def test_get_collection_info_without_connection(self, connect, collection, dis_connect):
'''
target: test get collection info, without connection
method: calling get collection info with correct params, with a disconnected instance
expected: get collection info raise exception
'''
with pytest.raises(Exception) as e:
assert connect.get_collection_info(dis_connect, collection)
def test_get_collection_info_not_existed(self, connect):
'''
target: test if collection not created
method: random a collection name, which not existed in db,
assert the value returned by get_collection_info method
expected: False
'''
collection_name = gen_unique_str(uid)
with pytest.raises(Exception) as e:
res = connect.get_collection_info(connect, collection_name)
@pytest.mark.level(2)
def test_get_collection_info_multithread(self, connect):
'''
target: test create collection with multithread
method: create collection using multithread,
expected: collections are created
'''
threads_num = 4
threads = []
collection_name = gen_unique_str(uid)
connect.create_collection(collection_name, default_fields)
def get_info():
res = connect.get_collection_info(connect, collection_name)
# assert
for i in range(threads_num):
t = threading.Thread(target=get_info, args=())
threads.append(t)
t.start()
time.sleep(0.2)
for t in threads:
t.join()
"""
******************************************************************
The following cases are used to test `get_collection_info` function, and insert data in collection
******************************************************************
"""
@pytest.mark.skip("no segment row limit and type")
def test_info_collection_fields_after_insert(self, connect, get_filter_field, get_vector_field):
'''
target: test create normal collection with different fields, check info returned
method: create collection with diff fields: metric/field_type/..., calling `get_collection_info`
expected: no exception raised, and value returned correct
'''
filter_field = get_filter_field
vector_field = get_vector_field
collection_name = gen_unique_str(uid)
fields = {
"fields": [filter_field, vector_field],
"segment_row_limit": default_segment_row_limit
}
connect.create_collection(collection_name, fields)
entities = gen_entities_by_fields(fields["fields"], default_nb, vector_field["params"]["dim"])
res_ids = connect.bulk_insert(collection_name, entities)
connect.flush([collection_name])
res = connect.get_collection_info(collection_name)
assert res['auto_id'] == True
assert res['segment_row_limit'] == default_segment_row_limit
assert len(res["fields"]) == 2
for field in res["fields"]:
if field["type"] == filter_field:
assert field["name"] == filter_field["name"]
elif field["type"] == vector_field:
assert field["name"] == vector_field["name"]
assert field["params"] == vector_field["params"]
@pytest.mark.skip("not segment row limit")
def test_create_collection_segment_row_limit_after_insert(self, connect, get_segment_row_limit):
'''
target: test create normal collection with different fields
method: create collection with diff segment_row_limit
expected: no exception raised
'''
collection_name = gen_unique_str(uid)
fields = copy.deepcopy(default_fields)
fields["segment_row_limit"] = get_segment_row_limit
connect.create_collection(collection_name, fields)
entities = gen_entities_by_fields(fields["fields"], default_nb, fields["fields"][-1]["params"]["dim"])
res_ids = connect.bulk_insert(collection_name, entities)
connect.flush([collection_name])
res = connect.get_collection_info(collection_name)
assert res['auto_id'] == True
assert res['segment_row_limit'] == get_segment_row_limit
class TestInfoInvalid(object):
"""
Test get collection info with invalid params
"""
@pytest.fixture(
scope="function",
params=gen_invalid_strs()
)
def get_collection_name(self, request):
yield request.param
@pytest.mark.level(2)
def test_get_collection_info_with_invalid_collectionname(self, connect, get_collection_name):
collection_name = get_collection_name
with pytest.raises(Exception) as e:
connect.get_collection_info(collection_name)
@pytest.mark.level(2)
def test_get_collection_info_with_empty_collectionname(self, connect):
collection_name = ''
with pytest.raises(Exception) as e:
connect.get_collection_info(collection_name)
@pytest.mark.level(2)
def test_get_collection_info_with_none_collectionname(self, connect):
collection_name = None
with pytest.raises(Exception) as e:
connect.get_collection_info(collection_name)
def test_get_collection_info_None(self, connect):
'''
target: test create collection but the collection name is None
method: create collection, param collection_name is None
expected: create raise error
'''
with pytest.raises(Exception) as e:
connect.get_collection_info(None)

View File

@ -0,0 +1,93 @@
import pytest
from .utils import *
from .constants import *
uid = "has_collection"
class TestHasCollection:
"""
******************************************************************
The following cases are used to test `has_collection` function
******************************************************************
"""
def test_has_collection(self, connect, collection):
'''
target: test if the created collection existed
method: create collection, assert the value returned by has_collection method
expected: True
'''
assert connect.has_collection(collection)
@pytest.mark.level(2)
def test_has_collection_without_connection(self, collection, dis_connect):
'''
target: test has collection, without connection
method: calling has collection with correct params, with a disconnected instance
expected: has collection raise exception
'''
with pytest.raises(Exception) as e:
assert dis_connect.has_collection(collection)
def test_has_collection_not_existed(self, connect):
'''
target: test if collection not created
method: random a collection name, which not existed in db,
assert the value returned by has_collection method
expected: False
'''
collection_name = gen_unique_str("test_collection")
assert not connect.has_collection(collection_name)
@pytest.mark.level(2)
def test_has_collection_multithread(self, connect):
'''
target: test create collection with multithread
method: create collection using multithread,
expected: collections are created
'''
threads_num = 4
threads = []
collection_name = gen_unique_str(uid)
connect.create_collection(collection_name, default_fields)
def has():
assert connect.has_collection(collection_name)
# assert not assert_collection(connect, collection_name)
for i in range(threads_num):
t = MilvusTestThread(target=has, args=())
threads.append(t)
t.start()
time.sleep(0.2)
for t in threads:
t.join()
class TestHasCollectionInvalid(object):
"""
Test has collection with invalid params
"""
@pytest.fixture(
scope="function",
params=gen_invalid_strs()
)
def get_collection_name(self, request):
yield request.param
@pytest.mark.level(2)
def test_has_collection_with_invalid_collectionname(self, connect, get_collection_name):
collection_name = get_collection_name
with pytest.raises(Exception) as e:
connect.has_collection(collection_name)
@pytest.mark.level(2)
def test_has_collection_with_empty_collectionname(self, connect):
collection_name = ''
with pytest.raises(Exception) as e:
connect.has_collection(collection_name)
@pytest.mark.level(2)
def test_has_collection_with_none_collectionname(self, connect):
collection_name = None
with pytest.raises(Exception) as e:
connect.has_collection(collection_name)

462
tests/python/test_insert.py Normal file
View File

@ -0,0 +1,462 @@
import pytest
from .utils import *
from .constants import *
ADD_TIMEOUT = 600
uid = "test_insert"
field_name = default_float_vec_field_name
binary_field_name = default_binary_vec_field_name
default_single_query = {
"bool": {
"must": [
{"vector": {field_name: {"topk": 10, "query": gen_vectors(1, default_dim), "metric_type": "L2",
"params": {"nprobe": 10}}}}
]
}
}
class TestInsertBase:
"""
******************************************************************
The following cases are used to test `insert` function
******************************************************************
"""
@pytest.fixture(
scope="function",
params=gen_simple_index()
)
def get_simple_index(self, request, connect):
if str(connect._cmd("mode")) == "CPU":
if request.param["index_type"] in index_cpu_not_support():
pytest.skip("CPU not support index_type: ivf_sq8h")
return request.param
@pytest.fixture(
scope="function",
params=gen_single_filter_fields()
)
def get_filter_field(self, request):
yield request.param
@pytest.fixture(
scope="function",
params=gen_single_vector_fields()
)
def get_vector_field(self, request):
yield request.param
def test_add_vector_with_empty_vector(self, connect, collection):
'''
target: test add vectors with empty vectors list
method: set empty vectors list as add method params
expected: raises a Exception
'''
vector = []
with pytest.raises(Exception) as e:
status, ids = connect.insert(collection, vector)
def test_add_vector_with_None(self, connect, collection):
'''
target: test add vectors with None
method: set None as add method params
expected: raises a Exception
'''
vector = None
with pytest.raises(Exception) as e:
status, ids = connect.insert(collection, vector)
@pytest.mark.timeout(ADD_TIMEOUT)
def test_insert_collection_not_existed(self, connect):
'''
target: test insert, with collection not existed
method: insert entity into a random named collection
expected: error raised
'''
collection_name = gen_unique_str(uid)
with pytest.raises(Exception) as e:
connect.insert(collection_name, default_entities_rows)
@pytest.mark.timeout(ADD_TIMEOUT)
def test_insert_drop_collection(self, connect, collection):
'''
target: test delete collection after insert vector
method: insert vector and delete collection
expected: no error raised
'''
ids = connect.insert(collection, default_entity_row)
assert len(ids) == 1
connect.drop_collection(collection)
@pytest.mark.timeout(ADD_TIMEOUT)
def test_insert_sleep_drop_collection(self, connect, collection):
'''
target: test delete collection after insert vector for a while
method: insert vector, sleep, and delete collection
expected: no error raised
'''
ids = connect.insert(collection, default_entity_row)
assert len(ids) == 1
connect.flush([collection])
connect.drop_collection(collection)
@pytest.mark.skip("create_index")
@pytest.mark.timeout(ADD_TIMEOUT)
def test_insert_create_index(self, connect, collection, get_simple_index):
'''
target: test build index insert after vector
method: insert vector and build index
expected: no error raised
'''
ids = connect.insert(collection, default_entities_rows)
assert len(ids) == default_nb
connect.flush([collection])
connect.create_index(collection, field_name, get_simple_index)
info = connect.get_collection_info(collection)
fields = info["fields"]
for field in fields:
if field["name"] == field_name:
assert field["indexes"][0] == get_simple_index
@pytest.mark.skip("create_index")
@pytest.mark.timeout(ADD_TIMEOUT)
def test_insert_after_create_index(self, connect, collection, get_simple_index):
'''
target: test build index insert after vector
method: insert vector and build index
expected: no error raised
'''
connect.create_index(collection, field_name, get_simple_index)
ids = connect.insert(collection, default_entities_rows)
assert len(ids) == default_nb
info = connect.get_collection_info(collection)
fields = info["fields"]
for field in fields:
if field["name"] == field_name:
assert field["indexes"][0] == get_simple_index
@pytest.mark.skip(" todo fix search")
@pytest.mark.timeout(ADD_TIMEOUT)
def test_insert_search(self, connect, collection):
'''
target: test search vector after insert vector after a while
method: insert vector, sleep, and search collection
expected: no error raised
'''
ids = connect.insert(collection, default_entities_rows)
connect.flush([collection])
res = connect.search(collection, default_single_query)
logging.getLogger().debug(res)
assert res
@pytest.mark.skip("segment row count")
def test_insert_segment_row_count(self, connect, collection):
nb = default_segment_row_limit + 1
res_ids = connect.insert(collection, gen_entities_rows(nb))
connect.flush([collection])
assert len(res_ids) == nb
stats = connect.get_collection_stats(collection)
assert len(stats['partitions'][0]['segments']) == 2
for segment in stats['partitions'][0]['segments']:
assert segment['row_count'] in [default_segment_row_limit, 1]
@pytest.fixture(
scope="function",
params=[
1,
2000
],
)
def insert_count(self, request):
yield request.param
@pytest.mark.timeout(ADD_TIMEOUT)
def test_insert_ids_not_match(self, connect, id_collection, insert_count):
'''
target: test insert vectors in collection, use customize ids
method: create collection and insert vectors in it, check the ids returned and the collection length after vectors inserted
expected: the length of ids and the collection row count
'''
nb = insert_count
with pytest.raises(Exception) as e:
res_ids = connect.insert(id_collection, gen_entities_rows(nb))
@pytest.mark.timeout(ADD_TIMEOUT)
def test_insert_twice_ids_no_ids(self, connect, collection):
'''
target: check the result of insert, with params ids and no ids
method: test insert vectors twice, use customize ids first, and then use no ids
expected: error raised
'''
with pytest.raises(Exception) as e:
res_ids = connect.insert(collection, gen_entities_rows(default_nb, _id=False))
@pytest.mark.timeout(ADD_TIMEOUT)
def test_insert_tag(self, connect, collection):
'''
target: test insert entities in collection created before
method: create collection and insert entities in it, with the partition_tag param
expected: the collection row count equals to nq
'''
connect.create_partition(collection, default_tag)
ids = connect.insert(collection, default_entities_rows, partition_tag=default_tag)
assert len(ids) == default_nb
assert connect.has_partition(collection, default_tag)
@pytest.mark.timeout(ADD_TIMEOUT)
def test_insert_tag_with_ids(self, connect, id_collection):
'''
target: test insert entities in collection created before, insert with ids
method: create collection and insert entities in it, with the partition_tag param
expected: the collection row count equals to nq
'''
connect.create_partition(id_collection, default_tag)
ids = [i for i in range(default_nb)]
res_ids = connect.insert(id_collection, gen_entities_rows(default_nb, _id=False), partition_tag=default_tag)
assert res_ids == ids
@pytest.mark.timeout(ADD_TIMEOUT)
def test_insert_tag_not_existed(self, connect, collection):
'''
target: test insert entities in collection created before
method: create collection and insert entities in it, with the not existed partition_tag param
expected: error raised
'''
tag = gen_unique_str()
with pytest.raises(Exception) as e:
ids = connect.insert(collection, default_entities_rows, partition_tag=tag)
@pytest.mark.skip("todo support count entites")
@pytest.mark.timeout(ADD_TIMEOUT)
def test_insert_tag_existed(self, connect, collection):
'''
target: test insert entities in collection created before
method: create collection and insert entities in it repeatly, with the partition_tag param
expected: the collection row count equals to nq
'''
connect.create_partition(collection, default_tag)
ids = connect.insert(collection, default_entities_rows, partition_tag=default_tag)
ids = connect.insert(collection, default_entities_rows, partition_tag=default_tag)
connect.flush([collection])
res_count = connect.count_entities(collection)
assert res_count == 2 * default_nb
@pytest.mark.level(2)
def test_insert_collection_not_existed(self, connect):
'''
target: test insert entities in collection, which not existed before
method: insert entities collection not existed, check the status
expected: error raised
'''
with pytest.raises(Exception) as e:
ids = connect.insert(gen_unique_str("not_exist_collection"), default_entities_rows)
@pytest.mark.skip("todo support row data check")
def test_insert_dim_not_matched(self, connect, collection):
'''
target: test insert entities, the vector dimension is not equal to the collection dimension
method: the entities dimension is half of the collection dimension, check the status
expected: error raised
'''
vectors = gen_vectors(default_nb, int(default_dim) // 2)
insert_entities = copy.deepcopy(default_entities_rows)
insert_entities[-1][default_float_vec_field_name] = vectors
with pytest.raises(Exception) as e:
ids = connect.insert(collection, insert_entities)
class TestInsertBinary:
@pytest.fixture(
scope="function",
params=gen_binary_index()
)
def get_binary_index(self, request):
request.param["metric_type"] = "JACCARD"
return request.param
@pytest.mark.skip("count entities")
def test_insert_binary_entities(self, connect, binary_collection):
'''
target: test insert entities in binary collection
method: create collection and insert binary entities in it
expected: the collection row count equals to nb
'''
ids = connect.insert(binary_collection, default_binary_entities_rows)
assert len(ids) == default_nb
connect.flush()
assert connect.count_entities(binary_collection) == default_nb
def test_insert_binary_tag(self, connect, binary_collection):
'''
target: test insert entities and create partition tag
method: create collection and insert binary entities in it, with the partition_tag param
expected: the collection row count equals to nb
'''
connect.create_partition(binary_collection, default_tag)
ids = connect.insert(binary_collection, default_binary_entities_rows, partition_tag=default_tag)
assert len(ids) == default_nb
assert connect.has_partition(binary_collection, default_tag)
# TODO
@pytest.mark.skip("count entities")
@pytest.mark.level(2)
def test_insert_binary_multi_times(self, connect, binary_collection):
'''
target: test insert entities multi times and final flush
method: create collection and insert binary entity multi and final flush
expected: the collection row count equals to nb
'''
for i in range(default_nb):
ids = connect.insert(binary_collection, default_binary_entity_row)
assert len(ids) == 1
connect.flush([binary_collection])
assert connect.count_entities(binary_collection) == default_nb
@pytest.mark.skip("create index")
def test_insert_binary_after_create_index(self, connect, binary_collection, get_binary_index):
'''
target: test insert binary entities after build index
method: build index and insert entities
expected: no error raised
'''
connect.create_index(binary_collection, binary_field_name, get_binary_index)
ids = connect.insert(binary_collection, default_binary_entities_rows)
assert len(ids) == default_nb
connect.flush([binary_collection])
info = connect.get_collection_info(binary_collection)
fields = info["fields"]
for field in fields:
if field["name"] == binary_field_name:
assert field["indexes"][0] == get_binary_index
@pytest.mark.skip("create index")
@pytest.mark.timeout(ADD_TIMEOUT)
def test_insert_binary_create_index(self, connect, binary_collection, get_binary_index):
'''
target: test build index insert after vector
method: insert vector and build index
expected: no error raised
'''
ids = connect.insert(binary_collection, default_binary_entities_rows)
assert len(ids) == default_nb
connect.flush([binary_collection])
connect.create_index(binary_collection, binary_field_name, get_binary_index)
info = connect.get_collection_info(binary_collection)
fields = info["fields"]
for field in fields:
if field["name"] == binary_field_name:
assert field["indexes"][0] == get_binary_index
class TestInsertInvalid(object):
"""
Test inserting vectors with invalid collection names
"""
@pytest.fixture(
scope="function",
params=gen_invalid_strs()
)
def get_collection_name(self, request):
yield request.param
@pytest.fixture(
scope="function",
params=gen_invalid_strs()
)
def get_tag_name(self, request):
yield request.param
@pytest.fixture(
scope="function",
params=gen_invalid_strs()
)
def get_field_name(self, request):
yield request.param
@pytest.fixture(
scope="function",
params=gen_invalid_strs()
)
def get_field_type(self, request):
yield request.param
@pytest.fixture(
scope="function",
params=gen_invalid_strs()
)
def get_field_int_value(self, request):
yield request.param
@pytest.fixture(
scope="function",
params=gen_invalid_ints()
)
def get_entity_id(self, request):
yield request.param
@pytest.fixture(
scope="function",
params=gen_invalid_vectors()
)
def get_field_vectors_value(self, request):
yield request.param
@pytest.mark.skip("todo support row data check")
def test_insert_field_name_not_match(self, connect, collection):
'''
target: test insert, with field name not matched
method: create collection and insert entities in it
expected: raise an exception
'''
tmp_entity = copy.deepcopy(default_entity_row)
tmp_entity[0]["string"] = "string"
with pytest.raises(Exception):
connect.insert(collection, default_entity_row)
def test_insert_with_invalid_collection_name(self, connect, get_collection_name):
collection_name = get_collection_name
with pytest.raises(Exception):
connect.insert(collection_name, default_entity_row)
def test_insert_with_invalid_tag_name(self, connect, collection, get_tag_name):
tag_name = get_tag_name
connect.create_partition(collection, default_tag)
if tag_name is not None:
with pytest.raises(Exception):
connect.insert(collection, default_entity_row, partition_tag=tag_name)
else:
connect.insert(collection, default_entity_row, partition_tag=tag_name)
@pytest.mark.skip("todo support row data check")
def test_insert_with_less_field(self, connect, collection):
tmp_entity = copy.deepcopy(default_entity_row)
tmp_entity[0].pop(default_float_vec_field_name)
with pytest.raises(Exception):
connect.insert(collection, tmp_entity)
def test_insert_with_less_field_id(self, connect, id_collection):
tmp_entity = copy.deepcopy(gen_entities_rows(default_nb, _id=False))
tmp_entity[0].pop("_id")
with pytest.raises(Exception):
connect.insert(id_collection, tmp_entity)
def test_insert_with_more_field(self, connect, collection):
tmp_entity = copy.deepcopy(default_entity_row)
tmp_entity[0]["new_field"] = 1
with pytest.raises(Exception):
connect.insert(collection, tmp_entity)
def test_insert_with_more_field_id(self, connect, collection):
tmp_entity = copy.deepcopy(default_entity_row)
tmp_entity[0]["_id"] = 1
with pytest.raises(Exception):
connect.insert(collection, tmp_entity)
@pytest.mark.skip("todo support row data check")
def test_insert_with_invalid_field_vector_value(self, connect, collection, get_field_vectors_value):
tmp_entity = copy.deepcopy(default_entity_row)
tmp_entity[0][default_float_vec_field_name][1] = get_field_vectors_value
with pytest.raises(Exception):
connect.insert(collection, tmp_entity)

View File

@ -0,0 +1,86 @@
import pytest
from .utils import *
from .constants import *
uid = "list_collections"
class TestListCollections:
"""
******************************************************************
The following cases are used to test `list_collections` function
******************************************************************
"""
def test_list_collections(self, connect, collection):
'''
target: test list collections
method: create collection, assert the value returned by list_collections method
expected: True
'''
assert collection in connect.list_collections()
def test_list_collections_multi_collections(self, connect):
'''
target: test list collections
method: create collection, assert the value returned by list_collections method
expected: True
'''
collection_num = 50
for i in range(collection_num):
collection_name = gen_unique_str(uid)
connect.create_collection(collection_name, default_fields)
assert collection_name in connect.list_collections()
@pytest.mark.level(2)
def test_list_collections_without_connection(self, dis_connect):
'''
target: test list collections, without connection
method: calling list collections with correct params, with a disconnected instance
expected: list collections raise exception
'''
with pytest.raises(Exception) as e:
dis_connect.list_collections()
def test_list_collections_not_existed(self, connect):
'''
target: test if collection not created
method: random a collection name, which not existed in db,
assert the value returned by list_collections method
expected: False
'''
collection_name = gen_unique_str(uid)
assert collection_name not in connect.list_collections()
@pytest.mark.level(2)
def test_list_collections_no_collection(self, connect):
'''
target: test show collections is correct or not, if no collection in db
method: delete all collections,
assert the value returned by list_collections method is equal to []
expected: the status is ok, and the result is equal to []
'''
result = connect.list_collections()
if result:
for collection_name in result:
assert connect.has_collection(collection_name)
@pytest.mark.level(2)
def test_list_collections_multithread(self, connect):
'''
target: test create collection with multithread
method: create collection using multithread,
expected: collections are created
'''
threads_num = 4
threads = []
collection_name = gen_unique_str(uid)
connect.create_collection(collection_name, default_fields)
def _list():
assert collection_name in connect.list_collections()
for i in range(threads_num):
t = threading.Thread(target=_list, args=())
threads.append(t)
t.start()
time.sleep(0.2)
for t in threads:
t.join()

View File

@ -0,0 +1,396 @@
import pytest
from .utils import *
from .constants import *
TIMEOUT = 120
class TestCreateBase:
"""
******************************************************************
The following cases are used to test `create_partition` function
******************************************************************
"""
def test_create_partition(self, connect, collection):
'''
target: test create partition, check status returned
method: call function: create_partition
expected: status ok
'''
connect.create_partition(collection, default_tag)
@pytest.mark.level(2)
@pytest.mark.timeout(600)
@pytest.mark.skip
def test_create_partition_limit(self, connect, collection, args):
'''
target: test create partitions, check status returned
method: call function: create_partition for 4097 times
expected: exception raised
'''
threads_num = 8
threads = []
if args["handler"] == "HTTP":
pytest.skip("skip in http mode")
def create(connect, threads_num):
for i in range(max_partition_num // threads_num):
tag_tmp = gen_unique_str()
connect.create_partition(collection, tag_tmp)
for i in range(threads_num):
m = get_milvus(host=args["ip"], port=args["port"], handler=args["handler"])
t = threading.Thread(target=create, args=(m, threads_num, ))
threads.append(t)
t.start()
for t in threads:
t.join()
tag_tmp = gen_unique_str()
with pytest.raises(Exception) as e:
connect.create_partition(collection, tag_tmp)
def test_create_partition_repeat(self, connect, collection):
'''
target: test create partition, check status returned
method: call function: create_partition
expected: status ok
'''
connect.create_partition(collection, default_tag)
with pytest.raises(Exception) as e:
connect.create_partition(collection, default_tag)
def test_create_partition_collection_not_existed(self, connect):
'''
target: test create partition, its owner collection name not existed in db, check status returned
method: call function: create_partition
expected: status not ok
'''
collection_name = gen_unique_str()
with pytest.raises(Exception) as e:
connect.create_partition(collection_name, default_tag)
def test_create_partition_tag_name_None(self, connect, collection):
'''
target: test create partition, tag name set None, check status returned
method: call function: create_partition
expected: status ok
'''
tag_name = None
with pytest.raises(Exception) as e:
connect.create_partition(collection, tag_name)
def test_create_different_partition_tags(self, connect, collection):
'''
target: test create partition twice with different names
method: call function: create_partition, and again
expected: status ok
'''
connect.create_partition(collection, default_tag)
tag_name = gen_unique_str()
connect.create_partition(collection, tag_name)
tag_list = connect.list_partitions(collection)
assert default_tag in tag_list
assert tag_name in tag_list
assert "_default" in tag_list
@pytest.mark.skip("not support custom id")
def test_create_partition_insert_default(self, connect, id_collection):
'''
target: test create partition, and insert vectors, check status returned
method: call function: create_partition
expected: status ok
'''
connect.create_partition(id_collection, default_tag)
ids = [i for i in range(default_nb)]
insert_ids = connect.bulk_insert(id_collection, default_entities, ids)
assert len(insert_ids) == len(ids)
@pytest.mark.skip("not support custom id")
def test_create_partition_insert_with_tag(self, connect, id_collection):
'''
target: test create partition, and insert vectors, check status returned
method: call function: create_partition
expected: status ok
'''
connect.create_partition(id_collection, default_tag)
ids = [i for i in range(default_nb)]
insert_ids = connect.bulk_insert(id_collection, default_entities, ids, partition_tag=default_tag)
assert len(insert_ids) == len(ids)
def test_create_partition_insert_with_tag_not_existed(self, connect, collection):
'''
target: test create partition, and insert vectors, check status returned
method: call function: create_partition
expected: status not ok
'''
tag_new = "tag_new"
connect.create_partition(collection, default_tag)
ids = [i for i in range(default_nb)]
with pytest.raises(Exception) as e:
insert_ids = connect.bulk_insert(collection, default_entities, ids, partition_tag=tag_new)
@pytest.mark.skip("not support custom id")
def test_create_partition_insert_same_tags(self, connect, id_collection):
'''
target: test create partition, and insert vectors, check status returned
method: call function: create_partition
expected: status ok
'''
connect.create_partition(id_collection, default_tag)
ids = [i for i in range(default_nb)]
insert_ids = connect.bulk_insert(id_collection, default_entities, ids, partition_tag=default_tag)
ids = [(i+default_nb) for i in range(default_nb)]
new_insert_ids = connect.bulk_insert(id_collection, default_entities, ids, partition_tag=default_tag)
connect.flush([id_collection])
res = connect.count_entities(id_collection)
assert res == default_nb * 2
@pytest.mark.level(2)
@pytest.mark.skip("not support count entities")
def test_create_partition_insert_same_tags_two_collections(self, connect, collection):
'''
target: test create two partitions, and insert vectors with the same tag to each collection, check status returned
method: call function: create_partition
expected: status ok, collection length is correct
'''
connect.create_partition(collection, default_tag)
collection_new = gen_unique_str()
connect.create_collection(collection_new, default_fields)
connect.create_partition(collection_new, default_tag)
ids = connect.bulk_insert(collection, default_entities, partition_tag=default_tag)
ids = connect.bulk_insert(collection_new, default_entities, partition_tag=default_tag)
connect.flush([collection, collection_new])
res = connect.count_entities(collection)
assert res == default_nb
res = connect.count_entities(collection_new)
assert res == default_nb
class TestShowBase:
"""
******************************************************************
The following cases are used to test `list_partitions` function
******************************************************************
"""
def test_list_partitions(self, connect, collection):
'''
target: test show partitions, check status and partitions returned
method: create partition first, then call function: list_partitions
expected: status ok, partition correct
'''
connect.create_partition(collection, default_tag)
res = connect.list_partitions(collection)
assert default_tag in res
def test_list_partitions_no_partition(self, connect, collection):
'''
target: test show partitions with collection name, check status and partitions returned
method: call function: list_partitions
expected: status ok, partitions correct
'''
res = connect.list_partitions(collection)
assert len(res) == 1
def test_show_multi_partitions(self, connect, collection):
'''
target: test show partitions, check status and partitions returned
method: create partitions first, then call function: list_partitions
expected: status ok, partitions correct
'''
tag_new = gen_unique_str()
connect.create_partition(collection, default_tag)
connect.create_partition(collection, tag_new)
res = connect.list_partitions(collection)
assert default_tag in res
assert tag_new in res
class TestHasBase:
"""
******************************************************************
The following cases are used to test `has_partition` function
******************************************************************
"""
@pytest.fixture(
scope="function",
params=gen_invalid_strs()
)
def get_tag_name(self, request):
yield request.param
def test_has_partition(self, connect, collection):
'''
target: test has_partition, check status and result
method: create partition first, then call function: has_partition
expected: status ok, result true
'''
connect.create_partition(collection, default_tag)
res = connect.has_partition(collection, default_tag)
logging.getLogger().info(res)
assert res
def test_has_partition_multi_partitions(self, connect, collection):
'''
target: test has_partition, check status and result
method: create partition first, then call function: has_partition
expected: status ok, result true
'''
for tag_name in [default_tag, "tag_new", "tag_new_new"]:
connect.create_partition(collection, tag_name)
for tag_name in [default_tag, "tag_new", "tag_new_new"]:
res = connect.has_partition(collection, tag_name)
assert res
def test_has_partition_tag_not_existed(self, connect, collection):
'''
target: test has_partition, check status and result
method: then call function: has_partition, with tag not existed
expected: status ok, result empty
'''
res = connect.has_partition(collection, default_tag)
logging.getLogger().info(res)
assert not res
def test_has_partition_collection_not_existed(self, connect, collection):
'''
target: test has_partition, check status and result
method: then call function: has_partition, with collection not existed
expected: status not ok
'''
with pytest.raises(Exception) as e:
res = connect.has_partition("not_existed_collection", default_tag)
@pytest.mark.level(2)
def test_has_partition_with_invalid_tag_name(self, connect, collection, get_tag_name):
'''
target: test has partition, with invalid tag name, check status returned
method: call function: has_partition
expected: status ok
'''
tag_name = get_tag_name
connect.create_partition(collection, default_tag)
with pytest.raises(Exception) as e:
res = connect.has_partition(collection, tag_name)
class TestDropBase:
"""
******************************************************************
The following cases are used to test `drop_partition` function
******************************************************************
"""
def test_drop_partition(self, connect, collection):
'''
target: test drop partition, check status and partition if existed
method: create partitions first, then call function: drop_partition
expected: status ok, no partitions in db
'''
connect.create_partition(collection, default_tag)
connect.drop_partition(collection, default_tag)
res = connect.list_partitions(collection)
tag_list = []
assert default_tag not in tag_list
def test_drop_partition_tag_not_existed(self, connect, collection):
'''
target: test drop partition, but tag not existed
method: create partitions first, then call function: drop_partition
expected: status not ok
'''
connect.create_partition(collection, default_tag)
new_tag = "new_tag"
with pytest.raises(Exception) as e:
connect.drop_partition(collection, new_tag)
def test_drop_partition_tag_not_existed_A(self, connect, collection):
'''
target: test drop partition, but collection not existed
method: create partitions first, then call function: drop_partition
expected: status not ok
'''
connect.create_partition(collection, default_tag)
new_collection = gen_unique_str()
with pytest.raises(Exception) as e:
connect.drop_partition(new_collection, default_tag)
@pytest.mark.level(2)
def test_drop_partition_repeatedly(self, connect, collection):
'''
target: test drop partition twice, check status and partition if existed
method: create partitions first, then call function: drop_partition
expected: status not ok, no partitions in db
'''
connect.create_partition(collection, default_tag)
connect.drop_partition(collection, default_tag)
time.sleep(2)
with pytest.raises(Exception) as e:
connect.drop_partition(collection, default_tag)
tag_list = connect.list_partitions(collection)
assert default_tag not in tag_list
def test_drop_partition_create(self, connect, collection):
'''
target: test drop partition, and create again, check status
method: create partitions first, then call function: drop_partition, create_partition
expected: status not ok, partition in db
'''
connect.create_partition(collection, default_tag)
connect.drop_partition(collection, default_tag)
time.sleep(2)
connect.create_partition(collection, default_tag)
tag_list = connect.list_partitions(collection)
assert default_tag in tag_list
class TestNameInvalid(object):
@pytest.fixture(
scope="function",
params=gen_invalid_strs()
)
def get_tag_name(self, request):
yield request.param
@pytest.fixture(
scope="function",
params=gen_invalid_strs()
)
def get_collection_name(self, request):
yield request.param
@pytest.mark.level(2)
def test_drop_partition_with_invalid_collection_name(self, connect, collection, get_collection_name):
'''
target: test drop partition, with invalid collection name, check status returned
method: call function: drop_partition
expected: status not ok
'''
collection_name = get_collection_name
connect.create_partition(collection, default_tag)
with pytest.raises(Exception) as e:
connect.drop_partition(collection_name, default_tag)
@pytest.mark.level(2)
def test_drop_partition_with_invalid_tag_name(self, connect, collection, get_tag_name):
'''
target: test drop partition, with invalid tag name, check status returned
method: call function: drop_partition
expected: status not ok
'''
tag_name = get_tag_name
connect.create_partition(collection, default_tag)
with pytest.raises(Exception) as e:
connect.drop_partition(collection, tag_name)
@pytest.mark.level(2)
def test_list_partitions_with_invalid_collection_name(self, connect, collection, get_collection_name):
'''
target: test show partitions, with invalid collection name, check status returned
method: call function: list_partitions
expected: status not ok
'''
collection_name = get_collection_name
connect.create_partition(collection, default_tag)
with pytest.raises(Exception) as e:
res = connect.list_partitions(collection_name)

1824
tests/python/test_search.py Normal file

File diff suppressed because it is too large Load Diff

1005
tests/python/utils.py Normal file

File diff suppressed because it is too large Load Diff