mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-11-29 18:38:44 +08:00
Add MinIO kv implements
Signed-off-by: godchen <qingxiang.chen@zilliz.com>
This commit is contained in:
parent
cec903da19
commit
7ab5b5d80d
3
.env
3
.env
@ -1,8 +1,9 @@
|
||||
REPO=milvusdb/milvus-distributed-dev
|
||||
ARCH=amd64
|
||||
UBUNTU=18.04
|
||||
DATE_VERSION=20201120-092740
|
||||
DATE_VERSION=20201202-085131
|
||||
LATEST_DATE_VERSION=latest
|
||||
MINIO_ADDRESS=minio:9000
|
||||
PULSAR_ADDRESS=pulsar://pulsar:6650
|
||||
ETCD_ADDRESS=etcd:2379
|
||||
MASTER_ADDRESS=localhost:53100
|
||||
|
2
.github/workflows/main.yaml
vendored
2
.github/workflows/main.yaml
vendored
@ -51,7 +51,7 @@ jobs:
|
||||
- name: Start Service
|
||||
shell: bash
|
||||
run: |
|
||||
docker-compose up -d pulsar etcd
|
||||
docker-compose up -d pulsar etcd minio
|
||||
- name: Build and UnitTest
|
||||
env:
|
||||
CHECK_BUILDER: "1"
|
||||
|
4
.github/workflows/publish-builder.yaml
vendored
4
.github/workflows/publish-builder.yaml
vendored
@ -6,12 +6,12 @@ on:
|
||||
push:
|
||||
# file paths to consider in the event. Optional; defaults to all.
|
||||
paths:
|
||||
- 'build/docker/**'
|
||||
- 'build/docker/env/**'
|
||||
- '.github/workflows/publish-builder.yaml'
|
||||
pull_request:
|
||||
# file paths to consider in the event. Optional; defaults to all.
|
||||
paths:
|
||||
- 'build/docker/**'
|
||||
- 'build/docker/env/**'
|
||||
- '.github/workflows/publish-builder.yaml'
|
||||
|
||||
jobs:
|
||||
|
@ -12,12 +12,15 @@ dir ('build/docker/deploy') {
|
||||
try {
|
||||
withCredentials([usernamePassword(credentialsId: "${env.DOCKER_CREDENTIALS_ID}", usernameVariable: 'DOCKER_USERNAME', passwordVariable: 'DOCKER_PASSWORD')]) {
|
||||
sh 'docker login -u ${DOCKER_USERNAME} -p ${DOCKER_PASSWORD} ${DOKCER_REGISTRY_URL}'
|
||||
|
||||
sh 'docker pull ${SOURCE_REPO}/master:${SOURCE_TAG} || true'
|
||||
sh 'docker-compose build --force-rm master'
|
||||
sh 'docker-compose push master'
|
||||
|
||||
sh 'docker pull ${SOURCE_REPO}/proxy:${SOURCE_TAG} || true'
|
||||
sh 'docker-compose build --force-rm proxy'
|
||||
sh 'docker-compose push proxy'
|
||||
|
||||
sh 'docker pull ${SOURCE_REPO}/querynode:${SOURCE_TAG} || true'
|
||||
sh 'docker-compose build --force-rm querynode'
|
||||
sh 'docker-compose push querynode'
|
||||
|
37
.jenkins/modules/Regression/PythonRegression.groovy
Normal file
37
.jenkins/modules/Regression/PythonRegression.groovy
Normal file
@ -0,0 +1,37 @@
|
||||
try {
|
||||
|
||||
sh 'docker-compose -p ${DOCKER_COMPOSE_PROJECT_NAME} up -d etcd'
|
||||
sh 'docker-compose -p ${DOCKER_COMPOSE_PROJECT_NAME} up -d pulsar'
|
||||
dir ('build/docker/deploy') {
|
||||
sh 'docker-compose -p ${DOCKER_COMPOSE_PROJECT_NAME} pull'
|
||||
sh 'docker-compose -p ${DOCKER_COMPOSE_PROJECT_NAME} up -d'
|
||||
}
|
||||
|
||||
dir ('build/docker/test') {
|
||||
sh 'docker pull ${SOURCE_REPO}/pytest:${SOURCE_TAG} || true'
|
||||
sh 'docker-compose build --force-rm regression'
|
||||
sh 'docker-compose -p ${DOCKER_COMPOSE_PROJECT_NAME} run --rm regression'
|
||||
try {
|
||||
withCredentials([usernamePassword(credentialsId: "${env.DOCKER_CREDENTIALS_ID}", usernameVariable: 'DOCKER_USERNAME', passwordVariable: 'DOCKER_PASSWORD')]) {
|
||||
sh 'docker login -u ${DOCKER_USERNAME} -p ${DOCKER_PASSWORD} ${DOKCER_REGISTRY_URL}'
|
||||
sh 'docker-compose push regression'
|
||||
}
|
||||
} catch (exc) {
|
||||
throw exc
|
||||
} finally {
|
||||
sh 'docker logout ${DOKCER_REGISTRY_URL}'
|
||||
}
|
||||
}
|
||||
} catch(exc) {
|
||||
throw exc
|
||||
} finally {
|
||||
sh 'docker-compose -p ${DOCKER_COMPOSE_PROJECT_NAME} rm -f -s -v pulsar'
|
||||
sh 'docker-compose -p ${DOCKER_COMPOSE_PROJECT_NAME} rm -f -s -v etcd'
|
||||
dir ('build/docker/deploy') {
|
||||
sh 'docker-compose -p ${DOCKER_COMPOSE_PROJECT_NAME} down --rmi all -v || true'
|
||||
}
|
||||
dir ('build/docker/test') {
|
||||
sh 'docker-compose -p ${DOCKER_COMPOSE_PROJECT_NAME} run --rm regression /bin/bash -c "rm -rf __pycache__ && rm -rf .pytest_cache"'
|
||||
sh 'docker-compose -p ${DOCKER_COMPOSE_PROJECT_NAME} down --rmi all -v || true'
|
||||
}
|
||||
}
|
10
Makefile
10
Makefile
@ -35,21 +35,21 @@ fmt:
|
||||
@echo "Running $@ check"
|
||||
@GO111MODULE=on env bash $(PWD)/scripts/gofmt.sh cmd/
|
||||
@GO111MODULE=on env bash $(PWD)/scripts/gofmt.sh internal/
|
||||
@GO111MODULE=on env bash $(PWD)/scripts/gofmt.sh test/
|
||||
@GO111MODULE=on env bash $(PWD)/scripts/gofmt.sh tests/go/
|
||||
|
||||
#TODO: Check code specifications by golangci-lint
|
||||
lint:
|
||||
@echo "Running $@ check"
|
||||
@GO111MODULE=on ${GOPATH}/bin/golangci-lint cache clean
|
||||
@GO111MODULE=on ${GOPATH}/bin/golangci-lint run --timeout=3m --config ./.golangci.yml ./internal/...
|
||||
@GO111MODULE=on ${GOPATH}/bin/golangci-lint run --timeout=3m --config ./.golangci.yml ./cmd/...
|
||||
@GO111MODULE=on ${GOPATH}/bin/golangci-lint run --timeout=3m --config ./.golangci.yml ./test/...
|
||||
@GO111MODULE=on ${GOPATH}/bin/golangci-lint run --timeout=5m --config ./.golangci.yml ./internal/...
|
||||
@GO111MODULE=on ${GOPATH}/bin/golangci-lint run --timeout=5m --config ./.golangci.yml ./cmd/...
|
||||
@GO111MODULE=on ${GOPATH}/bin/golangci-lint run --timeout=5m --config ./.golangci.yml ./tests/go/...
|
||||
|
||||
ruleguard:
|
||||
@echo "Running $@ check"
|
||||
@${GOPATH}/bin/ruleguard -rules ruleguard.rules.go ./internal/...
|
||||
@${GOPATH}/bin/ruleguard -rules ruleguard.rules.go ./cmd/...
|
||||
@${GOPATH}/bin/ruleguard -rules ruleguard.rules.go ./test/...
|
||||
@${GOPATH}/bin/ruleguard -rules ruleguard.rules.go ./tests/go/...
|
||||
|
||||
verifiers: cppcheck fmt lint ruleguard
|
||||
|
||||
|
29
build/ci/jenkins/Jenkinsfile
vendored
29
build/ci/jenkins/Jenkinsfile
vendored
@ -18,6 +18,11 @@ pipeline {
|
||||
PACKAGE_ARTFACTORY_URL = "${JFROG_ARTFACTORY_URL}/${PROJECT_NAME}/package/${PACKAGE_NAME}"
|
||||
DOCKER_CREDENTIALS_ID = "ba070c98-c8cc-4f7c-b657-897715f359fc"
|
||||
DOKCER_REGISTRY_URL = "registry.zilliz.com"
|
||||
SOURCE_REPO = "${DOKCER_REGISTRY_URL}/milvus-distributed"
|
||||
TARGET_REPO = "${DOKCER_REGISTRY_URL}/milvus-distributed"
|
||||
SOURCE_TAG = "${CHANGE_TARGET ? CHANGE_TARGET : SEMVER}-${LOWER_BUILD_TYPE}"
|
||||
TARGET_TAG = "${SEMVER}-${LOWER_BUILD_TYPE}"
|
||||
DOCKER_BUILDKIT = 1
|
||||
}
|
||||
stages {
|
||||
stage ('Build and UnitTest') {
|
||||
@ -51,18 +56,28 @@ pipeline {
|
||||
yamlFile "build/ci/jenkins/pod/docker-pod.yaml"
|
||||
}
|
||||
}
|
||||
environment{
|
||||
SOURCE_REPO = "${DOKCER_REGISTRY_URL}/milvus-distributed"
|
||||
TARGET_REPO = "${DOKCER_REGISTRY_URL}/milvus-distributed"
|
||||
SOURCE_TAG = "${CHANGE_TARGET ? CHANGE_TARGET : SEMVER}-${LOWER_BUILD_TYPE}"
|
||||
TARGET_TAG = "${SEMVER}-${LOWER_BUILD_TYPE}"
|
||||
DOCKER_BUILDKIT = 1
|
||||
}
|
||||
steps {
|
||||
container('publish-images') {
|
||||
MPLModule('Publish')
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
stage ('Dev Test') {
|
||||
agent {
|
||||
label "performance"
|
||||
}
|
||||
environment {
|
||||
DOCKER_COMPOSE_PROJECT_NAME = "${PROJECT_NAME}-${SEMVER}-${env.BUILD_NUMBER}".replaceAll("\\.", "-").replaceAll("_", "-")
|
||||
}
|
||||
steps {
|
||||
MPLModule('Python Regression')
|
||||
}
|
||||
post {
|
||||
cleanup {
|
||||
deleteDir() /* clean up our workspace */
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -13,8 +13,6 @@ services:
|
||||
ETCD_ADDRESS: ${ETCD_ADDRESS}
|
||||
networks:
|
||||
- milvus
|
||||
ports:
|
||||
- "53100:53100"
|
||||
|
||||
proxy:
|
||||
image: ${TARGET_REPO}/proxy:${TARGET_TAG}
|
||||
@ -26,8 +24,6 @@ services:
|
||||
environment:
|
||||
PULSAR_ADDRESS: ${PULSAR_ADDRESS}
|
||||
MASTER_ADDRESS: ${MASTER_ADDRESS}
|
||||
ports:
|
||||
- "19530:19530"
|
||||
networks:
|
||||
- milvus
|
||||
|
||||
|
18
build/docker/test/Dockerfile
Normal file
18
build/docker/test/Dockerfile
Normal file
@ -0,0 +1,18 @@
|
||||
# Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
FROM python:3.6.8-jessie
|
||||
|
||||
COPY ./tests/python/requirements.txt /requirements.txt
|
||||
|
||||
RUN python3 -m pip install -r /requirements.txt
|
||||
|
||||
CMD ["tail", "-f", "/dev/null"]
|
20
build/docker/test/docker-compose.yml
Normal file
20
build/docker/test/docker-compose.yml
Normal file
@ -0,0 +1,20 @@
|
||||
version: '3.5'
|
||||
|
||||
services:
|
||||
regression:
|
||||
image: ${TARGET_REPO}/pytest:${TARGET_TAG}
|
||||
build:
|
||||
context: ../../../
|
||||
dockerfile: build/docker/test/Dockerfile
|
||||
cache_from:
|
||||
- ${SOURCE_REPO}/pytest:${SOURCE_TAG}
|
||||
volumes:
|
||||
- ../../..:/milvus-distributed:delegated
|
||||
working_dir: "/milvus-distributed/tests/python"
|
||||
command: >
|
||||
/bin/bash -c "pytest --ip proxy"
|
||||
networks:
|
||||
- milvus
|
||||
|
||||
networks:
|
||||
milvus:
|
@ -42,6 +42,13 @@ etcd:
|
||||
rootpath: by-dev
|
||||
segthreshold: 10000
|
||||
|
||||
minio:
|
||||
address: localhost
|
||||
port: 9000
|
||||
accessKeyID: minioadmin
|
||||
secretAccessKey: minioadmin
|
||||
useSSL: false
|
||||
|
||||
timesync:
|
||||
interval: 400
|
||||
|
||||
|
@ -20,6 +20,22 @@ services:
|
||||
networks:
|
||||
- milvus
|
||||
|
||||
minio:
|
||||
image: minio/minio:RELEASE.2020-12-03T00-03-10Z
|
||||
ports:
|
||||
- "9000:9000"
|
||||
environment:
|
||||
MINIO_ACCESS_KEY: minioadmin
|
||||
MINIO_SECRET_KEY: minioadmin
|
||||
command: minio server /minio_data
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"]
|
||||
interval: 30s
|
||||
timeout: 20s
|
||||
retries: 3
|
||||
networks:
|
||||
- milvus
|
||||
|
||||
networks:
|
||||
milvus:
|
||||
|
||||
|
@ -21,6 +21,7 @@ services:
|
||||
PULSAR_ADDRESS: ${PULSAR_ADDRESS}
|
||||
ETCD_ADDRESS: ${ETCD_ADDRESS}
|
||||
MASTER_ADDRESS: ${MASTER_ADDRESS}
|
||||
MINIO_ADDRESS: ${MINIO_ADDRESS}
|
||||
volumes: &ubuntu-volumes
|
||||
- .:/go/src/github.com/zilliztech/milvus-distributed:delegated
|
||||
- ${DOCKER_VOLUME_DIRECTORY:-.docker}/${ARCH}-ubuntu${UBUNTU}-cache:/ccache:delegated
|
||||
@ -45,6 +46,7 @@ services:
|
||||
PULSAR_ADDRESS: ${PULSAR_ADDRESS}
|
||||
ETCD_ADDRESS: ${ETCD_ADDRESS}
|
||||
MASTER_ADDRESS: ${MASTER_ADDRESS}
|
||||
MINIO_ADDRESS: ${MINIO_ADDRESS}
|
||||
volumes:
|
||||
- .:/go/src/github.com/zilliztech/milvus-distributed:delegated
|
||||
- ${DOCKER_VOLUME_DIRECTORY:-.docker}/${ARCH}-ubuntu${UBUNTU}-gdbserver-home:/home/debugger:delegated
|
||||
@ -59,16 +61,28 @@ services:
|
||||
etcd:
|
||||
image: quay.io/coreos/etcd:v3.4.13
|
||||
command: etcd -advertise-client-urls=http://127.0.0.1:2379 -listen-client-urls http://0.0.0.0:2379
|
||||
ports:
|
||||
- "2379:2379"
|
||||
networks:
|
||||
- milvus
|
||||
|
||||
pulsar:
|
||||
image: apachepulsar/pulsar:2.6.1
|
||||
command: bin/pulsar standalone
|
||||
networks:
|
||||
- milvus
|
||||
|
||||
minio:
|
||||
image: minio/minio:RELEASE.2020-12-03T00-03-10Z
|
||||
ports:
|
||||
- "6650:6650"
|
||||
- "9000:9000"
|
||||
environment:
|
||||
MINIO_ACCESS_KEY: minioadmin
|
||||
MINIO_SECRET_KEY: minioadmin
|
||||
command: minio server /minio_data
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"]
|
||||
interval: 30s
|
||||
timeout: 20s
|
||||
retries: 3
|
||||
networks:
|
||||
- milvus
|
||||
|
||||
|
@ -184,7 +184,7 @@ Note that *tenantId*, *proxyId*, *collectionId*, *segmentId* are unique strings
|
||||
|
||||
```go
|
||||
type metaTable struct {
|
||||
kv kv.Base // client of a reliable kv service, i.e. etcd client
|
||||
kv kv.TxnBase // client of a reliable kv service, i.e. etcd client
|
||||
tenantId2Meta map[UniqueId]TenantMeta // tenant id to tenant meta
|
||||
proxyId2Meta map[UniqueId]ProxyMeta // proxy id to proxy meta
|
||||
collId2Meta map[UniqueId]CollectionMeta // collection id to collection meta
|
||||
@ -216,7 +216,7 @@ func (meta *metaTable) GetSegmentById(segId UniqueId)(*SegmentMeta, error)
|
||||
func (meta *metaTable) DeleteSegment(segId UniqueId) error
|
||||
func (meta *metaTable) CloseSegment(segId UniqueId, closeTs Timestamp, num_rows int64) error
|
||||
|
||||
func NewMetaTable(kv kv.Base) (*metaTable,error)
|
||||
func NewMetaTable(kv kv.TxnBase) (*metaTable,error)
|
||||
```
|
||||
|
||||
*metaTable* maintains meta both in memory and *etcdKV*. It keeps meta's consistency in both sides. All its member functions may be called concurrently.
|
||||
|
@ -54,7 +54,7 @@ func (ia *IDAllocator) syncID() {
|
||||
|
||||
cancel()
|
||||
if err != nil {
|
||||
log.Panic("syncID Failed!!!!!")
|
||||
log.Println("syncID Failed!!!!!")
|
||||
return
|
||||
}
|
||||
ia.idStart = resp.GetID()
|
||||
|
@ -210,7 +210,7 @@ func (sa *SegIDAssigner) syncSegments() {
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Panic("syncID Failed!!!!!")
|
||||
log.Println("syncSemgnet Failed!!!!!")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@ -63,7 +63,7 @@ func (ta *TimestampAllocator) syncTs() {
|
||||
|
||||
cancel()
|
||||
if err != nil {
|
||||
log.Panic("syncID Failed!!!!!")
|
||||
log.Println("syncTimestamp Failed!!!!!")
|
||||
return
|
||||
}
|
||||
ta.lastTsBegin = resp.GetTimestamp()
|
||||
|
@ -93,8 +93,8 @@ endif ()
|
||||
set(INDEX_INCLUDE_DIRS ${INDEX_INCLUDE_DIRS} PARENT_SCOPE)
|
||||
|
||||
if (KNOWHERE_BUILD_TESTS)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DELPP_DISABLE_LOGS")
|
||||
add_subdirectory(unittest)
|
||||
#set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DELPP_DISABLE_LOGS")
|
||||
#add_subdirectory(unittest)
|
||||
endif ()
|
||||
|
||||
config_summary()
|
||||
|
@ -143,6 +143,7 @@ endif ()
|
||||
|
||||
target_link_libraries(
|
||||
knowhere
|
||||
milvus_utils
|
||||
${depend_libs}
|
||||
)
|
||||
|
||||
|
@ -12,4 +12,4 @@ set(MILVUS_QUERY_SRCS
|
||||
BruteForceSearch.cpp
|
||||
)
|
||||
add_library(milvus_query ${MILVUS_QUERY_SRCS})
|
||||
target_link_libraries(milvus_query milvus_proto)
|
||||
target_link_libraries(milvus_query milvus_proto milvus_utils)
|
||||
|
@ -24,18 +24,22 @@ namespace milvus::query {
|
||||
|
||||
static std::unique_ptr<VectorPlanNode>
|
||||
ParseVecNode(Plan* plan, const Json& out_body) {
|
||||
Assert(out_body.is_object());
|
||||
// TODO add binary info
|
||||
auto vec_node = std::make_unique<FloatVectorANNS>();
|
||||
Assert(out_body.size() == 1);
|
||||
auto iter = out_body.begin();
|
||||
std::string field_name = iter.key();
|
||||
auto& vec_info = iter.value();
|
||||
Assert(vec_info.is_object());
|
||||
auto topK = vec_info["topk"];
|
||||
AssertInfo(topK > 0, "topK must greater than 0");
|
||||
AssertInfo(topK < 16384, "topK is too large");
|
||||
vec_node->query_info_.topK_ = topK;
|
||||
vec_node->query_info_.metric_type_ = vec_info["metric_type"];
|
||||
vec_node->query_info_.search_params_ = vec_info["params"];
|
||||
vec_node->query_info_.metric_type_ = vec_info.at("metric_type");
|
||||
vec_node->query_info_.search_params_ = vec_info.at("params");
|
||||
vec_node->query_info_.field_id_ = field_name;
|
||||
vec_node->placeholder_tag_ = vec_info["query"];
|
||||
vec_node->placeholder_tag_ = vec_info.at("query");
|
||||
auto tag = vec_node->placeholder_tag_;
|
||||
AssertInfo(!plan->tag2field_.count(tag), "duplicated placeholder tag");
|
||||
plan->tag2field_.emplace(tag, field_name);
|
||||
@ -56,6 +60,8 @@ to_lower(const std::string& raw) {
|
||||
return data;
|
||||
}
|
||||
|
||||
template <class...>
|
||||
constexpr std::false_type always_false{};
|
||||
template <typename T>
|
||||
std::unique_ptr<Expr>
|
||||
ParseRangeNodeImpl(const Schema& schema, const std::string& field_name, const Json& body) {
|
||||
@ -63,11 +69,19 @@ ParseRangeNodeImpl(const Schema& schema, const std::string& field_name, const Js
|
||||
auto data_type = schema[field_name].get_data_type();
|
||||
expr->data_type_ = data_type;
|
||||
expr->field_id_ = field_name;
|
||||
Assert(body.is_object());
|
||||
for (auto& item : body.items()) {
|
||||
auto op_name = to_lower(item.key());
|
||||
|
||||
AssertInfo(RangeExpr::mapping_.count(op_name), "op(" + op_name + ") not found");
|
||||
auto op = RangeExpr::mapping_.at(op_name);
|
||||
if constexpr (std::is_integral_v<T>) {
|
||||
Assert(item.value().is_number_integer());
|
||||
} else if constexpr (std::is_floating_point_v<T>) {
|
||||
Assert(item.value().is_number());
|
||||
} else {
|
||||
static_assert(always_false<T>, "unsupported type");
|
||||
}
|
||||
T value = item.value();
|
||||
expr->conditions_.emplace_back(op, value);
|
||||
}
|
||||
@ -83,8 +97,10 @@ ParseRangeNode(const Schema& schema, const Json& out_body) {
|
||||
auto data_type = schema[field_name].get_data_type();
|
||||
Assert(!field_is_vector(data_type));
|
||||
switch (data_type) {
|
||||
case DataType::BOOL:
|
||||
return ParseRangeNodeImpl<bool>(schema, field_name, body);
|
||||
case DataType::BOOL: {
|
||||
PanicInfo("bool is not supported in Range node");
|
||||
// return ParseRangeNodeImpl<bool>(schema, field_name, body);
|
||||
}
|
||||
case DataType::INT8:
|
||||
return ParseRangeNodeImpl<int8_t>(schema, field_name, body);
|
||||
case DataType::INT16:
|
||||
@ -109,16 +125,17 @@ CreatePlanImplNaive(const Schema& schema, const std::string& dsl_str) {
|
||||
nlohmann::json vec_pack;
|
||||
std::optional<std::unique_ptr<Expr>> predicate;
|
||||
|
||||
auto& bool_dsl = dsl["bool"];
|
||||
auto& bool_dsl = dsl.at("bool");
|
||||
if (bool_dsl.contains("must")) {
|
||||
auto& packs = bool_dsl["must"];
|
||||
auto& packs = bool_dsl.at("must");
|
||||
Assert(packs.is_array());
|
||||
for (auto& pack : packs) {
|
||||
if (pack.contains("vector")) {
|
||||
auto& out_body = pack["vector"];
|
||||
auto& out_body = pack.at("vector");
|
||||
plan->plan_node_ = ParseVecNode(plan.get(), out_body);
|
||||
} else if (pack.contains("range")) {
|
||||
AssertInfo(!predicate, "unsupported complex DSL");
|
||||
auto& out_body = pack["range"];
|
||||
auto& out_body = pack.at("range");
|
||||
predicate = ParseRangeNode(schema, out_body);
|
||||
} else {
|
||||
PanicInfo("unsupported node");
|
||||
@ -126,7 +143,7 @@ CreatePlanImplNaive(const Schema& schema, const std::string& dsl_str) {
|
||||
}
|
||||
AssertInfo(plan->plan_node_, "vector node not found");
|
||||
} else if (bool_dsl.contains("vector")) {
|
||||
auto& out_body = bool_dsl["vector"];
|
||||
auto& out_body = bool_dsl.at("vector");
|
||||
plan->plan_node_ = ParseVecNode(plan.get(), out_body);
|
||||
Assert(plan->plan_node_);
|
||||
} else {
|
||||
|
@ -17,7 +17,7 @@ add_library(milvus_segcore SHARED
|
||||
)
|
||||
|
||||
target_link_libraries(milvus_segcore
|
||||
tbb utils pthread knowhere log milvus_proto
|
||||
tbb milvus_utils pthread knowhere log milvus_proto
|
||||
dl backtrace
|
||||
milvus_common
|
||||
milvus_query
|
||||
|
@ -19,6 +19,9 @@ NewCollection(const char* collection_proto) {
|
||||
|
||||
auto collection = std::make_unique<milvus::segcore::Collection>(proto);
|
||||
|
||||
// TODO: delete print
|
||||
std::cout << "create collection " << collection->get_collection_name() << std::endl;
|
||||
|
||||
return (void*)collection.release();
|
||||
}
|
||||
|
||||
|
@ -50,97 +50,187 @@ DeleteMarshaledHits(CMarshaledHits c_marshaled_hits) {
|
||||
}
|
||||
|
||||
struct SearchResultPair {
|
||||
uint64_t id_;
|
||||
float distance_;
|
||||
int64_t segment_id_;
|
||||
SearchResult* search_result_;
|
||||
int64_t offset_;
|
||||
int64_t index_;
|
||||
|
||||
SearchResultPair(uint64_t id, float distance, int64_t segment_id)
|
||||
: id_(id), distance_(distance), segment_id_(segment_id) {
|
||||
SearchResultPair(float distance, SearchResult* search_result, int64_t offset, int64_t index)
|
||||
: distance_(distance), search_result_(search_result), offset_(offset), index_(index) {
|
||||
}
|
||||
|
||||
bool
|
||||
operator<(const SearchResultPair& pair) const {
|
||||
return (distance_ < pair.distance_);
|
||||
}
|
||||
|
||||
void
|
||||
reset_distance() {
|
||||
distance_ = search_result_->result_distances_[offset_];
|
||||
}
|
||||
};
|
||||
|
||||
void
|
||||
GetResultData(std::vector<SearchResult*>& search_results,
|
||||
SearchResult& final_result,
|
||||
GetResultData(std::vector<std::vector<int64_t>>& search_records,
|
||||
std::vector<SearchResult*>& search_results,
|
||||
int64_t query_offset,
|
||||
bool* is_selected,
|
||||
int64_t topk) {
|
||||
auto num_segments = search_results.size();
|
||||
std::map<int, int> iter_loc_peer_result;
|
||||
AssertInfo(num_segments > 0, "num segment must greater than 0");
|
||||
std::vector<SearchResultPair> result_pairs;
|
||||
for (int j = 0; j < num_segments; ++j) {
|
||||
auto id = search_results[j]->result_ids_[query_offset];
|
||||
auto distance = search_results[j]->result_distances_[query_offset];
|
||||
result_pairs.push_back(SearchResultPair(id, distance, j));
|
||||
iter_loc_peer_result[j] = query_offset;
|
||||
auto search_result = search_results[j];
|
||||
AssertInfo(search_result != nullptr, "search result must not equal to nullptr");
|
||||
result_pairs.push_back(SearchResultPair(distance, search_result, query_offset, j));
|
||||
}
|
||||
std::sort(result_pairs.begin(), result_pairs.end());
|
||||
final_result.result_ids_.push_back(result_pairs[0].id_);
|
||||
final_result.result_distances_.push_back(result_pairs[0].distance_);
|
||||
|
||||
for (int i = 1; i < topk; ++i) {
|
||||
auto segment_id = result_pairs[0].segment_id_;
|
||||
auto query_offset = ++(iter_loc_peer_result[segment_id]);
|
||||
auto id = search_results[segment_id]->result_ids_[query_offset];
|
||||
auto distance = search_results[segment_id]->result_distances_[query_offset];
|
||||
result_pairs[0] = SearchResultPair(id, distance, segment_id);
|
||||
int64_t loc_offset = query_offset;
|
||||
AssertInfo(topk > 0, "topK must greater than 0");
|
||||
for (int i = 0; i < topk; ++i) {
|
||||
result_pairs[0].reset_distance();
|
||||
std::sort(result_pairs.begin(), result_pairs.end());
|
||||
final_result.result_ids_.push_back(result_pairs[0].id_);
|
||||
final_result.result_distances_.push_back(result_pairs[0].distance_);
|
||||
auto& result_pair = result_pairs[0];
|
||||
auto index = result_pair.index_;
|
||||
is_selected[index] = true;
|
||||
result_pair.search_result_->result_offsets_.push_back(loc_offset++);
|
||||
search_records[index].push_back(result_pair.offset_++);
|
||||
}
|
||||
}
|
||||
|
||||
CQueryResult
|
||||
ReduceQueryResults(CQueryResult* query_results, int64_t num_segments) {
|
||||
void
|
||||
ResetSearchResult(std::vector<std::vector<int64_t>>& search_records,
|
||||
std::vector<SearchResult*>& search_results,
|
||||
bool* is_selected) {
|
||||
auto num_segments = search_results.size();
|
||||
AssertInfo(num_segments > 0, "num segment must greater than 0");
|
||||
for (int i = 0; i < num_segments; i++) {
|
||||
if (is_selected[i] == false) {
|
||||
continue;
|
||||
}
|
||||
auto search_result = search_results[i];
|
||||
AssertInfo(search_result != nullptr, "search result must not equal to nullptr");
|
||||
|
||||
std::vector<float> result_distances;
|
||||
std::vector<int64_t> internal_seg_offsets;
|
||||
std::vector<int64_t> result_ids;
|
||||
|
||||
for (int j = 0; j < search_records[i].size(); j++) {
|
||||
auto& offset = search_records[i][j];
|
||||
auto distance = search_result->result_distances_[offset];
|
||||
auto internal_seg_offset = search_result->internal_seg_offsets_[offset];
|
||||
auto id = search_result->result_ids_[offset];
|
||||
result_distances.push_back(distance);
|
||||
internal_seg_offsets.push_back(internal_seg_offset);
|
||||
result_ids.push_back(id);
|
||||
}
|
||||
|
||||
search_result->result_distances_ = result_distances;
|
||||
search_result->internal_seg_offsets_ = internal_seg_offsets;
|
||||
search_result->result_ids_ = result_ids;
|
||||
}
|
||||
}
|
||||
|
||||
CStatus
|
||||
ReduceQueryResults(CQueryResult* c_search_results, int64_t num_segments, bool* is_selected) {
|
||||
std::vector<SearchResult*> search_results;
|
||||
for (int i = 0; i < num_segments; ++i) {
|
||||
search_results.push_back((SearchResult*)query_results[i]);
|
||||
search_results.push_back((SearchResult*)c_search_results[i]);
|
||||
}
|
||||
auto topk = search_results[0]->topK_;
|
||||
auto num_queries = search_results[0]->num_queries_;
|
||||
auto final_result = std::make_unique<SearchResult>();
|
||||
try {
|
||||
auto topk = search_results[0]->topK_;
|
||||
auto num_queries = search_results[0]->num_queries_;
|
||||
std::vector<std::vector<int64_t>> search_records(num_segments);
|
||||
|
||||
int64_t query_offset = 0;
|
||||
for (int j = 0; j < num_queries; ++j) {
|
||||
GetResultData(search_results, *final_result, query_offset, topk);
|
||||
query_offset += topk;
|
||||
int64_t query_offset = 0;
|
||||
for (int j = 0; j < num_queries; ++j) {
|
||||
GetResultData(search_records, search_results, query_offset, is_selected, topk);
|
||||
query_offset += topk;
|
||||
}
|
||||
ResetSearchResult(search_records, search_results, is_selected);
|
||||
auto status = CStatus();
|
||||
status.error_code = Success;
|
||||
status.error_msg = "";
|
||||
return status;
|
||||
} catch (std::exception& e) {
|
||||
auto status = CStatus();
|
||||
status.error_code = UnexpectedException;
|
||||
status.error_msg = strdup(e.what());
|
||||
return status;
|
||||
}
|
||||
|
||||
return (CQueryResult)final_result.release();
|
||||
}
|
||||
|
||||
CMarshaledHits
|
||||
ReorganizeQueryResults(CQueryResult c_query_result,
|
||||
CPlan c_plan,
|
||||
CStatus
|
||||
ReorganizeQueryResults(CMarshaledHits* c_marshaled_hits,
|
||||
CPlaceholderGroup* c_placeholder_groups,
|
||||
int64_t num_groups) {
|
||||
auto marshaledHits = std::make_unique<MarshaledHits>(num_groups);
|
||||
auto search_result = (milvus::engine::QueryResult*)c_query_result;
|
||||
auto& result_ids = search_result->result_ids_;
|
||||
auto& result_distances = search_result->result_distances_;
|
||||
auto topk = GetTopK(c_plan);
|
||||
int64_t queries_offset = 0;
|
||||
for (int i = 0; i < num_groups; i++) {
|
||||
auto num_queries = GetNumOfQueries(c_placeholder_groups[i]);
|
||||
MarshaledHitsPeerGroup& hits_peer_group = (*marshaledHits).marshaled_hits_[i];
|
||||
for (int j = 0; j < num_queries; j++) {
|
||||
auto index = topk * queries_offset++;
|
||||
milvus::proto::service::Hits hits;
|
||||
for (int k = index; k < index + topk; k++) {
|
||||
hits.add_ids(result_ids[k]);
|
||||
hits.add_scores(result_distances[k]);
|
||||
}
|
||||
auto blob = hits.SerializeAsString();
|
||||
hits_peer_group.hits_.push_back(blob);
|
||||
hits_peer_group.blob_length_.push_back(blob.size());
|
||||
int64_t num_groups,
|
||||
CQueryResult* c_search_results,
|
||||
bool* is_selected,
|
||||
int64_t num_segments,
|
||||
CPlan c_plan) {
|
||||
try {
|
||||
auto marshaledHits = std::make_unique<MarshaledHits>(num_groups);
|
||||
auto topk = GetTopK(c_plan);
|
||||
std::vector<int64_t> num_queries_peer_group;
|
||||
int64_t total_num_queries = 0;
|
||||
for (int i = 0; i < num_groups; i++) {
|
||||
auto num_queries = GetNumOfQueries(c_placeholder_groups[i]);
|
||||
num_queries_peer_group.push_back(num_queries);
|
||||
total_num_queries += num_queries;
|
||||
}
|
||||
}
|
||||
|
||||
return (CMarshaledHits)marshaledHits.release();
|
||||
std::vector<float> result_distances(total_num_queries * topk);
|
||||
std::vector<int64_t> result_ids(total_num_queries * topk);
|
||||
std::vector<std::vector<char>> row_datas(total_num_queries * topk);
|
||||
|
||||
int64_t count = 0;
|
||||
for (int i = 0; i < num_segments; i++) {
|
||||
if (is_selected[i] == false) {
|
||||
continue;
|
||||
}
|
||||
auto search_result = (SearchResult*)c_search_results[i];
|
||||
AssertInfo(search_result != nullptr, "search result must not equal to nullptr");
|
||||
auto size = search_result->result_offsets_.size();
|
||||
for (int j = 0; j < size; j++) {
|
||||
auto loc = search_result->result_offsets_[j];
|
||||
result_distances[loc] = search_result->result_distances_[j];
|
||||
row_datas[loc] = search_result->row_data_[j];
|
||||
result_ids[loc] = search_result->result_ids_[j];
|
||||
}
|
||||
count += size;
|
||||
}
|
||||
AssertInfo(count == total_num_queries * topk, "the reduces result's size less than total_num_queries*topk");
|
||||
|
||||
int64_t fill_hit_offset = 0;
|
||||
for (int i = 0; i < num_groups; i++) {
|
||||
MarshaledHitsPeerGroup& hits_peer_group = (*marshaledHits).marshaled_hits_[i];
|
||||
for (int j = 0; j < num_queries_peer_group[i]; j++) {
|
||||
milvus::proto::service::Hits hits;
|
||||
for (int k = 0; k < topk; k++, fill_hit_offset++) {
|
||||
hits.add_ids(result_ids[fill_hit_offset]);
|
||||
hits.add_scores(result_distances[fill_hit_offset]);
|
||||
auto& row_data = row_datas[fill_hit_offset];
|
||||
hits.add_row_data(row_data.data(), row_data.size());
|
||||
}
|
||||
auto blob = hits.SerializeAsString();
|
||||
hits_peer_group.hits_.push_back(blob);
|
||||
hits_peer_group.blob_length_.push_back(blob.size());
|
||||
}
|
||||
}
|
||||
|
||||
auto status = CStatus();
|
||||
status.error_code = Success;
|
||||
status.error_msg = "";
|
||||
auto marshled_res = (CMarshaledHits)marshaledHits.release();
|
||||
*c_marshaled_hits = marshled_res;
|
||||
return status;
|
||||
} catch (std::exception& e) {
|
||||
auto status = CStatus();
|
||||
status.error_code = UnexpectedException;
|
||||
status.error_msg = strdup(e.what());
|
||||
*c_marshaled_hits = nullptr;
|
||||
return status;
|
||||
}
|
||||
}
|
||||
|
||||
int64_t
|
||||
|
@ -25,14 +25,17 @@ DeleteMarshaledHits(CMarshaledHits c_marshaled_hits);
|
||||
int
|
||||
MergeInto(int64_t num_queries, int64_t topk, float* distances, int64_t* uids, float* new_distances, int64_t* new_uids);
|
||||
|
||||
CQueryResult
|
||||
ReduceQueryResults(CQueryResult* query_results, int64_t num_segments);
|
||||
CStatus
|
||||
ReduceQueryResults(CQueryResult* query_results, int64_t num_segments, bool* is_selected);
|
||||
|
||||
CMarshaledHits
|
||||
ReorganizeQueryResults(CQueryResult query_result,
|
||||
CPlan c_plan,
|
||||
CStatus
|
||||
ReorganizeQueryResults(CMarshaledHits* c_marshaled_hits,
|
||||
CPlaceholderGroup* c_placeholder_groups,
|
||||
int64_t num_groups);
|
||||
int64_t num_groups,
|
||||
CQueryResult* c_search_results,
|
||||
bool* is_selected,
|
||||
int64_t num_segments,
|
||||
CPlan c_plan);
|
||||
|
||||
int64_t
|
||||
GetHitsBlobSize(CMarshaledHits c_marshaled_hits);
|
||||
|
@ -155,6 +155,24 @@ Search(CSegmentBase c_segment,
|
||||
return status;
|
||||
}
|
||||
|
||||
CStatus
|
||||
FillTargetEntry(CSegmentBase c_segment, CPlan c_plan, CQueryResult c_result) {
|
||||
auto segment = (milvus::segcore::SegmentBase*)c_segment;
|
||||
auto plan = (milvus::query::Plan*)c_plan;
|
||||
auto result = (milvus::engine::QueryResult*)c_result;
|
||||
|
||||
auto status = CStatus();
|
||||
try {
|
||||
auto res = segment->FillTargetEntry(plan, *result);
|
||||
status.error_code = Success;
|
||||
status.error_msg = "";
|
||||
} catch (std::runtime_error& e) {
|
||||
status.error_code = UnexpectedException;
|
||||
status.error_msg = strdup(e.what());
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////
|
||||
|
||||
int
|
||||
|
@ -61,6 +61,9 @@ Search(CSegmentBase c_segment,
|
||||
int num_groups,
|
||||
CQueryResult* result);
|
||||
|
||||
CStatus
|
||||
FillTargetEntry(CSegmentBase c_segment, CPlan c_plan, CQueryResult result);
|
||||
|
||||
//////////////////////////////////////////////////////////////////
|
||||
|
||||
int
|
||||
|
@ -17,8 +17,8 @@ set(UTILS_FILES
|
||||
EasyAssert.cpp
|
||||
)
|
||||
|
||||
add_library( utils STATIC ${UTILS_FILES} )
|
||||
add_library( milvus_utils STATIC ${UTILS_FILES} )
|
||||
|
||||
target_link_libraries(utils
|
||||
target_link_libraries(milvus_utils
|
||||
libboost_filesystem.a
|
||||
libboost_system.a)
|
||||
|
@ -11,11 +11,20 @@
|
||||
|
||||
#include <iostream>
|
||||
#include "EasyAssert.h"
|
||||
// #define BOOST_STACKTRACE_USE_ADDR2LINE
|
||||
#define BOOST_STACKTRACE_USE_BACKTRACE
|
||||
#include <boost/stacktrace.hpp>
|
||||
#include <sstream>
|
||||
|
||||
namespace milvus::impl {
|
||||
|
||||
std::string
|
||||
EasyStackTrace() {
|
||||
auto stack_info = boost::stacktrace::stacktrace();
|
||||
std::ostringstream ss;
|
||||
ss << stack_info;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
void
|
||||
EasyAssertInfo(
|
||||
bool value, std::string_view expr_str, std::string_view filename, int lineno, std::string_view extra_info) {
|
||||
@ -26,11 +35,15 @@ EasyAssertInfo(
|
||||
if (!extra_info.empty()) {
|
||||
info += " => " + std::string(extra_info);
|
||||
}
|
||||
auto fuck = boost::stacktrace::stacktrace();
|
||||
std::cout << fuck;
|
||||
// std::string s = fuck;
|
||||
// info += ;
|
||||
throw std::runtime_error(info);
|
||||
|
||||
throw std::runtime_error(info + "\n" + EasyStackTrace());
|
||||
}
|
||||
}
|
||||
|
||||
[[noreturn]] void
|
||||
ThrowWithTrace(const std::exception& exception) {
|
||||
auto err_msg = exception.what() + std::string("\n") + EasyStackTrace();
|
||||
throw std::runtime_error(err_msg);
|
||||
}
|
||||
|
||||
} // namespace milvus::impl
|
||||
|
@ -11,6 +11,7 @@
|
||||
|
||||
#pragma once
|
||||
#include <string_view>
|
||||
#include <exception>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
@ -20,7 +21,11 @@ namespace milvus::impl {
|
||||
void
|
||||
EasyAssertInfo(
|
||||
bool value, std::string_view expr_str, std::string_view filename, int lineno, std::string_view extra_info);
|
||||
}
|
||||
|
||||
[[noreturn]] void
|
||||
ThrowWithTrace(const std::exception& exception);
|
||||
|
||||
} // namespace milvus::impl
|
||||
|
||||
#define AssertInfo(expr, info) milvus::impl::EasyAssertInfo(bool(expr), #expr, __FILE__, __LINE__, (info))
|
||||
#define Assert(expr) AssertInfo((expr), "")
|
||||
|
@ -11,6 +11,13 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "utils/EasyAssert.h"
|
||||
|
||||
#define JSON_ASSERT(expr) Assert((expr))
|
||||
// TODO: dispatch error by type
|
||||
|
||||
#define JSON_THROW_USER(e) milvus::impl::ThrowWithTrace((e))
|
||||
|
||||
#include "nlohmann/json.hpp"
|
||||
|
||||
namespace milvus {
|
||||
|
@ -24,5 +24,6 @@ target_link_libraries(all_tests
|
||||
knowhere
|
||||
log
|
||||
pthread
|
||||
milvus_utils
|
||||
)
|
||||
install (TARGETS all_tests DESTINATION unittest)
|
||||
|
@ -641,8 +641,14 @@ TEST(CApiTest, Reduce) {
|
||||
results.push_back(res1);
|
||||
results.push_back(res2);
|
||||
|
||||
auto reduced_search_result = ReduceQueryResults(results.data(), 2);
|
||||
auto reorganize_search_result = ReorganizeQueryResults(reduced_search_result, plan, placeholderGroups.data(), 1);
|
||||
bool is_selected[1] = {false};
|
||||
status = ReduceQueryResults(results.data(), 1, is_selected);
|
||||
assert(status.error_code == Success);
|
||||
FillTargetEntry(segment, plan, res1);
|
||||
void* reorganize_search_result = nullptr;
|
||||
status = ReorganizeQueryResults(&reorganize_search_result, placeholderGroups.data(), 1, results.data(), is_selected,
|
||||
1, plan);
|
||||
assert(status.error_code == Success);
|
||||
auto hits_blob_size = GetHitsBlobSize(reorganize_search_result);
|
||||
assert(hits_blob_size > 0);
|
||||
std::vector<char> hits_blob;
|
||||
@ -660,7 +666,6 @@ TEST(CApiTest, Reduce) {
|
||||
DeletePlaceholderGroup(placeholderGroup);
|
||||
DeleteQueryResult(res1);
|
||||
DeleteQueryResult(res2);
|
||||
DeleteQueryResult(reduced_search_result);
|
||||
DeleteMarshaledHits(reorganize_search_result);
|
||||
DeleteCollection(collection);
|
||||
DeleteSegment(segment);
|
||||
|
@ -107,6 +107,83 @@ TEST(Expr, Range) {
|
||||
std::cout << out.dump(4);
|
||||
}
|
||||
|
||||
TEST(Expr, InvalidRange) {
|
||||
SUCCEED();
|
||||
using namespace milvus;
|
||||
using namespace milvus::query;
|
||||
using namespace milvus::segcore;
|
||||
std::string dsl_string = R"(
|
||||
{
|
||||
"bool": {
|
||||
"must": [
|
||||
{
|
||||
"range": {
|
||||
"age": {
|
||||
"GT": 1,
|
||||
"LT": "100"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"vector": {
|
||||
"fakevec": {
|
||||
"metric_type": "L2",
|
||||
"params": {
|
||||
"nprobe": 10
|
||||
},
|
||||
"query": "$0",
|
||||
"topk": 10
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
})";
|
||||
auto schema = std::make_shared<Schema>();
|
||||
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16);
|
||||
schema->AddField("age", DataType::INT32);
|
||||
ASSERT_ANY_THROW(CreatePlan(*schema, dsl_string));
|
||||
}
|
||||
|
||||
TEST(Expr, InvalidDSL) {
|
||||
SUCCEED();
|
||||
using namespace milvus;
|
||||
using namespace milvus::query;
|
||||
using namespace milvus::segcore;
|
||||
std::string dsl_string = R"(
|
||||
{
|
||||
"float": {
|
||||
"must": [
|
||||
{
|
||||
"range": {
|
||||
"age": {
|
||||
"GT": 1,
|
||||
"LT": 100
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"vector": {
|
||||
"fakevec": {
|
||||
"metric_type": "L2",
|
||||
"params": {
|
||||
"nprobe": 10
|
||||
},
|
||||
"query": "$0",
|
||||
"topk": 10
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
})";
|
||||
|
||||
auto schema = std::make_shared<Schema>();
|
||||
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16);
|
||||
schema->AddField("age", DataType::INT32);
|
||||
ASSERT_ANY_THROW(CreatePlan(*schema, dsl_string));
|
||||
}
|
||||
|
||||
TEST(Expr, ShowExecutor) {
|
||||
using namespace milvus::query;
|
||||
using namespace milvus::segcore;
|
||||
|
@ -1,4 +1,4 @@
|
||||
package kv
|
||||
package etcdkv
|
||||
|
||||
import (
|
||||
"context"
|
@ -1,11 +1,11 @@
|
||||
package kv_test
|
||||
package etcdkv_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zilliztech/milvus-distributed/internal/kv"
|
||||
etcdkv "github.com/zilliztech/milvus-distributed/internal/kv/etcd"
|
||||
"github.com/zilliztech/milvus-distributed/internal/util/paramtable"
|
||||
"go.etcd.io/etcd/clientv3"
|
||||
)
|
||||
@ -28,7 +28,7 @@ func TestEtcdKV_Load(t *testing.T) {
|
||||
cli, err := clientv3.New(clientv3.Config{Endpoints: []string{etcdAddr}})
|
||||
assert.Nil(t, err)
|
||||
rootPath := "/etcd/test/root"
|
||||
etcdKV := kv.NewEtcdKV(cli, rootPath)
|
||||
etcdKV := etcdkv.NewEtcdKV(cli, rootPath)
|
||||
|
||||
defer etcdKV.Close()
|
||||
defer etcdKV.RemoveWithPrefix("")
|
||||
@ -86,7 +86,7 @@ func TestEtcdKV_MultiSave(t *testing.T) {
|
||||
cli, err := clientv3.New(clientv3.Config{Endpoints: []string{etcdAddr}})
|
||||
assert.Nil(t, err)
|
||||
rootPath := "/etcd/test/root"
|
||||
etcdKV := kv.NewEtcdKV(cli, rootPath)
|
||||
etcdKV := etcdkv.NewEtcdKV(cli, rootPath)
|
||||
|
||||
defer etcdKV.Close()
|
||||
defer etcdKV.RemoveWithPrefix("")
|
||||
@ -117,7 +117,7 @@ func TestEtcdKV_Remove(t *testing.T) {
|
||||
cli, err := clientv3.New(clientv3.Config{Endpoints: []string{etcdAddr}})
|
||||
assert.Nil(t, err)
|
||||
rootPath := "/etcd/test/root"
|
||||
etcdKV := kv.NewEtcdKV(cli, rootPath)
|
||||
etcdKV := etcdkv.NewEtcdKV(cli, rootPath)
|
||||
|
||||
defer etcdKV.Close()
|
||||
defer etcdKV.RemoveWithPrefix("")
|
||||
@ -188,7 +188,7 @@ func TestEtcdKV_MultiSaveAndRemove(t *testing.T) {
|
||||
cli, err := clientv3.New(clientv3.Config{Endpoints: []string{etcdAddr}})
|
||||
assert.Nil(t, err)
|
||||
rootPath := "/etcd/test/root"
|
||||
etcdKV := kv.NewEtcdKV(cli, rootPath)
|
||||
etcdKV := etcdkv.NewEtcdKV(cli, rootPath)
|
||||
|
||||
defer etcdKV.Close()
|
||||
defer etcdKV.RemoveWithPrefix("")
|
@ -8,6 +8,11 @@ type Base interface {
|
||||
MultiSave(kvs map[string]string) error
|
||||
Remove(key string) error
|
||||
MultiRemove(keys []string) error
|
||||
MultiSaveAndRemove(saves map[string]string, removals []string) error
|
||||
|
||||
Close()
|
||||
}
|
||||
|
||||
type TxnBase interface {
|
||||
Base
|
||||
MultiSaveAndRemove(saves map[string]string, removals []string) error
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
package kv
|
||||
package memkv
|
||||
|
||||
import (
|
||||
"sync"
|
149
internal/kv/minio/minio_kv.go
Normal file
149
internal/kv/minio/minio_kv.go
Normal file
@ -0,0 +1,149 @@
|
||||
package miniokv
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"log"
|
||||
"strings"
|
||||
|
||||
"github.com/minio/minio-go/v7"
|
||||
)
|
||||
|
||||
type MinIOKV struct {
|
||||
ctx context.Context
|
||||
minioClient *minio.Client
|
||||
bucketName string
|
||||
}
|
||||
|
||||
// NewMinIOKV creates a new MinIO kv.
|
||||
func NewMinIOKV(ctx context.Context, client *minio.Client, bucketName string) (*MinIOKV, error) {
|
||||
|
||||
bucketExists, err := client.BucketExists(ctx, bucketName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !bucketExists {
|
||||
err = client.MakeBucket(ctx, bucketName, minio.MakeBucketOptions{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return &MinIOKV{
|
||||
ctx: ctx,
|
||||
minioClient: client,
|
||||
bucketName: bucketName,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (kv *MinIOKV) LoadWithPrefix(key string) ([]string, []string, error) {
|
||||
objects := kv.minioClient.ListObjects(kv.ctx, kv.bucketName, minio.ListObjectsOptions{Prefix: key})
|
||||
|
||||
var objectsKeys []string
|
||||
var objectsValues []string
|
||||
|
||||
for object := range objects {
|
||||
objectsKeys = append(objectsKeys, object.Key)
|
||||
}
|
||||
objectsValues, err := kv.MultiLoad(objectsKeys)
|
||||
if err != nil {
|
||||
log.Printf("cannot load value with prefix:%s", key)
|
||||
}
|
||||
|
||||
return objectsKeys, objectsValues, nil
|
||||
}
|
||||
|
||||
func (kv *MinIOKV) Load(key string) (string, error) {
|
||||
object, err := kv.minioClient.GetObject(kv.ctx, kv.bucketName, key, minio.GetObjectOptions{})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
buf := new(strings.Builder)
|
||||
_, err = io.Copy(buf, object)
|
||||
if err != nil && err != io.EOF {
|
||||
return "", err
|
||||
}
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
func (kv *MinIOKV) MultiLoad(keys []string) ([]string, error) {
|
||||
var resultErr error
|
||||
var objectsValues []string
|
||||
for _, key := range keys {
|
||||
objectValue, err := kv.Load(key)
|
||||
if err != nil {
|
||||
if resultErr == nil {
|
||||
resultErr = err
|
||||
}
|
||||
}
|
||||
objectsValues = append(objectsValues, objectValue)
|
||||
}
|
||||
|
||||
return objectsValues, resultErr
|
||||
}
|
||||
|
||||
func (kv *MinIOKV) Save(key, value string) error {
|
||||
reader := strings.NewReader(value)
|
||||
_, err := kv.minioClient.PutObject(kv.ctx, kv.bucketName, key, reader, int64(len(value)), minio.PutObjectOptions{})
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (kv *MinIOKV) MultiSave(kvs map[string]string) error {
|
||||
var resultErr error
|
||||
for key, value := range kvs {
|
||||
err := kv.Save(key, value)
|
||||
if err != nil {
|
||||
if resultErr == nil {
|
||||
resultErr = err
|
||||
}
|
||||
}
|
||||
}
|
||||
return resultErr
|
||||
}
|
||||
|
||||
func (kv *MinIOKV) RemoveWithPrefix(prefix string) error {
|
||||
objectsCh := make(chan minio.ObjectInfo)
|
||||
|
||||
go func() {
|
||||
defer close(objectsCh)
|
||||
|
||||
for object := range kv.minioClient.ListObjects(kv.ctx, kv.bucketName, minio.ListObjectsOptions{Prefix: prefix}) {
|
||||
objectsCh <- object
|
||||
}
|
||||
}()
|
||||
|
||||
for rErr := range kv.minioClient.RemoveObjects(kv.ctx, kv.bucketName, objectsCh, minio.RemoveObjectsOptions{GovernanceBypass: true}) {
|
||||
if rErr.Err != nil {
|
||||
return rErr.Err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (kv *MinIOKV) Remove(key string) error {
|
||||
err := kv.minioClient.RemoveObject(kv.ctx, kv.bucketName, string(key), minio.RemoveObjectOptions{})
|
||||
return err
|
||||
}
|
||||
|
||||
func (kv *MinIOKV) MultiRemove(keys []string) error {
|
||||
var resultErr error
|
||||
for _, key := range keys {
|
||||
err := kv.Remove(key)
|
||||
if err != nil {
|
||||
if resultErr == nil {
|
||||
resultErr = err
|
||||
}
|
||||
}
|
||||
}
|
||||
return resultErr
|
||||
}
|
||||
|
||||
func (kv *MinIOKV) Close() {
|
||||
|
||||
}
|
195
internal/kv/minio/minio_kv_test.go
Normal file
195
internal/kv/minio/minio_kv_test.go
Normal file
@ -0,0 +1,195 @@
|
||||
package miniokv_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/minio/minio-go/v7"
|
||||
"github.com/minio/minio-go/v7/pkg/credentials"
|
||||
miniokv "github.com/zilliztech/milvus-distributed/internal/kv/minio"
|
||||
"github.com/zilliztech/milvus-distributed/internal/util/paramtable"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
var Params paramtable.BaseTable
|
||||
|
||||
func TestMinIOKV_Load(t *testing.T) {
|
||||
Params.Init()
|
||||
endPoint, _ := Params.Load("_MinioAddress")
|
||||
accessKeyID, _ := Params.Load("minio.accessKeyID")
|
||||
secretAccessKey, _ := Params.Load("minio.secretAccessKey")
|
||||
useSSLStr, _ := Params.Load("minio.useSSL")
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
useSSL, _ := strconv.ParseBool(useSSLStr)
|
||||
|
||||
minioClient, err := minio.New(endPoint, &minio.Options{
|
||||
Creds: credentials.NewStaticV4(accessKeyID, secretAccessKey, ""),
|
||||
Secure: useSSL,
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
|
||||
bucketName := "fantastic-tech-test"
|
||||
MinIOKV, err := miniokv.NewMinIOKV(ctx, minioClient, bucketName)
|
||||
assert.Nil(t, err)
|
||||
defer MinIOKV.RemoveWithPrefix("")
|
||||
|
||||
err = MinIOKV.Save("abc", "123")
|
||||
assert.Nil(t, err)
|
||||
err = MinIOKV.Save("abcd", "1234")
|
||||
assert.Nil(t, err)
|
||||
|
||||
val, err := MinIOKV.Load("abc")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, val, "123")
|
||||
|
||||
keys, vals, err := MinIOKV.LoadWithPrefix("abc")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, len(keys), len(vals))
|
||||
assert.Equal(t, len(keys), 2)
|
||||
|
||||
assert.Equal(t, vals[0], "123")
|
||||
assert.Equal(t, vals[1], "1234")
|
||||
|
||||
err = MinIOKV.Save("key_1", "123")
|
||||
assert.Nil(t, err)
|
||||
err = MinIOKV.Save("key_2", "456")
|
||||
assert.Nil(t, err)
|
||||
err = MinIOKV.Save("key_3", "789")
|
||||
assert.Nil(t, err)
|
||||
|
||||
keys = []string{"key_1", "key_100"}
|
||||
|
||||
vals, err = MinIOKV.MultiLoad(keys)
|
||||
assert.NotNil(t, err)
|
||||
assert.Equal(t, len(vals), len(keys))
|
||||
assert.Equal(t, vals[0], "123")
|
||||
assert.Empty(t, vals[1])
|
||||
|
||||
keys = []string{"key_1", "key_2"}
|
||||
|
||||
vals, err = MinIOKV.MultiLoad(keys)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, len(vals), len(keys))
|
||||
assert.Equal(t, vals[0], "123")
|
||||
assert.Equal(t, vals[1], "456")
|
||||
|
||||
}
|
||||
|
||||
func TestMinIOKV_MultiSave(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
Params.Init()
|
||||
endPoint, _ := Params.Load("_MinioAddress")
|
||||
accessKeyID, _ := Params.Load("minio.accessKeyID")
|
||||
secretAccessKey, _ := Params.Load("minio.secretAccessKey")
|
||||
useSSLStr, _ := Params.Load("minio.useSSL")
|
||||
useSSL, _ := strconv.ParseBool(useSSLStr)
|
||||
|
||||
minioClient, err := minio.New(endPoint, &minio.Options{
|
||||
Creds: credentials.NewStaticV4(accessKeyID, secretAccessKey, ""),
|
||||
Secure: useSSL,
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
|
||||
bucketName := "fantastic-tech-test"
|
||||
MinIOKV, err := miniokv.NewMinIOKV(ctx, minioClient, bucketName)
|
||||
assert.Nil(t, err)
|
||||
defer MinIOKV.RemoveWithPrefix("")
|
||||
|
||||
err = MinIOKV.Save("key_1", "111")
|
||||
assert.Nil(t, err)
|
||||
|
||||
kvs := map[string]string{
|
||||
"key_1": "123",
|
||||
"key_2": "456",
|
||||
}
|
||||
|
||||
err = MinIOKV.MultiSave(kvs)
|
||||
assert.Nil(t, err)
|
||||
|
||||
val, err := MinIOKV.Load("key_1")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, val, "123")
|
||||
}
|
||||
|
||||
func TestMinIOKV_Remove(t *testing.T) {
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
Params.Init()
|
||||
endPoint, _ := Params.Load("_MinioAddress")
|
||||
accessKeyID, _ := Params.Load("minio.accessKeyID")
|
||||
secretAccessKey, _ := Params.Load("minio.secretAccessKey")
|
||||
useSSLStr, _ := Params.Load("minio.useSSL")
|
||||
useSSL, _ := strconv.ParseBool(useSSLStr)
|
||||
|
||||
minioClient, err := minio.New(endPoint, &minio.Options{
|
||||
Creds: credentials.NewStaticV4(accessKeyID, secretAccessKey, ""),
|
||||
Secure: useSSL,
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
|
||||
bucketName := "fantastic-tech-test"
|
||||
MinIOKV, err := miniokv.NewMinIOKV(ctx, minioClient, bucketName)
|
||||
assert.Nil(t, err)
|
||||
defer MinIOKV.RemoveWithPrefix("")
|
||||
|
||||
err = MinIOKV.Save("key_1", "123")
|
||||
assert.Nil(t, err)
|
||||
err = MinIOKV.Save("key_2", "456")
|
||||
assert.Nil(t, err)
|
||||
|
||||
val, err := MinIOKV.Load("key_1")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, val, "123")
|
||||
// delete "key_1"
|
||||
err = MinIOKV.Remove("key_1")
|
||||
assert.Nil(t, err)
|
||||
val, err = MinIOKV.Load("key_1")
|
||||
assert.Error(t, err)
|
||||
assert.Empty(t, val)
|
||||
|
||||
val, err = MinIOKV.Load("key_2")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, val, "456")
|
||||
|
||||
keys, vals, err := MinIOKV.LoadWithPrefix("key")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, len(keys), len(vals))
|
||||
assert.Equal(t, len(keys), 1)
|
||||
|
||||
assert.Equal(t, vals[0], "456")
|
||||
|
||||
// MultiRemove
|
||||
err = MinIOKV.Save("key_1", "111")
|
||||
assert.Nil(t, err)
|
||||
|
||||
kvs := map[string]string{
|
||||
"key_1": "123",
|
||||
"key_2": "456",
|
||||
"key_3": "789",
|
||||
"key_4": "012",
|
||||
}
|
||||
|
||||
err = MinIOKV.MultiSave(kvs)
|
||||
assert.Nil(t, err)
|
||||
val, err = MinIOKV.Load("key_1")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, val, "123")
|
||||
val, err = MinIOKV.Load("key_3")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, val, "789")
|
||||
|
||||
keys = []string{"key_1", "key_2", "key_3"}
|
||||
err = MinIOKV.MultiRemove(keys)
|
||||
assert.Nil(t, err)
|
||||
|
||||
val, err = MinIOKV.Load("key_1")
|
||||
assert.Error(t, err)
|
||||
assert.Empty(t, val)
|
||||
}
|
@ -1,14 +1,14 @@
|
||||
package mockkv
|
||||
|
||||
import (
|
||||
"github.com/zilliztech/milvus-distributed/internal/kv"
|
||||
memkv "github.com/zilliztech/milvus-distributed/internal/kv/mem"
|
||||
)
|
||||
|
||||
// use MemoryKV to mock EtcdKV
|
||||
func NewEtcdKV() *kv.MemoryKV {
|
||||
return kv.NewMemoryKV()
|
||||
func NewEtcdKV() *memkv.MemoryKV {
|
||||
return memkv.NewMemoryKV()
|
||||
}
|
||||
|
||||
func NewMemoryKV() *kv.MemoryKV {
|
||||
return kv.NewMemoryKV()
|
||||
func NewMemoryKV() *memkv.MemoryKV {
|
||||
return memkv.NewMemoryKV()
|
||||
}
|
||||
|
@ -4,13 +4,13 @@ import (
|
||||
"log"
|
||||
|
||||
"github.com/zilliztech/milvus-distributed/internal/errors"
|
||||
"github.com/zilliztech/milvus-distributed/internal/kv"
|
||||
etcdkv "github.com/zilliztech/milvus-distributed/internal/kv/etcd"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/internalpb"
|
||||
)
|
||||
|
||||
type getSysConfigsTask struct {
|
||||
baseTask
|
||||
configkv *kv.EtcdKV
|
||||
configkv *etcdkv.EtcdKV
|
||||
req *internalpb.SysConfigRequest
|
||||
keys []string
|
||||
values []string
|
||||
|
@ -36,7 +36,7 @@ type GlobalTSOAllocator struct {
|
||||
}
|
||||
|
||||
// NewGlobalTSOAllocator creates a new global TSO allocator.
|
||||
func NewGlobalTSOAllocator(key string, kvBase kv.Base) *GlobalTSOAllocator {
|
||||
func NewGlobalTSOAllocator(key string, kvBase kv.TxnBase) *GlobalTSOAllocator {
|
||||
var saveInterval = 3 * time.Second
|
||||
return &GlobalTSOAllocator{
|
||||
tso: ×tampOracle{
|
||||
|
@ -271,7 +271,7 @@ func (s *Master) HasPartition(ctx context.Context, in *internalpb.HasPartitionRe
|
||||
return &servicepb.BoolResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR,
|
||||
Reason: "WaitToFinish failed",
|
||||
Reason: err.Error(),
|
||||
},
|
||||
Value: t.(*hasPartitionTask).hasPartition,
|
||||
}, nil
|
||||
|
@ -9,7 +9,7 @@ type GlobalIDAllocator struct {
|
||||
allocator Allocator
|
||||
}
|
||||
|
||||
func NewGlobalIDAllocator(key string, base kv.Base) *GlobalIDAllocator {
|
||||
func NewGlobalIDAllocator(key string, base kv.TxnBase) *GlobalIDAllocator {
|
||||
return &GlobalIDAllocator{
|
||||
allocator: NewGlobalTSOAllocator(key, base),
|
||||
}
|
||||
|
@ -10,7 +10,7 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/zilliztech/milvus-distributed/internal/kv"
|
||||
etcdkv "github.com/zilliztech/milvus-distributed/internal/kv/etcd"
|
||||
ms "github.com/zilliztech/milvus-distributed/internal/msgstream"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/masterpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/util/tsoutil"
|
||||
@ -42,7 +42,7 @@ type Master struct {
|
||||
grpcServer *grpc.Server
|
||||
grpcErr chan error
|
||||
|
||||
kvBase *kv.EtcdKV
|
||||
kvBase *etcdkv.EtcdKV
|
||||
scheduler *ddRequestScheduler
|
||||
metaTable *metaTable
|
||||
timesSyncMsgProducer *timeSyncMsgProducer
|
||||
@ -63,12 +63,12 @@ type Master struct {
|
||||
tsoAllocator *GlobalTSOAllocator
|
||||
}
|
||||
|
||||
func newKVBase(kvRoot string, etcdAddr []string) *kv.EtcdKV {
|
||||
func newKVBase(kvRoot string, etcdAddr []string) *etcdkv.EtcdKV {
|
||||
cli, _ := clientv3.New(clientv3.Config{
|
||||
Endpoints: etcdAddr,
|
||||
DialTimeout: 5 * time.Second,
|
||||
})
|
||||
kvBase := kv.NewEtcdKV(cli, kvRoot)
|
||||
kvBase := etcdkv.NewEtcdKV(cli, kvRoot)
|
||||
return kvBase
|
||||
}
|
||||
|
||||
@ -89,8 +89,8 @@ func CreateServer(ctx context.Context) (*Master, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
etcdkv := kv.NewEtcdKV(etcdClient, metaRootPath)
|
||||
metakv, err := NewMetaTable(etcdkv)
|
||||
etcdKV := etcdkv.NewEtcdKV(etcdClient, metaRootPath)
|
||||
metakv, err := NewMetaTable(etcdKV)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -11,7 +11,7 @@ import (
|
||||
)
|
||||
|
||||
type metaTable struct {
|
||||
client *kv.EtcdKV // client of a reliable kv service, i.e. etcd client
|
||||
client kv.TxnBase // client of a reliable kv service, i.e. etcd client
|
||||
tenantID2Meta map[UniqueID]pb.TenantMeta // tenant id to tenant meta
|
||||
proxyID2Meta map[UniqueID]pb.ProxyMeta // proxy id to proxy meta
|
||||
collID2Meta map[UniqueID]pb.CollectionMeta // collection id to collection meta
|
||||
@ -23,7 +23,7 @@ type metaTable struct {
|
||||
ddLock sync.RWMutex
|
||||
}
|
||||
|
||||
func NewMetaTable(kv *kv.EtcdKV) (*metaTable, error) {
|
||||
func NewMetaTable(kv kv.TxnBase) (*metaTable, error) {
|
||||
mt := &metaTable{
|
||||
client: kv,
|
||||
tenantLock: sync.RWMutex{},
|
||||
|
@ -7,7 +7,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zilliztech/milvus-distributed/internal/kv"
|
||||
etcdkv "github.com/zilliztech/milvus-distributed/internal/kv/etcd"
|
||||
pb "github.com/zilliztech/milvus-distributed/internal/proto/etcdpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
|
||||
"go.etcd.io/etcd/clientv3"
|
||||
@ -19,7 +19,7 @@ func TestMetaTable_Collection(t *testing.T) {
|
||||
etcdAddr := Params.EtcdAddress
|
||||
cli, err := clientv3.New(clientv3.Config{Endpoints: []string{etcdAddr}})
|
||||
assert.Nil(t, err)
|
||||
etcdKV := kv.NewEtcdKV(cli, "/etcd/test/root")
|
||||
etcdKV := etcdkv.NewEtcdKV(cli, "/etcd/test/root")
|
||||
|
||||
_, err = cli.Delete(context.TODO(), "/etcd/test/root", clientv3.WithPrefix())
|
||||
assert.Nil(t, err)
|
||||
@ -157,7 +157,7 @@ func TestMetaTable_DeletePartition(t *testing.T) {
|
||||
|
||||
cli, err := clientv3.New(clientv3.Config{Endpoints: []string{etcdAddr}})
|
||||
assert.Nil(t, err)
|
||||
etcdKV := kv.NewEtcdKV(cli, "/etcd/test/root")
|
||||
etcdKV := etcdkv.NewEtcdKV(cli, "/etcd/test/root")
|
||||
|
||||
_, err = cli.Delete(context.TODO(), "/etcd/test/root", clientv3.WithPrefix())
|
||||
assert.Nil(t, err)
|
||||
@ -252,7 +252,7 @@ func TestMetaTable_Segment(t *testing.T) {
|
||||
|
||||
cli, err := clientv3.New(clientv3.Config{Endpoints: []string{etcdAddr}})
|
||||
assert.Nil(t, err)
|
||||
etcdKV := kv.NewEtcdKV(cli, "/etcd/test/root")
|
||||
etcdKV := etcdkv.NewEtcdKV(cli, "/etcd/test/root")
|
||||
|
||||
_, err = cli.Delete(context.TODO(), "/etcd/test/root", clientv3.WithPrefix())
|
||||
assert.Nil(t, err)
|
||||
@ -333,7 +333,7 @@ func TestMetaTable_UpdateSegment(t *testing.T) {
|
||||
|
||||
cli, err := clientv3.New(clientv3.Config{Endpoints: []string{etcdAddr}})
|
||||
assert.Nil(t, err)
|
||||
etcdKV := kv.NewEtcdKV(cli, "/etcd/test/root")
|
||||
etcdKV := etcdkv.NewEtcdKV(cli, "/etcd/test/root")
|
||||
|
||||
_, err = cli.Delete(context.TODO(), "/etcd/test/root", clientv3.WithPrefix())
|
||||
assert.Nil(t, err)
|
||||
@ -379,7 +379,7 @@ func TestMetaTable_AddPartition_Limit(t *testing.T) {
|
||||
|
||||
cli, err := clientv3.New(clientv3.Config{Endpoints: []string{etcdAddr}})
|
||||
assert.Nil(t, err)
|
||||
etcdKV := kv.NewEtcdKV(cli, "/etcd/test/root")
|
||||
etcdKV := etcdkv.NewEtcdKV(cli, "/etcd/test/root")
|
||||
|
||||
_, err = cli.Delete(context.TODO(), "/etcd/test/root", clientv3.WithPrefix())
|
||||
assert.Nil(t, err)
|
||||
|
@ -9,7 +9,7 @@ import (
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zilliztech/milvus-distributed/internal/kv"
|
||||
etcdkv "github.com/zilliztech/milvus-distributed/internal/kv/etcd"
|
||||
ms "github.com/zilliztech/milvus-distributed/internal/msgstream"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/internalpb"
|
||||
@ -31,7 +31,7 @@ func TestMaster_Scheduler_Collection(t *testing.T) {
|
||||
|
||||
cli, err := clientv3.New(clientv3.Config{Endpoints: []string{etcdAddress}})
|
||||
assert.Nil(t, err)
|
||||
etcdKV := kv.NewEtcdKV(cli, "/etcd/test/root")
|
||||
etcdKV := etcdkv.NewEtcdKV(cli, "/etcd/test/root")
|
||||
|
||||
meta, err := NewMetaTable(etcdKV)
|
||||
assert.Nil(t, err)
|
||||
@ -168,7 +168,7 @@ func TestMaster_Scheduler_Partition(t *testing.T) {
|
||||
|
||||
cli, err := clientv3.New(clientv3.Config{Endpoints: []string{etcdAddress}})
|
||||
assert.Nil(t, err)
|
||||
etcdKV := kv.NewEtcdKV(cli, "/etcd/test/root")
|
||||
etcdKV := etcdkv.NewEtcdKV(cli, "/etcd/test/root")
|
||||
|
||||
meta, err := NewMetaTable(etcdKV)
|
||||
assert.Nil(t, err)
|
||||
|
@ -11,6 +11,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zilliztech/milvus-distributed/internal/errors"
|
||||
"github.com/zilliztech/milvus-distributed/internal/kv"
|
||||
etcdkv "github.com/zilliztech/milvus-distributed/internal/kv/etcd"
|
||||
"github.com/zilliztech/milvus-distributed/internal/msgstream"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
|
||||
pb "github.com/zilliztech/milvus-distributed/internal/proto/etcdpb"
|
||||
@ -29,7 +30,7 @@ var segMgr *SegmentManager
|
||||
var collName = "coll_segmgr_test"
|
||||
var collID = int64(1001)
|
||||
var partitionTag = "test"
|
||||
var kvBase *kv.EtcdKV
|
||||
var kvBase kv.TxnBase
|
||||
var master *Master
|
||||
var masterCancelFunc context.CancelFunc
|
||||
|
||||
@ -48,7 +49,7 @@ func setup() {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
kvBase = kv.NewEtcdKV(cli, rootPath)
|
||||
kvBase = etcdkv.NewEtcdKV(cli, rootPath)
|
||||
tmpMt, err := NewMetaTable(kvBase)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
|
@ -10,11 +10,11 @@ import (
|
||||
|
||||
"github.com/spf13/viper"
|
||||
"github.com/zilliztech/milvus-distributed/internal/errors"
|
||||
"github.com/zilliztech/milvus-distributed/internal/kv"
|
||||
etcdkv "github.com/zilliztech/milvus-distributed/internal/kv/etcd"
|
||||
)
|
||||
|
||||
type SysConfig struct {
|
||||
kv *kv.EtcdKV
|
||||
kv *etcdkv.EtcdKV
|
||||
}
|
||||
|
||||
// Initialize Configs from config files, and store them in Etcd.
|
||||
|
@ -9,7 +9,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/zilliztech/milvus-distributed/internal/kv"
|
||||
etcdkv "github.com/zilliztech/milvus-distributed/internal/kv/etcd"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@ -31,7 +31,7 @@ func Test_SysConfig(t *testing.T) {
|
||||
require.Nil(t, err)
|
||||
|
||||
rootPath := "/test/root"
|
||||
configKV := kv.NewEtcdKV(cli, rootPath)
|
||||
configKV := etcdkv.NewEtcdKV(cli, rootPath)
|
||||
defer configKV.Close()
|
||||
|
||||
sc := SysConfig{kv: configKV}
|
||||
|
@ -47,7 +47,7 @@ type atomicObject struct {
|
||||
// timestampOracle is used to maintain the logic of tso.
|
||||
type timestampOracle struct {
|
||||
key string
|
||||
kvBase kv.Base
|
||||
kvBase kv.TxnBase
|
||||
|
||||
// TODO: remove saveInterval
|
||||
saveInterval time.Duration
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
"log"
|
||||
"reflect"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/apache/pulsar-client-go/pulsar"
|
||||
"github.com/golang/protobuf/proto"
|
||||
@ -69,11 +70,22 @@ func (ms *PulsarMsgStream) SetPulsarClient(address string) {
|
||||
|
||||
func (ms *PulsarMsgStream) CreatePulsarProducers(channels []string) {
|
||||
for i := 0; i < len(channels); i++ {
|
||||
pp, err := (*ms.client).CreateProducer(pulsar.ProducerOptions{Topic: channels[i]})
|
||||
if err != nil {
|
||||
log.Printf("Failed to create querynode producer %s, error = %v", channels[i], err)
|
||||
fn := func() error {
|
||||
pp, err := (*ms.client).CreateProducer(pulsar.ProducerOptions{Topic: channels[i]})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if pp == nil {
|
||||
return errors.New("pulsar is not ready, producer is nil")
|
||||
}
|
||||
ms.producers = append(ms.producers, &pp)
|
||||
return nil
|
||||
}
|
||||
err := Retry(10, time.Millisecond*200, fn)
|
||||
if err != nil {
|
||||
errMsg := "Failed to create producer " + channels[i] + ", error = " + err.Error()
|
||||
panic(errMsg)
|
||||
}
|
||||
ms.producers = append(ms.producers, &pp)
|
||||
}
|
||||
}
|
||||
|
||||
@ -83,18 +95,29 @@ func (ms *PulsarMsgStream) CreatePulsarConsumers(channels []string,
|
||||
pulsarBufSize int64) {
|
||||
ms.unmarshal = unmarshal
|
||||
for i := 0; i < len(channels); i++ {
|
||||
receiveChannel := make(chan pulsar.ConsumerMessage, pulsarBufSize)
|
||||
pc, err := (*ms.client).Subscribe(pulsar.ConsumerOptions{
|
||||
Topic: channels[i],
|
||||
SubscriptionName: subName,
|
||||
Type: pulsar.KeyShared,
|
||||
SubscriptionInitialPosition: pulsar.SubscriptionPositionEarliest,
|
||||
MessageChannel: receiveChannel,
|
||||
})
|
||||
if err != nil {
|
||||
log.Printf("Failed to subscribe topic, error = %v", err)
|
||||
fn := func() error {
|
||||
receiveChannel := make(chan pulsar.ConsumerMessage, pulsarBufSize)
|
||||
pc, err := (*ms.client).Subscribe(pulsar.ConsumerOptions{
|
||||
Topic: channels[i],
|
||||
SubscriptionName: subName,
|
||||
Type: pulsar.KeyShared,
|
||||
SubscriptionInitialPosition: pulsar.SubscriptionPositionEarliest,
|
||||
MessageChannel: receiveChannel,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if pc == nil {
|
||||
return errors.New("pulsar is not ready, consumer is nil")
|
||||
}
|
||||
ms.consumers = append(ms.consumers, &pc)
|
||||
return nil
|
||||
}
|
||||
err := Retry(10, time.Millisecond*200, fn)
|
||||
if err != nil {
|
||||
errMsg := "Failed to create consumer " + channels[i] + ", error = " + err.Error()
|
||||
panic(errMsg)
|
||||
}
|
||||
ms.consumers = append(ms.consumers, &pc)
|
||||
}
|
||||
}
|
||||
|
||||
|
32
internal/msgstream/retry.go
Normal file
32
internal/msgstream/retry.go
Normal file
@ -0,0 +1,32 @@
|
||||
package msgstream
|
||||
|
||||
import (
|
||||
"log"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Reference: https://blog.cyeam.com/golang/2018/08/27/retry
|
||||
|
||||
func Retry(attempts int, sleep time.Duration, fn func() error) error {
|
||||
if err := fn(); err != nil {
|
||||
if s, ok := err.(InterruptError); ok {
|
||||
return s.error
|
||||
}
|
||||
|
||||
if attempts--; attempts > 0 {
|
||||
log.Printf("retry func error: %s. attempts #%d after %s.", err.Error(), attempts, sleep)
|
||||
time.Sleep(sleep)
|
||||
return Retry(attempts, 2*sleep, fn)
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type InterruptError struct {
|
||||
error
|
||||
}
|
||||
|
||||
func NoRetryError(err error) InterruptError {
|
||||
return InterruptError{err}
|
||||
}
|
@ -18,6 +18,7 @@ const (
|
||||
)
|
||||
|
||||
func (p *Proxy) Insert(ctx context.Context, in *servicepb.RowBatch) (*servicepb.IntegerRangeResponse, error) {
|
||||
log.Println("insert into: ", in.CollectionName)
|
||||
it := &InsertTask{
|
||||
Condition: NewTaskCondition(ctx),
|
||||
BaseInsertTask: BaseInsertTask{
|
||||
@ -76,6 +77,7 @@ func (p *Proxy) Insert(ctx context.Context, in *servicepb.RowBatch) (*servicepb.
|
||||
}
|
||||
|
||||
func (p *Proxy) CreateCollection(ctx context.Context, req *schemapb.CollectionSchema) (*commonpb.Status, error) {
|
||||
log.Println("create collection: ", req)
|
||||
cct := &CreateCollectionTask{
|
||||
Condition: NewTaskCondition(ctx),
|
||||
CreateCollectionRequest: internalpb.CreateCollectionRequest{
|
||||
@ -117,6 +119,7 @@ func (p *Proxy) CreateCollection(ctx context.Context, req *schemapb.CollectionSc
|
||||
}
|
||||
|
||||
func (p *Proxy) Search(ctx context.Context, req *servicepb.Query) (*servicepb.QueryResult, error) {
|
||||
log.Println("search: ", req.CollectionName, req.Dsl)
|
||||
qt := &QueryTask{
|
||||
Condition: NewTaskCondition(ctx),
|
||||
SearchRequest: internalpb.SearchRequest{
|
||||
@ -164,6 +167,7 @@ func (p *Proxy) Search(ctx context.Context, req *servicepb.Query) (*servicepb.Qu
|
||||
}
|
||||
|
||||
func (p *Proxy) DropCollection(ctx context.Context, req *servicepb.CollectionName) (*commonpb.Status, error) {
|
||||
log.Println("drop collection: ", req)
|
||||
dct := &DropCollectionTask{
|
||||
Condition: NewTaskCondition(ctx),
|
||||
DropCollectionRequest: internalpb.DropCollectionRequest{
|
||||
@ -204,6 +208,7 @@ func (p *Proxy) DropCollection(ctx context.Context, req *servicepb.CollectionNam
|
||||
}
|
||||
|
||||
func (p *Proxy) HasCollection(ctx context.Context, req *servicepb.CollectionName) (*servicepb.BoolResponse, error) {
|
||||
log.Println("has collection: ", req)
|
||||
hct := &HasCollectionTask{
|
||||
Condition: NewTaskCondition(ctx),
|
||||
HasCollectionRequest: internalpb.HasCollectionRequest{
|
||||
@ -248,6 +253,7 @@ func (p *Proxy) HasCollection(ctx context.Context, req *servicepb.CollectionName
|
||||
}
|
||||
|
||||
func (p *Proxy) DescribeCollection(ctx context.Context, req *servicepb.CollectionName) (*servicepb.CollectionDescription, error) {
|
||||
log.Println("describe collection: ", req)
|
||||
dct := &DescribeCollectionTask{
|
||||
Condition: NewTaskCondition(ctx),
|
||||
DescribeCollectionRequest: internalpb.DescribeCollectionRequest{
|
||||
@ -292,6 +298,7 @@ func (p *Proxy) DescribeCollection(ctx context.Context, req *servicepb.Collectio
|
||||
}
|
||||
|
||||
func (p *Proxy) ShowCollections(ctx context.Context, req *commonpb.Empty) (*servicepb.StringListResponse, error) {
|
||||
log.Println("show collections")
|
||||
sct := &ShowCollectionsTask{
|
||||
Condition: NewTaskCondition(ctx),
|
||||
ShowCollectionRequest: internalpb.ShowCollectionRequest{
|
||||
@ -335,6 +342,7 @@ func (p *Proxy) ShowCollections(ctx context.Context, req *commonpb.Empty) (*serv
|
||||
}
|
||||
|
||||
func (p *Proxy) CreatePartition(ctx context.Context, in *servicepb.PartitionName) (*commonpb.Status, error) {
|
||||
log.Println("create partition", in)
|
||||
cpt := &CreatePartitionTask{
|
||||
Condition: NewTaskCondition(ctx),
|
||||
CreatePartitionRequest: internalpb.CreatePartitionRequest{
|
||||
@ -380,6 +388,7 @@ func (p *Proxy) CreatePartition(ctx context.Context, in *servicepb.PartitionName
|
||||
}
|
||||
|
||||
func (p *Proxy) DropPartition(ctx context.Context, in *servicepb.PartitionName) (*commonpb.Status, error) {
|
||||
log.Println("drop partition: ", in)
|
||||
dpt := &DropPartitionTask{
|
||||
Condition: NewTaskCondition(ctx),
|
||||
DropPartitionRequest: internalpb.DropPartitionRequest{
|
||||
@ -426,6 +435,7 @@ func (p *Proxy) DropPartition(ctx context.Context, in *servicepb.PartitionName)
|
||||
}
|
||||
|
||||
func (p *Proxy) HasPartition(ctx context.Context, in *servicepb.PartitionName) (*servicepb.BoolResponse, error) {
|
||||
log.Println("has partition: ", in)
|
||||
hpt := &HasPartitionTask{
|
||||
Condition: NewTaskCondition(ctx),
|
||||
HasPartitionRequest: internalpb.HasPartitionRequest{
|
||||
@ -478,6 +488,7 @@ func (p *Proxy) HasPartition(ctx context.Context, in *servicepb.PartitionName) (
|
||||
}
|
||||
|
||||
func (p *Proxy) DescribePartition(ctx context.Context, in *servicepb.PartitionName) (*servicepb.PartitionDescription, error) {
|
||||
log.Println("describe partition: ", in)
|
||||
dpt := &DescribePartitionTask{
|
||||
Condition: NewTaskCondition(ctx),
|
||||
DescribePartitionRequest: internalpb.DescribePartitionRequest{
|
||||
@ -532,6 +543,7 @@ func (p *Proxy) DescribePartition(ctx context.Context, in *servicepb.PartitionNa
|
||||
}
|
||||
|
||||
func (p *Proxy) ShowPartitions(ctx context.Context, req *servicepb.CollectionName) (*servicepb.StringListResponse, error) {
|
||||
log.Println("show partitions: ", req)
|
||||
spt := &ShowPartitionsTask{
|
||||
Condition: NewTaskCondition(ctx),
|
||||
ShowPartitionRequest: internalpb.ShowPartitionRequest{
|
||||
|
@ -4,30 +4,26 @@ import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/zilliztech/milvus-distributed/internal/allocator"
|
||||
"github.com/zilliztech/milvus-distributed/internal/errors"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/internalpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/masterpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/servicepb"
|
||||
)
|
||||
|
||||
type Cache interface {
|
||||
Hit(collectionName string) bool
|
||||
Get(collectionName string) (*servicepb.CollectionDescription, error)
|
||||
Update(collectionName string) error
|
||||
Sync(collectionName string) error
|
||||
Update(collectionName string, desc *servicepb.CollectionDescription) error
|
||||
Remove(collectionName string) error
|
||||
}
|
||||
|
||||
var globalMetaCache Cache
|
||||
|
||||
type SimpleMetaCache struct {
|
||||
mu sync.RWMutex
|
||||
proxyID UniqueID
|
||||
metas map[string]*servicepb.CollectionDescription // collection name to schema
|
||||
masterClient masterpb.MasterClient
|
||||
reqIDAllocator *allocator.IDAllocator
|
||||
tsoAllocator *allocator.TimestampAllocator
|
||||
ctx context.Context
|
||||
mu sync.RWMutex
|
||||
metas map[string]*servicepb.CollectionDescription // collection name to schema
|
||||
ctx context.Context
|
||||
proxyInstance *Proxy
|
||||
}
|
||||
|
||||
func (metaCache *SimpleMetaCache) Hit(collectionName string) bool {
|
||||
@ -47,58 +43,34 @@ func (metaCache *SimpleMetaCache) Get(collectionName string) (*servicepb.Collect
|
||||
return schema, nil
|
||||
}
|
||||
|
||||
func (metaCache *SimpleMetaCache) Update(collectionName string) error {
|
||||
reqID, err := metaCache.reqIDAllocator.AllocOne()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ts, err := metaCache.tsoAllocator.AllocOne()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
hasCollectionReq := &internalpb.HasCollectionRequest{
|
||||
MsgType: internalpb.MsgType_kHasCollection,
|
||||
ReqID: reqID,
|
||||
Timestamp: ts,
|
||||
ProxyID: metaCache.proxyID,
|
||||
CollectionName: &servicepb.CollectionName{
|
||||
CollectionName: collectionName,
|
||||
func (metaCache *SimpleMetaCache) Sync(collectionName string) error {
|
||||
dct := &DescribeCollectionTask{
|
||||
Condition: NewTaskCondition(metaCache.ctx),
|
||||
DescribeCollectionRequest: internalpb.DescribeCollectionRequest{
|
||||
MsgType: internalpb.MsgType_kDescribeCollection,
|
||||
CollectionName: &servicepb.CollectionName{
|
||||
CollectionName: collectionName,
|
||||
},
|
||||
},
|
||||
masterClient: metaCache.proxyInstance.masterClient,
|
||||
}
|
||||
has, err := metaCache.masterClient.HasCollection(metaCache.ctx, hasCollectionReq)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !has.Value {
|
||||
return errors.New("collection " + collectionName + " not exists")
|
||||
}
|
||||
var cancel func()
|
||||
dct.ctx, cancel = context.WithTimeout(metaCache.ctx, reqTimeoutInterval)
|
||||
defer cancel()
|
||||
|
||||
reqID, err = metaCache.reqIDAllocator.AllocOne()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ts, err = metaCache.tsoAllocator.AllocOne()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req := &internalpb.DescribeCollectionRequest{
|
||||
MsgType: internalpb.MsgType_kDescribeCollection,
|
||||
ReqID: reqID,
|
||||
Timestamp: ts,
|
||||
ProxyID: metaCache.proxyID,
|
||||
CollectionName: &servicepb.CollectionName{
|
||||
CollectionName: collectionName,
|
||||
},
|
||||
}
|
||||
resp, err := metaCache.masterClient.DescribeCollection(metaCache.ctx, req)
|
||||
err := metaCache.proxyInstance.sched.DdQueue.Enqueue(dct)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return dct.WaitToFinish()
|
||||
}
|
||||
|
||||
func (metaCache *SimpleMetaCache) Update(collectionName string, desc *servicepb.CollectionDescription) error {
|
||||
metaCache.mu.Lock()
|
||||
defer metaCache.mu.Unlock()
|
||||
metaCache.metas[collectionName] = resp
|
||||
|
||||
metaCache.metas[collectionName] = desc
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -115,23 +87,14 @@ func (metaCache *SimpleMetaCache) Remove(collectionName string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func newSimpleMetaCache(ctx context.Context,
|
||||
mCli masterpb.MasterClient,
|
||||
idAllocator *allocator.IDAllocator,
|
||||
tsoAllocator *allocator.TimestampAllocator) *SimpleMetaCache {
|
||||
func newSimpleMetaCache(ctx context.Context, proxyInstance *Proxy) *SimpleMetaCache {
|
||||
return &SimpleMetaCache{
|
||||
metas: make(map[string]*servicepb.CollectionDescription),
|
||||
masterClient: mCli,
|
||||
reqIDAllocator: idAllocator,
|
||||
tsoAllocator: tsoAllocator,
|
||||
proxyID: Params.ProxyID(),
|
||||
ctx: ctx,
|
||||
metas: make(map[string]*servicepb.CollectionDescription),
|
||||
proxyInstance: proxyInstance,
|
||||
ctx: ctx,
|
||||
}
|
||||
}
|
||||
|
||||
func initGlobalMetaCache(ctx context.Context,
|
||||
mCli masterpb.MasterClient,
|
||||
idAllocator *allocator.IDAllocator,
|
||||
tsoAllocator *allocator.TimestampAllocator) {
|
||||
globalMetaCache = newSimpleMetaCache(ctx, mCli, idAllocator, tsoAllocator)
|
||||
func initGlobalMetaCache(ctx context.Context, proxyInstance *Proxy) {
|
||||
globalMetaCache = newSimpleMetaCache(ctx, proxyInstance)
|
||||
}
|
||||
|
@ -163,7 +163,7 @@ func (pt *ParamTable) convertRangeToSlice(rangeStr, sep string) []int {
|
||||
panic(err)
|
||||
}
|
||||
var ret []int
|
||||
for i := start; i <= end; i++ {
|
||||
for i := start; i < end; i++ {
|
||||
ret = append(ret, i)
|
||||
}
|
||||
return ret
|
||||
|
@ -109,7 +109,7 @@ func (p *Proxy) startProxy() error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
initGlobalMetaCache(p.proxyLoopCtx, p.masterClient, p.idAllocator, p.tsoAllocator)
|
||||
initGlobalMetaCache(p.proxyLoopCtx, p)
|
||||
p.manipulationMsgStream.Start()
|
||||
p.queryMsgStream.Start()
|
||||
p.sched.Start()
|
||||
|
@ -119,7 +119,7 @@ func createCollection(t *testing.T, name string) {
|
||||
Name: name,
|
||||
Description: "no description",
|
||||
AutoID: true,
|
||||
Fields: make([]*schemapb.FieldSchema, 1),
|
||||
Fields: make([]*schemapb.FieldSchema, 2),
|
||||
}
|
||||
fieldName := "Field1"
|
||||
req.Fields[0] = &schemapb.FieldSchema{
|
||||
@ -127,6 +127,24 @@ func createCollection(t *testing.T, name string) {
|
||||
Description: "no description",
|
||||
DataType: schemapb.DataType_INT32,
|
||||
}
|
||||
fieldName = "vec"
|
||||
req.Fields[1] = &schemapb.FieldSchema{
|
||||
Name: fieldName,
|
||||
Description: "vector",
|
||||
DataType: schemapb.DataType_VECTOR_FLOAT,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "16",
|
||||
},
|
||||
},
|
||||
IndexParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "metric_type",
|
||||
Value: "L2",
|
||||
},
|
||||
},
|
||||
}
|
||||
resp, err := proxyClient.CreateCollection(ctx, req)
|
||||
assert.Nil(t, err)
|
||||
msg := "Create Collection " + name + " should succeed!"
|
||||
@ -139,7 +157,7 @@ func dropCollection(t *testing.T, name string) {
|
||||
}
|
||||
resp, err := proxyClient.DropCollection(ctx, req)
|
||||
assert.Nil(t, err)
|
||||
msg := "Drop Collection " + name + " should succeed!"
|
||||
msg := "Drop Collection " + name + " should succeed! err :" + resp.Reason
|
||||
assert.Equal(t, resp.ErrorCode, commonpb.ErrorCode_SUCCESS, msg)
|
||||
}
|
||||
|
||||
@ -152,6 +170,7 @@ func TestProxy_CreateCollection(t *testing.T) {
|
||||
go func(group *sync.WaitGroup) {
|
||||
defer group.Done()
|
||||
createCollection(t, collectionName)
|
||||
dropCollection(t, collectionName)
|
||||
}(&wg)
|
||||
}
|
||||
wg.Wait()
|
||||
@ -165,9 +184,11 @@ func TestProxy_HasCollection(t *testing.T) {
|
||||
wg.Add(1)
|
||||
go func(group *sync.WaitGroup) {
|
||||
defer group.Done()
|
||||
createCollection(t, collectionName)
|
||||
has := hasCollection(t, collectionName)
|
||||
msg := "Should has Collection " + collectionName
|
||||
assert.Equal(t, has, true, msg)
|
||||
dropCollection(t, collectionName)
|
||||
}(&wg)
|
||||
}
|
||||
wg.Wait()
|
||||
@ -182,6 +203,7 @@ func TestProxy_DescribeCollection(t *testing.T) {
|
||||
wg.Add(1)
|
||||
go func(group *sync.WaitGroup) {
|
||||
defer group.Done()
|
||||
createCollection(t, collectionName)
|
||||
has := hasCollection(t, collectionName)
|
||||
if has {
|
||||
resp, err := proxyClient.DescribeCollection(ctx, &servicepb.CollectionName{CollectionName: collectionName})
|
||||
@ -191,6 +213,7 @@ func TestProxy_DescribeCollection(t *testing.T) {
|
||||
msg := "Describe Collection " + strconv.Itoa(i) + " should succeed!"
|
||||
assert.Equal(t, resp.Status.ErrorCode, commonpb.ErrorCode_SUCCESS, msg)
|
||||
t.Logf("Describe Collection %v: %v", i, resp)
|
||||
dropCollection(t, collectionName)
|
||||
}
|
||||
}(&wg)
|
||||
}
|
||||
@ -206,6 +229,7 @@ func TestProxy_ShowCollections(t *testing.T) {
|
||||
wg.Add(1)
|
||||
go func(group *sync.WaitGroup) {
|
||||
defer group.Done()
|
||||
createCollection(t, collectionName)
|
||||
has := hasCollection(t, collectionName)
|
||||
if has {
|
||||
resp, err := proxyClient.ShowCollections(ctx, &commonpb.Empty{})
|
||||
@ -215,6 +239,7 @@ func TestProxy_ShowCollections(t *testing.T) {
|
||||
msg := "Show collections " + strconv.Itoa(i) + " should succeed!"
|
||||
assert.Equal(t, resp.Status.ErrorCode, commonpb.ErrorCode_SUCCESS, msg)
|
||||
t.Logf("Show collections %v: %v", i, resp)
|
||||
dropCollection(t, collectionName)
|
||||
}
|
||||
}(&wg)
|
||||
}
|
||||
@ -246,6 +271,7 @@ func TestProxy_Insert(t *testing.T) {
|
||||
}
|
||||
msg := "Insert into Collection " + strconv.Itoa(i) + " should succeed!"
|
||||
assert.Equal(t, resp.Status.ErrorCode, commonpb.ErrorCode_SUCCESS, msg)
|
||||
dropCollection(t, collectionName)
|
||||
}
|
||||
}(&wg)
|
||||
}
|
||||
@ -308,6 +334,7 @@ func TestProxy_Search(t *testing.T) {
|
||||
queryWg.Add(1)
|
||||
go func(group *sync.WaitGroup) {
|
||||
defer group.Done()
|
||||
//createCollection(t, collectionName)
|
||||
has := hasCollection(t, collectionName)
|
||||
if !has {
|
||||
createCollection(t, collectionName)
|
||||
@ -315,6 +342,7 @@ func TestProxy_Search(t *testing.T) {
|
||||
resp, err := proxyClient.Search(ctx, req)
|
||||
t.Logf("response of search collection %v: %v", i, resp)
|
||||
assert.Nil(t, err)
|
||||
dropCollection(t, collectionName)
|
||||
}(&queryWg)
|
||||
}
|
||||
|
||||
@ -328,9 +356,9 @@ func TestProxy_Search(t *testing.T) {
|
||||
func TestProxy_AssignSegID(t *testing.T) {
|
||||
collectionName := "CreateCollection1"
|
||||
createCollection(t, collectionName)
|
||||
testNum := 4
|
||||
testNum := 1
|
||||
for i := 0; i < testNum; i++ {
|
||||
segID, err := proxyServer.segAssigner.GetSegmentID(collectionName, "default", int32(i), 200000)
|
||||
segID, err := proxyServer.segAssigner.GetSegmentID(collectionName, Params.defaultPartitionTag(), int32(i), 200000)
|
||||
assert.Nil(t, err)
|
||||
fmt.Println("segID", segID)
|
||||
}
|
||||
@ -345,6 +373,7 @@ func TestProxy_DropCollection(t *testing.T) {
|
||||
wg.Add(1)
|
||||
go func(group *sync.WaitGroup) {
|
||||
defer group.Done()
|
||||
createCollection(t, collectionName)
|
||||
has := hasCollection(t, collectionName)
|
||||
if has {
|
||||
dropCollection(t, collectionName)
|
||||
@ -357,27 +386,14 @@ func TestProxy_DropCollection(t *testing.T) {
|
||||
func TestProxy_PartitionGRPC(t *testing.T) {
|
||||
var wg sync.WaitGroup
|
||||
collName := "collPartTest"
|
||||
filedName := "collPartTestF1"
|
||||
collReq := &schemapb.CollectionSchema{
|
||||
Name: collName,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
&schemapb.FieldSchema{
|
||||
Name: filedName,
|
||||
Description: "",
|
||||
DataType: schemapb.DataType_VECTOR_FLOAT,
|
||||
},
|
||||
},
|
||||
}
|
||||
st, err := proxyClient.CreateCollection(ctx, collReq)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, st.ErrorCode, commonpb.ErrorCode_SUCCESS)
|
||||
createCollection(t, collName)
|
||||
|
||||
for i := 0; i < testNum; i++ {
|
||||
wg.Add(1)
|
||||
i := i
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
tag := fmt.Sprintf("partition-%d", i)
|
||||
tag := fmt.Sprintf("partition_%d", i)
|
||||
preq := &servicepb.PartitionName{
|
||||
CollectionName: collName,
|
||||
Tag: tag,
|
||||
@ -413,6 +429,7 @@ func TestProxy_PartitionGRPC(t *testing.T) {
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
dropCollection(t, collName)
|
||||
}
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
|
@ -91,7 +91,7 @@ func (it *InsertTask) PreExecute() error {
|
||||
func (it *InsertTask) Execute() error {
|
||||
collectionName := it.BaseInsertTask.CollectionName
|
||||
if !globalMetaCache.Hit(collectionName) {
|
||||
err := globalMetaCache.Update(collectionName)
|
||||
err := globalMetaCache.Sync(collectionName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -103,17 +103,20 @@ func (it *InsertTask) Execute() error {
|
||||
autoID := description.Schema.AutoID
|
||||
var rowIDBegin UniqueID
|
||||
var rowIDEnd UniqueID
|
||||
if autoID || true {
|
||||
rowNums := len(it.BaseInsertTask.RowData)
|
||||
rowIDBegin, rowIDEnd, _ = it.rowIDAllocator.Alloc(uint32(rowNums))
|
||||
it.BaseInsertTask.RowIDs = make([]UniqueID, rowNums)
|
||||
for i := rowIDBegin; i < rowIDEnd; i++ {
|
||||
offset := i - rowIDBegin
|
||||
it.BaseInsertTask.RowIDs[offset] = i
|
||||
}
|
||||
|
||||
if autoID {
|
||||
if it.HashValues == nil || len(it.HashValues) == 0 {
|
||||
it.HashValues = make([]uint32, 0)
|
||||
}
|
||||
rowNums := len(it.BaseInsertTask.RowData)
|
||||
rowIDBegin, rowIDEnd, _ = it.rowIDAllocator.Alloc(uint32(rowNums))
|
||||
it.BaseInsertTask.RowIDs = make([]UniqueID, rowNums)
|
||||
for i := rowIDBegin; i < rowIDEnd; i++ {
|
||||
offset := i - rowIDBegin
|
||||
it.BaseInsertTask.RowIDs[offset] = i
|
||||
hashValue, _ := typeutil.Hash32Int64(i)
|
||||
for _, rowID := range it.RowIDs {
|
||||
hashValue, _ := typeutil.Hash32Int64(rowID)
|
||||
it.HashValues = append(it.HashValues, hashValue)
|
||||
}
|
||||
}
|
||||
@ -126,6 +129,7 @@ func (it *InsertTask) Execute() error {
|
||||
}
|
||||
msgPack.Msgs[0] = tsMsg
|
||||
err = it.manipulationMsgStream.Produce(msgPack)
|
||||
|
||||
it.result = &servicepb.IntegerRangeResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_SUCCESS,
|
||||
@ -352,7 +356,7 @@ func (qt *QueryTask) SetTs(ts Timestamp) {
|
||||
func (qt *QueryTask) PreExecute() error {
|
||||
collectionName := qt.query.CollectionName
|
||||
if !globalMetaCache.Hit(collectionName) {
|
||||
err := globalMetaCache.Update(collectionName)
|
||||
err := globalMetaCache.Sync(collectionName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -605,14 +609,9 @@ func (dct *DescribeCollectionTask) PreExecute() error {
|
||||
}
|
||||
|
||||
func (dct *DescribeCollectionTask) Execute() error {
|
||||
if !globalMetaCache.Hit(dct.CollectionName.CollectionName) {
|
||||
err := globalMetaCache.Update(dct.CollectionName.CollectionName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
var err error
|
||||
dct.result, err = globalMetaCache.Get(dct.CollectionName.CollectionName)
|
||||
dct.result, err = dct.masterClient.DescribeCollection(dct.ctx, &dct.DescribeCollectionRequest)
|
||||
globalMetaCache.Update(dct.CollectionName.CollectionName, dct.result)
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -27,7 +27,7 @@ func TestTimeTick_Start(t *testing.T) {
|
||||
|
||||
func TestTimeTick_Start2(t *testing.T) {
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
masterAddr := Params.MasterAddress()
|
||||
tsoAllocator, err := allocator.NewTimestampAllocator(ctx, masterAddr)
|
||||
assert.Nil(t, err)
|
||||
|
@ -9,7 +9,6 @@ import (
|
||||
)
|
||||
|
||||
func TestValidateCollectionName(t *testing.T) {
|
||||
Params.Init()
|
||||
assert.Nil(t, ValidateCollectionName("abc"))
|
||||
assert.Nil(t, ValidateCollectionName("_123abc"))
|
||||
assert.Nil(t, ValidateCollectionName("abc123_$"))
|
||||
@ -34,8 +33,8 @@ func TestValidateCollectionName(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestValidatePartitionTag(t *testing.T) {
|
||||
Params.Init()
|
||||
assert.Nil(t, ValidatePartitionTag("abc", true))
|
||||
assert.Nil(t, ValidatePartitionTag("123abc", true))
|
||||
assert.Nil(t, ValidatePartitionTag("_123abc", true))
|
||||
assert.Nil(t, ValidatePartitionTag("abc123_$", true))
|
||||
|
||||
@ -44,7 +43,6 @@ func TestValidatePartitionTag(t *testing.T) {
|
||||
longName[i] = 'a'
|
||||
}
|
||||
invalidNames := []string{
|
||||
"123abc",
|
||||
"$abc",
|
||||
"_12 ac",
|
||||
" ",
|
||||
@ -62,7 +60,6 @@ func TestValidatePartitionTag(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestValidateFieldName(t *testing.T) {
|
||||
Params.Init()
|
||||
assert.Nil(t, ValidateFieldName("abc"))
|
||||
assert.Nil(t, ValidateFieldName("_123abc"))
|
||||
|
||||
@ -86,7 +83,6 @@ func TestValidateFieldName(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestValidateDimension(t *testing.T) {
|
||||
Params.Init()
|
||||
assert.Nil(t, ValidateDimension(1, false))
|
||||
assert.Nil(t, ValidateDimension(Params.MaxDimension(), false))
|
||||
assert.Nil(t, ValidateDimension(8, true))
|
||||
|
@ -12,6 +12,7 @@ package querynode
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/internalpb"
|
||||
"log"
|
||||
"strconv"
|
||||
@ -229,6 +230,7 @@ func (colReplica *collectionReplicaImpl) addPartitionsByCollectionMeta(colMeta *
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
fmt.Println("add partition: ", tag)
|
||||
}
|
||||
|
||||
return nil
|
||||
@ -262,6 +264,7 @@ func (colReplica *collectionReplicaImpl) removePartitionsByCollectionMeta(colMet
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
fmt.Println("delete partition: ", tag)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -12,7 +12,7 @@ import (
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
|
||||
"github.com/zilliztech/milvus-distributed/internal/kv"
|
||||
etcdkv "github.com/zilliztech/milvus-distributed/internal/kv/etcd"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/etcdpb"
|
||||
"go.etcd.io/etcd/clientv3"
|
||||
"go.etcd.io/etcd/mvcc/mvccpb"
|
||||
@ -25,7 +25,7 @@ const (
|
||||
|
||||
type metaService struct {
|
||||
ctx context.Context
|
||||
kvBase *kv.EtcdKV
|
||||
kvBase *etcdkv.EtcdKV
|
||||
replica *collectionReplica
|
||||
}
|
||||
|
||||
@ -40,7 +40,7 @@ func newMetaService(ctx context.Context, replica *collectionReplica) *metaServic
|
||||
|
||||
return &metaService{
|
||||
ctx: ctx,
|
||||
kvBase: kv.NewEtcdKV(cli, MetaRootPath),
|
||||
kvBase: etcdkv.NewEtcdKV(cli, MetaRootPath),
|
||||
replica: replica,
|
||||
}
|
||||
}
|
||||
@ -145,7 +145,7 @@ func printSegmentStruct(obj *etcdpb.SegmentMeta) {
|
||||
}
|
||||
|
||||
func (mService *metaService) processCollectionCreate(id string, value string) {
|
||||
println(fmt.Sprintf("Create Collection:$%s$", id))
|
||||
//println(fmt.Sprintf("Create Collection:$%s$", id))
|
||||
|
||||
col := mService.collectionUnmarshal(value)
|
||||
if col != nil {
|
||||
@ -163,7 +163,7 @@ func (mService *metaService) processCollectionCreate(id string, value string) {
|
||||
}
|
||||
|
||||
func (mService *metaService) processSegmentCreate(id string, value string) {
|
||||
println("Create Segment: ", id)
|
||||
//println("Create Segment: ", id)
|
||||
|
||||
seg := mService.segmentUnmarshal(value)
|
||||
if !isSegmentChannelRangeInQueryNodeChannelRange(seg) {
|
||||
@ -182,7 +182,7 @@ func (mService *metaService) processSegmentCreate(id string, value string) {
|
||||
}
|
||||
|
||||
func (mService *metaService) processCreate(key string, msg string) {
|
||||
println("process create", key)
|
||||
//println("process create", key)
|
||||
if isCollectionObj(key) {
|
||||
objID := GetCollectionObjID(key)
|
||||
mService.processCollectionCreate(objID, msg)
|
||||
@ -214,7 +214,7 @@ func (mService *metaService) processSegmentModify(id string, value string) {
|
||||
}
|
||||
|
||||
func (mService *metaService) processCollectionModify(id string, value string) {
|
||||
println("Modify Collection: ", id)
|
||||
//println("Modify Collection: ", id)
|
||||
|
||||
col := mService.collectionUnmarshal(value)
|
||||
if col != nil {
|
||||
@ -242,7 +242,7 @@ func (mService *metaService) processModify(key string, msg string) {
|
||||
}
|
||||
|
||||
func (mService *metaService) processSegmentDelete(id string) {
|
||||
println("Delete segment: ", id)
|
||||
//println("Delete segment: ", id)
|
||||
|
||||
var segmentID, err = strconv.ParseInt(id, 10, 64)
|
||||
if err != nil {
|
||||
@ -257,7 +257,7 @@ func (mService *metaService) processSegmentDelete(id string) {
|
||||
}
|
||||
|
||||
func (mService *metaService) processCollectionDelete(id string) {
|
||||
println("Delete collection: ", id)
|
||||
//println("Delete collection: ", id)
|
||||
|
||||
var collectionID, err = strconv.ParseInt(id, 10, 64)
|
||||
if err != nil {
|
||||
@ -272,7 +272,7 @@ func (mService *metaService) processCollectionDelete(id string) {
|
||||
}
|
||||
|
||||
func (mService *metaService) processDelete(key string) {
|
||||
println("process delete")
|
||||
//println("process delete")
|
||||
|
||||
if isCollectionObj(key) {
|
||||
objID := GetCollectionObjID(key)
|
||||
|
@ -29,7 +29,7 @@ func createPlan(col Collection, dsl string) (*Plan, error) {
|
||||
if errorCode != 0 {
|
||||
errorMsg := C.GoString(status.error_msg)
|
||||
defer C.free(unsafe.Pointer(status.error_msg))
|
||||
return nil, errors.New("Insert failed, C runtime error detected, error code = " + strconv.Itoa(int(errorCode)) + ", error msg = " + errorMsg)
|
||||
return nil, errors.New("Create plan failed, C runtime error detected, error code = " + strconv.Itoa(int(errorCode)) + ", error msg = " + errorMsg)
|
||||
}
|
||||
|
||||
var newPlan = &Plan{cPlan: cPlan}
|
||||
@ -60,7 +60,7 @@ func parserPlaceholderGroup(plan *Plan, placeHolderBlob []byte) (*PlaceholderGro
|
||||
if errorCode != 0 {
|
||||
errorMsg := C.GoString(status.error_msg)
|
||||
defer C.free(unsafe.Pointer(status.error_msg))
|
||||
return nil, errors.New("Insert failed, C runtime error detected, error code = " + strconv.Itoa(int(errorCode)) + ", error msg = " + errorMsg)
|
||||
return nil, errors.New("Parser placeholder group failed, C runtime error detected, error code = " + strconv.Itoa(int(errorCode)) + ", error msg = " + errorMsg)
|
||||
}
|
||||
|
||||
var newPlaceholderGroup = &PlaceholderGroup{cPlaceholderGroup: cPlaceholderGroup}
|
||||
|
@ -10,6 +10,8 @@ package querynode
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"errors"
|
||||
"strconv"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
@ -21,26 +23,66 @@ type MarshaledHits struct {
|
||||
cMarshaledHits C.CMarshaledHits
|
||||
}
|
||||
|
||||
func reduceSearchResults(searchResults []*SearchResult, numSegments int64) *SearchResult {
|
||||
func reduceSearchResults(searchResults []*SearchResult, numSegments int64, inReduced []bool) error {
|
||||
cSearchResults := make([]C.CQueryResult, 0)
|
||||
for _, res := range searchResults {
|
||||
cSearchResults = append(cSearchResults, res.cQueryResult)
|
||||
}
|
||||
cSearchResultPtr := (*C.CQueryResult)(&cSearchResults[0])
|
||||
cNumSegments := C.long(numSegments)
|
||||
res := C.ReduceQueryResults(cSearchResultPtr, cNumSegments)
|
||||
return &SearchResult{cQueryResult: res}
|
||||
cInReduced := (*C.bool)(&inReduced[0])
|
||||
|
||||
status := C.ReduceQueryResults(cSearchResultPtr, cNumSegments, cInReduced)
|
||||
|
||||
errorCode := status.error_code
|
||||
|
||||
if errorCode != 0 {
|
||||
errorMsg := C.GoString(status.error_msg)
|
||||
defer C.free(unsafe.Pointer(status.error_msg))
|
||||
return errors.New("reduceSearchResults failed, C runtime error detected, error code = " + strconv.Itoa(int(errorCode)) + ", error msg = " + errorMsg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sr *SearchResult) reorganizeQueryResults(plan *Plan, placeholderGroups []*PlaceholderGroup) *MarshaledHits {
|
||||
func fillTargetEntry(plan *Plan, searchResults []*SearchResult, matchedSegments []*Segment, inReduced []bool) error {
|
||||
for i, value := range inReduced {
|
||||
if value {
|
||||
err := matchedSegments[i].fillTargetEntry(plan, searchResults[i])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func reorganizeQueryResults(plan *Plan, placeholderGroups []*PlaceholderGroup, searchResults []*SearchResult, numSegments int64, inReduced []bool) (*MarshaledHits, error) {
|
||||
cPlaceholderGroups := make([]C.CPlaceholderGroup, 0)
|
||||
for _, pg := range placeholderGroups {
|
||||
cPlaceholderGroups = append(cPlaceholderGroups, (*pg).cPlaceholderGroup)
|
||||
}
|
||||
cNumGroup := (C.long)(len(placeholderGroups))
|
||||
var cPlaceHolder = (*C.CPlaceholderGroup)(&cPlaceholderGroups[0])
|
||||
res := C.ReorganizeQueryResults(sr.cQueryResult, plan.cPlan, cPlaceHolder, cNumGroup)
|
||||
return &MarshaledHits{cMarshaledHits: res}
|
||||
var cPlaceHolderGroupPtr = (*C.CPlaceholderGroup)(&cPlaceholderGroups[0])
|
||||
var cNumGroup = (C.long)(len(placeholderGroups))
|
||||
|
||||
cSearchResults := make([]C.CQueryResult, 0)
|
||||
for _, res := range searchResults {
|
||||
cSearchResults = append(cSearchResults, res.cQueryResult)
|
||||
}
|
||||
cSearchResultPtr := (*C.CQueryResult)(&cSearchResults[0])
|
||||
|
||||
var cNumSegments = C.long(numSegments)
|
||||
var cInReduced = (*C.bool)(&inReduced[0])
|
||||
var cMarshaledHits C.CMarshaledHits
|
||||
|
||||
status := C.ReorganizeQueryResults(&cMarshaledHits, cPlaceHolderGroupPtr, cNumGroup, cSearchResultPtr, cInReduced, cNumSegments, plan.cPlan)
|
||||
errorCode := status.error_code
|
||||
|
||||
if errorCode != 0 {
|
||||
errorMsg := C.GoString(status.error_msg)
|
||||
defer C.free(unsafe.Pointer(status.error_msg))
|
||||
return nil, errors.New("reorganizeQueryResults failed, C runtime error detected, error code = " + strconv.Itoa(int(errorCode)) + ", error msg = " + errorMsg)
|
||||
}
|
||||
return &MarshaledHits{cMarshaledHits: cMarshaledHits}, nil
|
||||
}
|
||||
|
||||
func (mh *MarshaledHits) getHitsBlobSize() int64 {
|
||||
|
@ -107,15 +107,21 @@ func TestReduce_AllFunc(t *testing.T) {
|
||||
placeholderGroups = append(placeholderGroups, holder)
|
||||
|
||||
searchResults := make([]*SearchResult, 0)
|
||||
matchedSegment := make([]*Segment, 0)
|
||||
searchResult, err := segment.segmentSearch(plan, placeholderGroups, []Timestamp{0})
|
||||
assert.Nil(t, err)
|
||||
searchResults = append(searchResults, searchResult)
|
||||
matchedSegment = append(matchedSegment, segment)
|
||||
|
||||
reducedSearchResults := reduceSearchResults(searchResults, 1)
|
||||
assert.NotNil(t, reducedSearchResults)
|
||||
testReduce := make([]bool, len(searchResults))
|
||||
err = reduceSearchResults(searchResults, 1, testReduce)
|
||||
assert.Nil(t, err)
|
||||
err = fillTargetEntry(plan, searchResults, matchedSegment, testReduce)
|
||||
assert.Nil(t, err)
|
||||
|
||||
marshaledHits := reducedSearchResults.reorganizeQueryResults(plan, placeholderGroups)
|
||||
marshaledHits, err := reorganizeQueryResults(plan, placeholderGroups, searchResults, 1, testReduce)
|
||||
assert.NotNil(t, marshaledHits)
|
||||
assert.Nil(t, err)
|
||||
|
||||
hitsBlob, err := marshaledHits.getHitsBlob()
|
||||
assert.Nil(t, err)
|
||||
@ -137,7 +143,6 @@ func TestReduce_AllFunc(t *testing.T) {
|
||||
plan.delete()
|
||||
holder.delete()
|
||||
deleteSearchResults(searchResults)
|
||||
deleteSearchResults([]*SearchResult{reducedSearchResults})
|
||||
deleteMarshaledHits(marshaledHits)
|
||||
deleteSegment(segment)
|
||||
deleteCollection(collection)
|
||||
|
@ -139,7 +139,7 @@ func (ss *searchService) receiveSearchMsg() {
|
||||
err := ss.search(msg)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
err = ss.publishFailedSearchResult(msg)
|
||||
err = ss.publishFailedSearchResult(msg, err.Error())
|
||||
if err != nil {
|
||||
log.Println("publish FailedSearchResult failed, error message: ", err)
|
||||
}
|
||||
@ -191,7 +191,7 @@ func (ss *searchService) doUnsolvedMsgSearch() {
|
||||
err := ss.search(msg)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
err = ss.publishFailedSearchResult(msg)
|
||||
err = ss.publishFailedSearchResult(msg, err.Error())
|
||||
if err != nil {
|
||||
log.Println("publish FailedSearchResult failed, error message: ", err)
|
||||
}
|
||||
@ -238,6 +238,7 @@ func (ss *searchService) search(msg msgstream.TsMsg) error {
|
||||
placeholderGroups = append(placeholderGroups, placeholderGroup)
|
||||
|
||||
searchResults := make([]*SearchResult, 0)
|
||||
matchedSegments := make([]*Segment, 0)
|
||||
|
||||
for _, partitionTag := range partitionTags {
|
||||
hasPartition := (*ss.replica).hasPartition(collectionID, partitionTag)
|
||||
@ -257,6 +258,7 @@ func (ss *searchService) search(msg msgstream.TsMsg) error {
|
||||
return err
|
||||
}
|
||||
searchResults = append(searchResults, searchResult)
|
||||
matchedSegments = append(matchedSegments, segment)
|
||||
}
|
||||
}
|
||||
|
||||
@ -282,8 +284,20 @@ func (ss *searchService) search(msg msgstream.TsMsg) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
reducedSearchResult := reduceSearchResults(searchResults, int64(len(searchResults)))
|
||||
marshaledHits := reducedSearchResult.reorganizeQueryResults(plan, placeholderGroups)
|
||||
inReduced := make([]bool, len(searchResults))
|
||||
numSegment := int64(len(searchResults))
|
||||
err = reduceSearchResults(searchResults, numSegment, inReduced)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = fillTargetEntry(plan, searchResults, matchedSegments, inReduced)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
marshaledHits, err := reorganizeQueryResults(plan, placeholderGroups, searchResults, numSegment, inReduced)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
hitsBlob, err := marshaledHits.getHitsBlob()
|
||||
if err != nil {
|
||||
return err
|
||||
@ -291,12 +305,12 @@ func (ss *searchService) search(msg msgstream.TsMsg) error {
|
||||
|
||||
var offset int64 = 0
|
||||
for index := range placeholderGroups {
|
||||
hitBolbSizePeerQuery, err := marshaledHits.hitBlobSizeInGroup(int64(index))
|
||||
hitBlobSizePeerQuery, err := marshaledHits.hitBlobSizeInGroup(int64(index))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
hits := make([][]byte, 0)
|
||||
for _, len := range hitBolbSizePeerQuery {
|
||||
for _, len := range hitBlobSizePeerQuery {
|
||||
hits = append(hits, hitsBlob[offset:offset+len])
|
||||
//test code to checkout marshaled hits
|
||||
//marshaledHit := hitsBlob[offset:offset+len]
|
||||
@ -329,7 +343,6 @@ func (ss *searchService) search(msg msgstream.TsMsg) error {
|
||||
}
|
||||
|
||||
deleteSearchResults(searchResults)
|
||||
deleteSearchResults([]*SearchResult{reducedSearchResult})
|
||||
deleteMarshaledHits(marshaledHits)
|
||||
plan.delete()
|
||||
placeholderGroup.delete()
|
||||
@ -346,7 +359,7 @@ func (ss *searchService) publishSearchResult(msg msgstream.TsMsg) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ss *searchService) publishFailedSearchResult(msg msgstream.TsMsg) error {
|
||||
func (ss *searchService) publishFailedSearchResult(msg msgstream.TsMsg, errMsg string) error {
|
||||
msgPack := msgstream.MsgPack{}
|
||||
searchMsg, ok := msg.(*msgstream.SearchMsg)
|
||||
if !ok {
|
||||
@ -354,7 +367,7 @@ func (ss *searchService) publishFailedSearchResult(msg msgstream.TsMsg) error {
|
||||
}
|
||||
var results = internalpb.SearchResult{
|
||||
MsgType: internalpb.MsgType_kSearchResult,
|
||||
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR},
|
||||
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR, Reason: errMsg},
|
||||
ReqID: searchMsg.ReqID,
|
||||
ProxyID: searchMsg.ProxyID,
|
||||
QueryNodeID: searchMsg.ProxyID,
|
||||
|
@ -253,3 +253,242 @@ func TestSearch_Search(t *testing.T) {
|
||||
cancel()
|
||||
node.Close()
|
||||
}
|
||||
|
||||
func TestSearch_SearchMultiSegments(t *testing.T) {
|
||||
Params.Init()
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
// init query node
|
||||
pulsarURL, _ := Params.pulsarAddress()
|
||||
node := NewQueryNode(ctx, 0)
|
||||
|
||||
// init meta
|
||||
collectionName := "collection0"
|
||||
fieldVec := schemapb.FieldSchema{
|
||||
Name: "vec",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_VECTOR_FLOAT,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "16",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
fieldInt := schemapb.FieldSchema{
|
||||
Name: "age",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_INT32,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
schema := schemapb.CollectionSchema{
|
||||
Name: collectionName,
|
||||
AutoID: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
&fieldVec, &fieldInt,
|
||||
},
|
||||
}
|
||||
|
||||
collectionMeta := etcdpb.CollectionMeta{
|
||||
ID: UniqueID(0),
|
||||
Schema: &schema,
|
||||
CreateTime: Timestamp(0),
|
||||
SegmentIDs: []UniqueID{0},
|
||||
PartitionTags: []string{"default"},
|
||||
}
|
||||
|
||||
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
|
||||
assert.NotEqual(t, "", collectionMetaBlob)
|
||||
|
||||
var err = (*node.replica).addCollection(&collectionMeta, collectionMetaBlob)
|
||||
assert.NoError(t, err)
|
||||
|
||||
collection, err := (*node.replica).getCollectionByName(collectionName)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, collection.meta.Schema.Name, "collection0")
|
||||
assert.Equal(t, collection.meta.ID, UniqueID(0))
|
||||
assert.Equal(t, (*node.replica).getCollectionNum(), 1)
|
||||
|
||||
err = (*node.replica).addPartition(collection.ID(), collectionMeta.PartitionTags[0])
|
||||
assert.NoError(t, err)
|
||||
|
||||
segmentID := UniqueID(0)
|
||||
err = (*node.replica).addSegment(segmentID, collectionMeta.PartitionTags[0], UniqueID(0))
|
||||
assert.NoError(t, err)
|
||||
|
||||
// test data generate
|
||||
const msgLength = 1024
|
||||
const receiveBufSize = 1024
|
||||
const DIM = 16
|
||||
insertProducerChannels := Params.insertChannelNames()
|
||||
searchProducerChannels := Params.searchChannelNames()
|
||||
var vec = [DIM]float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
|
||||
|
||||
// start search service
|
||||
dslString := "{\"bool\": { \n\"vector\": {\n \"vec\": {\n \"metric_type\": \"L2\", \n \"params\": {\n \"nprobe\": 10 \n},\n \"query\": \"$0\",\"topk\": 10 \n } \n } \n } \n }"
|
||||
var searchRawData1 []byte
|
||||
var searchRawData2 []byte
|
||||
for i, ele := range vec {
|
||||
buf := make([]byte, 4)
|
||||
binary.LittleEndian.PutUint32(buf, math.Float32bits(ele+float32(i*2)))
|
||||
searchRawData1 = append(searchRawData1, buf...)
|
||||
}
|
||||
for i, ele := range vec {
|
||||
buf := make([]byte, 4)
|
||||
binary.LittleEndian.PutUint32(buf, math.Float32bits(ele+float32(i*4)))
|
||||
searchRawData2 = append(searchRawData2, buf...)
|
||||
}
|
||||
placeholderValue := servicepb.PlaceholderValue{
|
||||
Tag: "$0",
|
||||
Type: servicepb.PlaceholderType_VECTOR_FLOAT,
|
||||
Values: [][]byte{searchRawData1, searchRawData2},
|
||||
}
|
||||
|
||||
placeholderGroup := servicepb.PlaceholderGroup{
|
||||
Placeholders: []*servicepb.PlaceholderValue{&placeholderValue},
|
||||
}
|
||||
|
||||
placeGroupByte, err := proto.Marshal(&placeholderGroup)
|
||||
if err != nil {
|
||||
log.Print("marshal placeholderGroup failed")
|
||||
}
|
||||
|
||||
query := servicepb.Query{
|
||||
CollectionName: "collection0",
|
||||
PartitionTags: []string{"default"},
|
||||
Dsl: dslString,
|
||||
PlaceholderGroup: placeGroupByte,
|
||||
}
|
||||
|
||||
queryByte, err := proto.Marshal(&query)
|
||||
if err != nil {
|
||||
log.Print("marshal query failed")
|
||||
}
|
||||
|
||||
blob := commonpb.Blob{
|
||||
Value: queryByte,
|
||||
}
|
||||
|
||||
searchMsg := &msgstream.SearchMsg{
|
||||
BaseMsg: msgstream.BaseMsg{
|
||||
HashValues: []uint32{0},
|
||||
},
|
||||
SearchRequest: internalpb.SearchRequest{
|
||||
MsgType: internalpb.MsgType_kSearch,
|
||||
ReqID: int64(1),
|
||||
ProxyID: int64(1),
|
||||
Timestamp: uint64(10 + 1000),
|
||||
ResultChannelID: int64(0),
|
||||
Query: &blob,
|
||||
},
|
||||
}
|
||||
|
||||
msgPackSearch := msgstream.MsgPack{}
|
||||
msgPackSearch.Msgs = append(msgPackSearch.Msgs, searchMsg)
|
||||
|
||||
searchStream := msgstream.NewPulsarMsgStream(ctx, receiveBufSize)
|
||||
searchStream.SetPulsarClient(pulsarURL)
|
||||
searchStream.CreatePulsarProducers(searchProducerChannels)
|
||||
searchStream.Start()
|
||||
err = searchStream.Produce(&msgPackSearch)
|
||||
assert.NoError(t, err)
|
||||
|
||||
node.searchService = newSearchService(node.ctx, node.replica)
|
||||
go node.searchService.start()
|
||||
|
||||
// start insert
|
||||
timeRange := TimeRange{
|
||||
timestampMin: 0,
|
||||
timestampMax: math.MaxUint64,
|
||||
}
|
||||
|
||||
insertMessages := make([]msgstream.TsMsg, 0)
|
||||
for i := 0; i < msgLength; i++ {
|
||||
segmentID := 0
|
||||
if i >= msgLength/2 {
|
||||
segmentID = 1
|
||||
}
|
||||
var rawData []byte
|
||||
for _, ele := range vec {
|
||||
buf := make([]byte, 4)
|
||||
binary.LittleEndian.PutUint32(buf, math.Float32bits(ele+float32(i*2)))
|
||||
rawData = append(rawData, buf...)
|
||||
}
|
||||
bs := make([]byte, 4)
|
||||
binary.LittleEndian.PutUint32(bs, 1)
|
||||
rawData = append(rawData, bs...)
|
||||
|
||||
var msg msgstream.TsMsg = &msgstream.InsertMsg{
|
||||
BaseMsg: msgstream.BaseMsg{
|
||||
HashValues: []uint32{
|
||||
uint32(i),
|
||||
},
|
||||
},
|
||||
InsertRequest: internalpb.InsertRequest{
|
||||
MsgType: internalpb.MsgType_kInsert,
|
||||
ReqID: int64(i),
|
||||
CollectionName: "collection0",
|
||||
PartitionTag: "default",
|
||||
SegmentID: int64(segmentID),
|
||||
ChannelID: int64(0),
|
||||
ProxyID: int64(0),
|
||||
Timestamps: []uint64{uint64(i + 1000)},
|
||||
RowIDs: []int64{int64(i)},
|
||||
RowData: []*commonpb.Blob{
|
||||
{Value: rawData},
|
||||
},
|
||||
},
|
||||
}
|
||||
insertMessages = append(insertMessages, msg)
|
||||
}
|
||||
|
||||
msgPack := msgstream.MsgPack{
|
||||
BeginTs: timeRange.timestampMin,
|
||||
EndTs: timeRange.timestampMax,
|
||||
Msgs: insertMessages,
|
||||
}
|
||||
|
||||
// generate timeTick
|
||||
timeTickMsgPack := msgstream.MsgPack{}
|
||||
baseMsg := msgstream.BaseMsg{
|
||||
BeginTimestamp: 0,
|
||||
EndTimestamp: 0,
|
||||
HashValues: []uint32{0},
|
||||
}
|
||||
timeTickResult := internalpb.TimeTickMsg{
|
||||
MsgType: internalpb.MsgType_kTimeTick,
|
||||
PeerID: UniqueID(0),
|
||||
Timestamp: math.MaxUint64,
|
||||
}
|
||||
timeTickMsg := &msgstream.TimeTickMsg{
|
||||
BaseMsg: baseMsg,
|
||||
TimeTickMsg: timeTickResult,
|
||||
}
|
||||
timeTickMsgPack.Msgs = append(timeTickMsgPack.Msgs, timeTickMsg)
|
||||
|
||||
// pulsar produce
|
||||
insertStream := msgstream.NewPulsarMsgStream(ctx, receiveBufSize)
|
||||
insertStream.SetPulsarClient(pulsarURL)
|
||||
insertStream.CreatePulsarProducers(insertProducerChannels)
|
||||
insertStream.Start()
|
||||
err = insertStream.Produce(&msgPack)
|
||||
assert.NoError(t, err)
|
||||
err = insertStream.Broadcast(&timeTickMsgPack)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// dataSync
|
||||
node.dataSyncService = newDataSyncService(node.ctx, node.replica)
|
||||
go node.dataSyncService.start()
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
cancel()
|
||||
node.Close()
|
||||
}
|
||||
|
@ -208,7 +208,7 @@ func (s *Segment) segmentSearch(plan *Plan,
|
||||
var cTimestamp = (*C.ulong)(×tamp[0])
|
||||
var cPlaceHolder = (*C.CPlaceholderGroup)(&cPlaceholderGroups[0])
|
||||
var cNumGroups = C.int(len(placeHolderGroups))
|
||||
cQueryResult := (*C.CQueryResult)(&searchResult.cQueryResult)
|
||||
var cQueryResult = (*C.CQueryResult)(&searchResult.cQueryResult)
|
||||
|
||||
var status = C.Search(s.segmentPtr, plan.cPlan, cPlaceHolder, cTimestamp, cNumGroups, cQueryResult)
|
||||
errorCode := status.error_code
|
||||
@ -221,3 +221,18 @@ func (s *Segment) segmentSearch(plan *Plan,
|
||||
|
||||
return &searchResult, nil
|
||||
}
|
||||
|
||||
func (s *Segment) fillTargetEntry(plan *Plan,
|
||||
result *SearchResult) error {
|
||||
|
||||
var status = C.FillTargetEntry(s.segmentPtr, plan.cPlan, result.cQueryResult)
|
||||
errorCode := status.error_code
|
||||
|
||||
if errorCode != 0 {
|
||||
errorMsg := C.GoString(status.error_msg)
|
||||
defer C.free(unsafe.Pointer(status.error_msg))
|
||||
return errors.New("FillTargetEntry failed, C runtime error detected, error code = " + strconv.Itoa(int(errorCode)) + ", error msg = " + errorMsg)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -20,7 +20,7 @@ import (
|
||||
|
||||
"github.com/spf13/cast"
|
||||
"github.com/spf13/viper"
|
||||
"github.com/zilliztech/milvus-distributed/internal/kv"
|
||||
memkv "github.com/zilliztech/milvus-distributed/internal/kv/mem"
|
||||
)
|
||||
|
||||
type Base interface {
|
||||
@ -33,16 +33,25 @@ type Base interface {
|
||||
}
|
||||
|
||||
type BaseTable struct {
|
||||
params *kv.MemoryKV
|
||||
params *memkv.MemoryKV
|
||||
}
|
||||
|
||||
func (gp *BaseTable) Init() {
|
||||
gp.params = kv.NewMemoryKV()
|
||||
gp.params = memkv.NewMemoryKV()
|
||||
err := gp.LoadYaml("config.yaml")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
minioAddress := os.Getenv("MINIO_ADDRESS")
|
||||
if minioAddress == "" {
|
||||
minioAddress = "localhost:9000"
|
||||
}
|
||||
err = gp.Save("_MinioAddress", minioAddress)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
etcdAddress := os.Getenv("ETCD_ADDRESS")
|
||||
if etcdAddress == "" {
|
||||
etcdAddress = "localhost:2379"
|
||||
|
@ -4,7 +4,7 @@ import (
|
||||
"path"
|
||||
"time"
|
||||
|
||||
"github.com/zilliztech/milvus-distributed/internal/kv"
|
||||
etcdkv "github.com/zilliztech/milvus-distributed/internal/kv/etcd"
|
||||
"go.etcd.io/etcd/clientv3"
|
||||
)
|
||||
|
||||
@ -25,10 +25,10 @@ func ParseTS(ts uint64) (time.Time, uint64) {
|
||||
return physicalTime, logical
|
||||
}
|
||||
|
||||
func NewTSOKVBase(etcdAddr []string, tsoRoot, subPath string) *kv.EtcdKV {
|
||||
func NewTSOKVBase(etcdAddr []string, tsoRoot, subPath string) *etcdkv.EtcdKV {
|
||||
client, _ := clientv3.New(clientv3.Config{
|
||||
Endpoints: etcdAddr,
|
||||
DialTimeout: 5 * time.Second,
|
||||
})
|
||||
return kv.NewEtcdKV(client, path.Join(tsoRoot, subPath))
|
||||
return etcdkv.NewEtcdKV(client, path.Join(tsoRoot, subPath))
|
||||
}
|
||||
|
@ -13,5 +13,5 @@ SCRIPTS_DIR="$( cd -P "$( dirname "$SOURCE" )" && pwd )"
|
||||
# ignore Minio,S3 unittes
|
||||
MILVUS_DIR="${SCRIPTS_DIR}/../internal/"
|
||||
echo $MILVUS_DIR
|
||||
#go test -cover "${MILVUS_DIR}/kv/..." "${MILVUS_DIR}/msgstream/..." "${MILVUS_DIR}/master/..." "${MILVUS_DIR}/querynode/..." "${MILVUS_DIR}/proxy/..." -failfast
|
||||
go test -cover "${MILVUS_DIR}/kv/..." "${MILVUS_DIR}/msgstream/..." "${MILVUS_DIR}/master/..." "${MILVUS_DIR}/querynode/..." -failfast
|
||||
go test -cover "${MILVUS_DIR}/kv/..." "${MILVUS_DIR}/msgstream/..." "${MILVUS_DIR}/master/..." "${MILVUS_DIR}/querynode/..." "${MILVUS_DIR}/proxy/..." -failfast
|
||||
#go test -cover "${MILVUS_DIR}/kv/..." "${MILVUS_DIR}/msgstream/..." "${MILVUS_DIR}/master/..." "${MILVUS_DIR}/querynode/..." -failfast
|
||||
|
0
tests/python/__init__.py
Normal file
0
tests/python/__init__.py
Normal file
235
tests/python/conftest.py
Normal file
235
tests/python/conftest.py
Normal file
@ -0,0 +1,235 @@
|
||||
import socket
|
||||
import pytest
|
||||
|
||||
from .utils import *
|
||||
|
||||
timeout = 60
|
||||
dimension = 128
|
||||
delete_timeout = 60
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption("--ip", action="store", default="localhost")
|
||||
parser.addoption("--service", action="store", default="")
|
||||
parser.addoption("--port", action="store", default=19530)
|
||||
parser.addoption("--http-port", action="store", default=19121)
|
||||
parser.addoption("--handler", action="store", default="GRPC")
|
||||
parser.addoption("--tag", action="store", default="all", help="only run tests matching the tag.")
|
||||
parser.addoption('--dry-run', action='store_true', default=False)
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
# register an additional marker
|
||||
config.addinivalue_line(
|
||||
"markers", "tag(name): mark test to run only matching the tag"
|
||||
)
|
||||
|
||||
|
||||
def pytest_runtest_setup(item):
|
||||
tags = list()
|
||||
for marker in item.iter_markers(name="tag"):
|
||||
for tag in marker.args:
|
||||
tags.append(tag)
|
||||
if tags:
|
||||
cmd_tag = item.config.getoption("--tag")
|
||||
if cmd_tag != "all" and cmd_tag not in tags:
|
||||
pytest.skip("test requires tag in {!r}".format(tags))
|
||||
|
||||
|
||||
def pytest_runtestloop(session):
|
||||
if session.config.getoption('--dry-run'):
|
||||
total_passed = 0
|
||||
total_skipped = 0
|
||||
test_file_to_items = {}
|
||||
for item in session.items:
|
||||
file_name, test_class, test_func = item.nodeid.split("::")
|
||||
if test_file_to_items.get(file_name) is not None:
|
||||
test_file_to_items[file_name].append(item)
|
||||
else:
|
||||
test_file_to_items[file_name] = [item]
|
||||
for k, items in test_file_to_items.items():
|
||||
skip_case = []
|
||||
should_pass_but_skipped = []
|
||||
skipped_other_reason = []
|
||||
|
||||
level2_case = []
|
||||
for item in items:
|
||||
if "pytestmark" in item.keywords.keys():
|
||||
markers = item.keywords["pytestmark"]
|
||||
skip_case.extend([item.nodeid for marker in markers if marker.name == 'skip'])
|
||||
should_pass_but_skipped.extend([item.nodeid for marker in markers if marker.name == 'skip' and len(marker.args) > 0 and marker.args[0] == "should pass"])
|
||||
skipped_other_reason.extend([item.nodeid for marker in markers if marker.name == 'skip' and (len(marker.args) < 1 or marker.args[0] != "should pass")])
|
||||
level2_case.extend([item.nodeid for marker in markers if marker.name == 'level' and marker.args[0] == 2])
|
||||
|
||||
print("")
|
||||
print(f"[{k}]:")
|
||||
print(f" Total : {len(items):13}")
|
||||
print(f" Passed : {len(items) - len(skip_case):13}")
|
||||
print(f" Skipped : {len(skip_case):13}")
|
||||
print(f" - should pass: {len(should_pass_but_skipped):4}")
|
||||
print(f" - not supported: {len(skipped_other_reason):4}")
|
||||
print(f" Level2 : {len(level2_case):13}")
|
||||
|
||||
print(f" ---------------------------------------")
|
||||
print(f" should pass but skipped: ")
|
||||
print("")
|
||||
for nodeid in should_pass_but_skipped:
|
||||
name, test_class, test_func = nodeid.split("::")
|
||||
print(f" {name:8}: {test_class}.{test_func}")
|
||||
print("")
|
||||
print(f"===============================================")
|
||||
total_passed += len(items) - len(skip_case)
|
||||
total_skipped += len(skip_case)
|
||||
|
||||
print("Total tests : ", len(session.items))
|
||||
print("Total passed: ", total_passed)
|
||||
print("Total skiped: ", total_skipped)
|
||||
return True
|
||||
|
||||
|
||||
def check_server_connection(request):
|
||||
ip = request.config.getoption("--ip")
|
||||
port = request.config.getoption("--port")
|
||||
|
||||
connected = True
|
||||
if ip and (ip not in ['localhost', '127.0.0.1']):
|
||||
try:
|
||||
socket.getaddrinfo(ip, port, 0, 0, socket.IPPROTO_TCP)
|
||||
except Exception as e:
|
||||
print("Socket connnet failed: %s" % str(e))
|
||||
connected = False
|
||||
return connected
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def connect(request):
|
||||
ip = request.config.getoption("--ip")
|
||||
service_name = request.config.getoption("--service")
|
||||
port = request.config.getoption("--port")
|
||||
http_port = request.config.getoption("--http-port")
|
||||
handler = request.config.getoption("--handler")
|
||||
if handler == "HTTP":
|
||||
port = http_port
|
||||
try:
|
||||
milvus = get_milvus(host=ip, port=port, handler=handler)
|
||||
# reset_build_index_threshold(milvus)
|
||||
except Exception as e:
|
||||
logging.getLogger().error(str(e))
|
||||
pytest.exit("Milvus server can not connected, exit pytest ...")
|
||||
def fin():
|
||||
try:
|
||||
milvus.close()
|
||||
pass
|
||||
except Exception as e:
|
||||
logging.getLogger().info(str(e))
|
||||
request.addfinalizer(fin)
|
||||
return milvus
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def dis_connect(request):
|
||||
ip = request.config.getoption("--ip")
|
||||
service_name = request.config.getoption("--service")
|
||||
port = request.config.getoption("--port")
|
||||
http_port = request.config.getoption("--http-port")
|
||||
handler = request.config.getoption("--handler")
|
||||
if handler == "HTTP":
|
||||
port = http_port
|
||||
milvus = get_milvus(host=ip, port=port, handler=handler)
|
||||
milvus.close()
|
||||
return milvus
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def args(request):
|
||||
ip = request.config.getoption("--ip")
|
||||
service_name = request.config.getoption("--service")
|
||||
port = request.config.getoption("--port")
|
||||
http_port = request.config.getoption("--http-port")
|
||||
handler = request.config.getoption("--handler")
|
||||
if handler == "HTTP":
|
||||
port = http_port
|
||||
args = {"ip": ip, "port": port, "handler": handler, "service_name": service_name}
|
||||
return args
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def milvus(request):
|
||||
ip = request.config.getoption("--ip")
|
||||
port = request.config.getoption("--port")
|
||||
http_port = request.config.getoption("--http-port")
|
||||
handler = request.config.getoption("--handler")
|
||||
if handler == "HTTP":
|
||||
port = http_port
|
||||
return get_milvus(host=ip, port=port, handler=handler)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def collection(request, connect):
|
||||
ori_collection_name = getattr(request.module, "collection_id", "test")
|
||||
collection_name = gen_unique_str(ori_collection_name)
|
||||
try:
|
||||
default_fields = gen_default_fields()
|
||||
connect.create_collection(collection_name, default_fields)
|
||||
except Exception as e:
|
||||
pytest.exit(str(e))
|
||||
def teardown():
|
||||
if connect.has_collection(collection_name):
|
||||
connect.drop_collection(collection_name, timeout=delete_timeout)
|
||||
request.addfinalizer(teardown)
|
||||
assert connect.has_collection(collection_name)
|
||||
return collection_name
|
||||
|
||||
|
||||
# customised id
|
||||
@pytest.fixture(scope="function")
|
||||
def id_collection(request, connect):
|
||||
ori_collection_name = getattr(request.module, "collection_id", "test")
|
||||
collection_name = gen_unique_str(ori_collection_name)
|
||||
try:
|
||||
fields = gen_default_fields(auto_id=False)
|
||||
connect.create_collection(collection_name, fields)
|
||||
except Exception as e:
|
||||
pytest.exit(str(e))
|
||||
def teardown():
|
||||
if connect.has_collection(collection_name):
|
||||
connect.drop_collection(collection_name, timeout=delete_timeout)
|
||||
request.addfinalizer(teardown)
|
||||
assert connect.has_collection(collection_name)
|
||||
return collection_name
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def binary_collection(request, connect):
|
||||
ori_collection_name = getattr(request.module, "collection_id", "test")
|
||||
collection_name = gen_unique_str(ori_collection_name)
|
||||
try:
|
||||
fields = gen_binary_default_fields()
|
||||
connect.create_collection(collection_name, fields)
|
||||
except Exception as e:
|
||||
pytest.exit(str(e))
|
||||
def teardown():
|
||||
collection_names = connect.list_collections()
|
||||
if connect.has_collection(collection_name):
|
||||
connect.drop_collection(collection_name, timeout=delete_timeout)
|
||||
request.addfinalizer(teardown)
|
||||
assert connect.has_collection(collection_name)
|
||||
return collection_name
|
||||
|
||||
|
||||
# customised id
|
||||
@pytest.fixture(scope="function")
|
||||
def binary_id_collection(request, connect):
|
||||
ori_collection_name = getattr(request.module, "collection_id", "test")
|
||||
collection_name = gen_unique_str(ori_collection_name)
|
||||
try:
|
||||
fields = gen_binary_default_fields(auto_id=False)
|
||||
connect.create_collection(collection_name, fields)
|
||||
except Exception as e:
|
||||
pytest.exit(str(e))
|
||||
def teardown():
|
||||
if connect.has_collection(collection_name):
|
||||
connect.drop_collection(collection_name, timeout=delete_timeout)
|
||||
request.addfinalizer(teardown)
|
||||
assert connect.has_collection(collection_name)
|
||||
return collection_name
|
22
tests/python/constants.py
Normal file
22
tests/python/constants.py
Normal file
@ -0,0 +1,22 @@
|
||||
from . import utils
|
||||
|
||||
default_fields = utils.gen_default_fields()
|
||||
default_binary_fields = utils.gen_binary_default_fields()
|
||||
|
||||
default_entity = utils.gen_entities(1)
|
||||
default_raw_binary_vector, default_binary_entity = utils.gen_binary_entities(1)
|
||||
|
||||
default_entity_row = utils.gen_entities_rows(1)
|
||||
default_raw_binary_vector_row, default_binary_entity_row = utils.gen_binary_entities_rows(1)
|
||||
|
||||
|
||||
default_entities = utils.gen_entities(utils.default_nb)
|
||||
default_raw_binary_vectors, default_binary_entities = utils.gen_binary_entities(utils.default_nb)
|
||||
|
||||
|
||||
default_entities_new = utils.gen_entities_new(utils.default_nb)
|
||||
default_raw_binary_vectors_new, default_binary_entities_new = utils.gen_binary_entities_new(utils.default_nb)
|
||||
|
||||
|
||||
default_entities_rows = utils.gen_entities_rows(utils.default_nb)
|
||||
default_raw_binary_vectors_rows, default_binary_entities_rows = utils.gen_binary_entities_rows(utils.default_nb)
|
127
tests/python/factorys.py
Normal file
127
tests/python/factorys.py
Normal file
@ -0,0 +1,127 @@
|
||||
# STL imports
|
||||
import random
|
||||
import string
|
||||
import time
|
||||
import datetime
|
||||
import random
|
||||
import struct
|
||||
import sys
|
||||
import uuid
|
||||
from functools import wraps
|
||||
|
||||
sys.path.append('..')
|
||||
# Third party imports
|
||||
import numpy as np
|
||||
import faker
|
||||
from faker.providers import BaseProvider
|
||||
|
||||
# local application imports
|
||||
from milvus.client.types import IndexType, MetricType, DataType
|
||||
|
||||
# grpc
|
||||
from milvus.client.grpc_handler import Prepare as gPrepare
|
||||
from milvus.grpc_gen import milvus_pb2
|
||||
|
||||
|
||||
def gen_vectors(num, dim):
|
||||
return [[random.random() for _ in range(dim)] for _ in range(num)]
|
||||
|
||||
|
||||
def gen_single_vector(dim):
|
||||
return [[random.random() for _ in range(dim)]]
|
||||
|
||||
|
||||
def gen_vector(nb, d, seed=np.random.RandomState(1234)):
|
||||
xb = seed.rand(nb, d).astype("float32")
|
||||
return xb.tolist()
|
||||
|
||||
|
||||
def gen_unique_str(str=None):
|
||||
prefix = "".join(random.choice(string.ascii_letters + string.digits) for _ in range(8))
|
||||
return prefix if str is None else str + "_" + prefix
|
||||
|
||||
|
||||
def get_current_day():
|
||||
return time.strftime('%Y-%m-%d', time.localtime())
|
||||
|
||||
|
||||
def get_last_day(day):
|
||||
tmp = datetime.datetime.now() - datetime.timedelta(days=day)
|
||||
return tmp.strftime('%Y-%m-%d')
|
||||
|
||||
|
||||
def get_next_day(day):
|
||||
tmp = datetime.datetime.now() + datetime.timedelta(days=day)
|
||||
return tmp.strftime('%Y-%m-%d')
|
||||
|
||||
|
||||
def gen_long_str(num):
|
||||
string = ''
|
||||
for _ in range(num):
|
||||
char = random.choice('tomorrow')
|
||||
string += char
|
||||
|
||||
|
||||
def gen_one_binary(topk):
|
||||
ids = [random.randrange(10000000, 99999999) for _ in range(topk)]
|
||||
distances = [random.random() for _ in range(topk)]
|
||||
return milvus_pb2.TopKQueryResult(struct.pack(str(topk) + 'l', *ids), struct.pack(str(topk) + 'd', *distances))
|
||||
|
||||
|
||||
def gen_nq_binaries(nq, topk):
|
||||
return [gen_one_binary(topk) for _ in range(nq)]
|
||||
|
||||
|
||||
def fake_query_bin_result(nq, topk):
|
||||
return gen_nq_binaries(nq, topk)
|
||||
|
||||
|
||||
class FakerProvider(BaseProvider):
|
||||
|
||||
def collection_name(self):
|
||||
return 'collection_names' + str(uuid.uuid4()).replace('-', '_')
|
||||
|
||||
def normal_field_name(self):
|
||||
return 'normal_field_names' + str(uuid.uuid4()).replace('-', '_')
|
||||
|
||||
def vector_field_name(self):
|
||||
return 'vector_field_names' + str(uuid.uuid4()).replace('-', '_')
|
||||
|
||||
def name(self):
|
||||
return 'name' + str(random.randint(1000, 9999))
|
||||
|
||||
def dim(self):
|
||||
return random.randint(0, 999)
|
||||
|
||||
|
||||
fake = faker.Faker()
|
||||
fake.add_provider(FakerProvider)
|
||||
|
||||
def collection_name_factory():
|
||||
return fake.collection_name()
|
||||
|
||||
def collection_schema_factory():
|
||||
param = {
|
||||
"fields": [
|
||||
{"name": fake.normal_field_name(),"type": DataType.INT32},
|
||||
{"name": fake.vector_field_name(),"type": DataType.FLOAT_VECTOR, "params": {"dim": random.randint(1, 999)}},
|
||||
],
|
||||
"auto_id": True,
|
||||
}
|
||||
return param
|
||||
|
||||
|
||||
def records_factory(dimension, nq):
|
||||
return [[random.random() for _ in range(dimension)] for _ in range(nq)]
|
||||
|
||||
|
||||
def time_it(func):
|
||||
@wraps(func)
|
||||
def inner(*args, **kwrgs):
|
||||
pref = time.perf_counter()
|
||||
result = func(*args, **kwrgs)
|
||||
delt = time.perf_counter() - pref
|
||||
print(f"[{func.__name__}][{delt:.4}s]")
|
||||
return result
|
||||
|
||||
return inner
|
19
tests/python/pytest.ini
Normal file
19
tests/python/pytest.ini
Normal file
@ -0,0 +1,19 @@
|
||||
[pytest]
|
||||
log_format = [%(asctime)s-%(levelname)s-%(name)s]: %(message)s (%(filename)s:%(lineno)s)
|
||||
log_date_format = %Y-%m-%d %H:%M:%S
|
||||
|
||||
# cli arguments. `-x`-stop test when error occurred;
|
||||
addopts = -x
|
||||
|
||||
testpaths = .
|
||||
|
||||
log_cli = true
|
||||
log_level = 10
|
||||
|
||||
timeout = 360
|
||||
|
||||
markers =
|
||||
level: test level
|
||||
serial
|
||||
|
||||
; level = 1
|
8
tests/python/requirements.txt
Normal file
8
tests/python/requirements.txt
Normal file
@ -0,0 +1,8 @@
|
||||
grpcio==1.26.0
|
||||
grpcio-tools==1.26.0
|
||||
numpy==1.18.1
|
||||
pytest==5.3.4
|
||||
pytest-cov==2.8.1
|
||||
pytest-timeout==1.3.4
|
||||
pymilvus-distributed==0.0.3
|
||||
sklearn==0.0
|
1164
tests/python/test_bulk_insert.py
Normal file
1164
tests/python/test_bulk_insert.py
Normal file
File diff suppressed because it is too large
Load Diff
314
tests/python/test_create_collection.py
Normal file
314
tests/python/test_create_collection.py
Normal file
@ -0,0 +1,314 @@
|
||||
import pytest
|
||||
from .utils import *
|
||||
from .constants import *
|
||||
|
||||
uid = "create_collection"
|
||||
|
||||
class TestCreateCollection:
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test `create_collection` function
|
||||
******************************************************************
|
||||
"""
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_single_filter_fields()
|
||||
)
|
||||
def get_filter_field(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_single_vector_fields()
|
||||
)
|
||||
def get_vector_field(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_segment_row_limits()
|
||||
)
|
||||
def get_segment_row_limit(self, request):
|
||||
yield request.param
|
||||
|
||||
def test_create_collection_fields(self, connect, get_filter_field, get_vector_field):
|
||||
'''
|
||||
target: test create normal collection with different fields
|
||||
method: create collection with diff fields: metric/field_type/...
|
||||
expected: no exception raised
|
||||
'''
|
||||
filter_field = get_filter_field
|
||||
# logging.getLogger().info(filter_field)
|
||||
vector_field = get_vector_field
|
||||
collection_name = gen_unique_str(uid)
|
||||
fields = {
|
||||
"fields": [filter_field, vector_field],
|
||||
}
|
||||
# logging.getLogger().info(fields)
|
||||
connect.create_collection(collection_name, fields)
|
||||
assert connect.has_collection(collection_name)
|
||||
|
||||
def test_create_collection_fields_create_index(self, connect, get_filter_field, get_vector_field):
|
||||
'''
|
||||
target: test create normal collection with different fields
|
||||
method: create collection with diff fields: metric/field_type/...
|
||||
expected: no exception raised
|
||||
'''
|
||||
filter_field = get_filter_field
|
||||
vector_field = get_vector_field
|
||||
collection_name = gen_unique_str(uid)
|
||||
fields = {
|
||||
"fields": [filter_field, vector_field],
|
||||
}
|
||||
connect.create_collection(collection_name, fields)
|
||||
assert connect.has_collection(collection_name)
|
||||
|
||||
@pytest.mark.skip("no segment_row_limit")
|
||||
def test_create_collection_segment_row_limit(self, connect):
|
||||
'''
|
||||
target: test create normal collection with different fields
|
||||
method: create collection with diff segment_row_limit
|
||||
expected: no exception raised
|
||||
'''
|
||||
collection_name = gen_unique_str(uid)
|
||||
fields = copy.deepcopy(default_fields)
|
||||
# fields["segment_row_limit"] = get_segment_row_limit
|
||||
connect.create_collection(collection_name, fields)
|
||||
assert connect.has_collection(collection_name)
|
||||
|
||||
@pytest.mark.skip("no flush")
|
||||
def _test_create_collection_auto_flush_disabled(self, connect):
|
||||
'''
|
||||
target: test create normal collection, with large auto_flush_interval
|
||||
method: create collection with corrent params
|
||||
expected: create status return ok
|
||||
'''
|
||||
disable_flush(connect)
|
||||
collection_name = gen_unique_str(uid)
|
||||
try:
|
||||
connect.create_collection(collection_name, default_fields)
|
||||
finally:
|
||||
enable_flush(connect)
|
||||
|
||||
def test_create_collection_after_insert(self, connect, collection):
|
||||
'''
|
||||
target: test insert vector, then create collection again
|
||||
method: insert vector and create collection
|
||||
expected: error raised
|
||||
'''
|
||||
# pdb.set_trace()
|
||||
connect.bulk_insert(collection, default_entity)
|
||||
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.create_collection(collection, default_fields)
|
||||
|
||||
def test_create_collection_after_insert_flush(self, connect, collection):
|
||||
'''
|
||||
target: test insert vector, then create collection again
|
||||
method: insert vector and create collection
|
||||
expected: error raised
|
||||
'''
|
||||
connect.bulk_insert(collection, default_entity)
|
||||
connect.flush([collection])
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.create_collection(collection, default_fields)
|
||||
|
||||
def test_create_collection_without_connection(self, dis_connect):
|
||||
'''
|
||||
target: test create collection, without connection
|
||||
method: create collection with correct params, with a disconnected instance
|
||||
expected: error raised
|
||||
'''
|
||||
collection_name = gen_unique_str(uid)
|
||||
with pytest.raises(Exception) as e:
|
||||
dis_connect.create_collection(collection_name, default_fields)
|
||||
|
||||
def test_create_collection_existed(self, connect):
|
||||
'''
|
||||
target: test create collection but the collection name have already existed
|
||||
method: create collection with the same collection_name
|
||||
expected: error raised
|
||||
'''
|
||||
collection_name = gen_unique_str(uid)
|
||||
connect.create_collection(collection_name, default_fields)
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.create_collection(collection_name, default_fields)
|
||||
|
||||
def test_create_after_drop_collection(self, connect, collection):
|
||||
'''
|
||||
target: create with the same collection name after collection dropped
|
||||
method: delete, then create
|
||||
expected: create success
|
||||
'''
|
||||
connect.drop_collection(collection)
|
||||
time.sleep(2)
|
||||
connect.create_collection(collection, default_fields)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_create_collection_multithread(self, connect):
|
||||
'''
|
||||
target: test create collection with multithread
|
||||
method: create collection using multithread,
|
||||
expected: collections are created
|
||||
'''
|
||||
threads_num = 8
|
||||
threads = []
|
||||
collection_names = []
|
||||
|
||||
def create():
|
||||
collection_name = gen_unique_str(uid)
|
||||
collection_names.append(collection_name)
|
||||
connect.create_collection(collection_name, default_fields)
|
||||
for i in range(threads_num):
|
||||
t = threading.Thread(target=create, args=())
|
||||
threads.append(t)
|
||||
t.start()
|
||||
time.sleep(0.2)
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
for item in collection_names:
|
||||
assert item in connect.list_collections()
|
||||
connect.drop_collection(item)
|
||||
|
||||
|
||||
class TestCreateCollectionInvalid(object):
|
||||
"""
|
||||
Test creating collections with invalid params
|
||||
"""
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_invalid_metric_types()
|
||||
)
|
||||
def get_metric_type(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_invalid_ints()
|
||||
)
|
||||
def get_segment_row_limit(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_invalid_ints()
|
||||
)
|
||||
def get_dim(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_invalid_strs()
|
||||
)
|
||||
def get_invalid_string(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_invalid_field_types()
|
||||
)
|
||||
def get_field_type(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.mark.level(2)
|
||||
@pytest.mark.skip("no segment row limit")
|
||||
def test_create_collection_with_invalid_segment_row_limit(self, connect, get_segment_row_limit):
|
||||
collection_name = gen_unique_str()
|
||||
fields = copy.deepcopy(default_fields)
|
||||
fields["segment_row_limit"] = get_segment_row_limit
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.create_collection(collection_name, fields)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_create_collection_with_invalid_dimension(self, connect, get_dim):
|
||||
dimension = get_dim
|
||||
collection_name = gen_unique_str()
|
||||
fields = copy.deepcopy(default_fields)
|
||||
fields["fields"][-1]["params"]["dim"] = dimension
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.create_collection(collection_name, fields)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_create_collection_with_invalid_collectionname(self, connect, get_invalid_string):
|
||||
collection_name = get_invalid_string
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.create_collection(collection_name, default_fields)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_create_collection_with_empty_collectionname(self, connect):
|
||||
collection_name = ''
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.create_collection(collection_name, default_fields)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_create_collection_with_none_collectionname(self, connect):
|
||||
collection_name = None
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.create_collection(collection_name, default_fields)
|
||||
|
||||
def test_create_collection_None(self, connect):
|
||||
'''
|
||||
target: test create collection but the collection name is None
|
||||
method: create collection, param collection_name is None
|
||||
expected: create raise error
|
||||
'''
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.create_collection(None, default_fields)
|
||||
|
||||
def test_create_collection_no_dimension(self, connect):
|
||||
'''
|
||||
target: test create collection with no dimension params
|
||||
method: create collection with corrent params
|
||||
expected: create status return ok
|
||||
'''
|
||||
collection_name = gen_unique_str(uid)
|
||||
fields = copy.deepcopy(default_fields)
|
||||
fields["fields"][-1]["params"].pop("dim")
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.create_collection(collection_name, fields)
|
||||
|
||||
@pytest.mark.skip("no segment row limit")
|
||||
def test_create_collection_no_segment_row_limit(self, connect):
|
||||
'''
|
||||
target: test create collection with no segment_row_limit params
|
||||
method: create collection with correct params
|
||||
expected: use default default_segment_row_limit
|
||||
'''
|
||||
collection_name = gen_unique_str(uid)
|
||||
fields = copy.deepcopy(default_fields)
|
||||
fields.pop("segment_row_limit")
|
||||
connect.create_collection(collection_name, fields)
|
||||
res = connect.get_collection_info(collection_name)
|
||||
# logging.getLogger().info(res)
|
||||
assert res["segment_row_limit"] == default_server_segment_row_limit
|
||||
|
||||
def test_create_collection_limit_fields(self, connect):
|
||||
collection_name = gen_unique_str(uid)
|
||||
limit_num = 64
|
||||
fields = copy.deepcopy(default_fields)
|
||||
for i in range(limit_num):
|
||||
field_name = gen_unique_str("field_name")
|
||||
field = {"name": field_name, "type": DataType.INT64}
|
||||
fields["fields"].append(field)
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.create_collection(collection_name, fields)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_create_collection_invalid_field_name(self, connect, get_invalid_string):
|
||||
collection_name = gen_unique_str(uid)
|
||||
fields = copy.deepcopy(default_fields)
|
||||
field_name = get_invalid_string
|
||||
field = {"name": field_name, "type": DataType.INT64}
|
||||
fields["fields"].append(field)
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.create_collection(collection_name, fields)
|
||||
|
||||
def test_create_collection_invalid_field_type(self, connect, get_field_type):
|
||||
collection_name = gen_unique_str(uid)
|
||||
fields = copy.deepcopy(default_fields)
|
||||
field_type = get_field_type
|
||||
field = {"name": "test_field", "type": field_type}
|
||||
fields["fields"].append(field)
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.create_collection(collection_name, fields)
|
98
tests/python/test_drop_collection.py
Normal file
98
tests/python/test_drop_collection.py
Normal file
@ -0,0 +1,98 @@
|
||||
import pytest
|
||||
from .utils import *
|
||||
from .constants import *
|
||||
|
||||
uniq_id = "drop_collection"
|
||||
|
||||
class TestDropCollection:
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test `drop_collection` function
|
||||
******************************************************************
|
||||
"""
|
||||
def test_drop_collection(self, connect, collection):
|
||||
'''
|
||||
target: test delete collection created with correct params
|
||||
method: create collection and then delete,
|
||||
assert the value returned by delete method
|
||||
expected: status ok, and no collection in collections
|
||||
'''
|
||||
connect.drop_collection(collection)
|
||||
time.sleep(2)
|
||||
assert not connect.has_collection(collection)
|
||||
|
||||
def test_drop_collection_without_connection(self, collection, dis_connect):
|
||||
'''
|
||||
target: test describe collection, without connection
|
||||
method: drop collection with correct params, with a disconnected instance
|
||||
expected: drop raise exception
|
||||
'''
|
||||
with pytest.raises(Exception) as e:
|
||||
dis_connect.drop_collection(collection)
|
||||
|
||||
def test_drop_collection_not_existed(self, connect):
|
||||
'''
|
||||
target: test if collection not created
|
||||
method: random a collection name, which not existed in db,
|
||||
assert the exception raised returned by drp_collection method
|
||||
expected: False
|
||||
'''
|
||||
collection_name = gen_unique_str(uniq_id)
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.drop_collection(collection_name)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_create_drop_collection_multithread(self, connect):
|
||||
'''
|
||||
target: test create and drop collection with multithread
|
||||
method: create and drop collection using multithread,
|
||||
expected: collections are created, and dropped
|
||||
'''
|
||||
threads_num = 8
|
||||
threads = []
|
||||
collection_names = []
|
||||
|
||||
def create():
|
||||
collection_name = gen_unique_str(uniq_id)
|
||||
collection_names.append(collection_name)
|
||||
connect.create_collection(collection_name, default_fields)
|
||||
connect.drop_collection(collection_name)
|
||||
|
||||
for i in range(threads_num):
|
||||
t = threading.Thread(target=create, args=())
|
||||
threads.append(t)
|
||||
t.start()
|
||||
time.sleep(0.2)
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
for item in collection_names:
|
||||
assert not connect.has_collection(item)
|
||||
|
||||
|
||||
class TestDropCollectionInvalid(object):
|
||||
"""
|
||||
Test has collection with invalid params
|
||||
"""
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_invalid_strs()
|
||||
)
|
||||
def get_collection_name(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_drop_collection_with_invalid_collectionname(self, connect, get_collection_name):
|
||||
collection_name = get_collection_name
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.has_collection(collection_name)
|
||||
|
||||
def test_drop_collection_with_empty_collectionname(self, connect):
|
||||
collection_name = ''
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.has_collection(collection_name)
|
||||
|
||||
def test_drop_collection_with_none_collectionname(self, connect):
|
||||
collection_name = None
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.has_collection(collection_name)
|
234
tests/python/test_get_collection_info.py
Normal file
234
tests/python/test_get_collection_info.py
Normal file
@ -0,0 +1,234 @@
|
||||
import pytest
|
||||
from .utils import *
|
||||
from .constants import *
|
||||
|
||||
uid = "collection_info"
|
||||
|
||||
class TestInfoBase:
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_single_filter_fields()
|
||||
)
|
||||
def get_filter_field(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_single_vector_fields()
|
||||
)
|
||||
def get_vector_field(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_segment_row_limits()
|
||||
)
|
||||
def get_segment_row_limit(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_simple_index()
|
||||
)
|
||||
def get_simple_index(self, request, connect):
|
||||
logging.getLogger().info(request.param)
|
||||
if str(connect._cmd("mode")) == "CPU":
|
||||
if request.param["index_type"] in index_cpu_not_support():
|
||||
pytest.skip("sq8h not support in CPU mode")
|
||||
return request.param
|
||||
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test `get_collection_info` function, no data in collection
|
||||
******************************************************************
|
||||
"""
|
||||
|
||||
@pytest.mark.skip("no segment row limit and type")
|
||||
def test_info_collection_fields(self, connect, get_filter_field, get_vector_field):
|
||||
'''
|
||||
target: test create normal collection with different fields, check info returned
|
||||
method: create collection with diff fields: metric/field_type/..., calling `get_collection_info`
|
||||
expected: no exception raised, and value returned correct
|
||||
'''
|
||||
filter_field = get_filter_field
|
||||
vector_field = get_vector_field
|
||||
collection_name = gen_unique_str(uid)
|
||||
fields = {
|
||||
"fields": [filter_field, vector_field],
|
||||
"segment_row_limit": default_segment_row_limit
|
||||
}
|
||||
connect.create_collection(collection_name, fields)
|
||||
res = connect.get_collection_info(collection_name)
|
||||
assert res['auto_id'] == True
|
||||
assert res['segment_row_limit'] == default_segment_row_limit
|
||||
assert len(res["fields"]) == 2
|
||||
for field in res["fields"]:
|
||||
if field["type"] == filter_field:
|
||||
assert field["name"] == filter_field["name"]
|
||||
elif field["type"] == vector_field:
|
||||
assert field["name"] == vector_field["name"]
|
||||
assert field["params"] == vector_field["params"]
|
||||
|
||||
@pytest.mark.skip("no segment row limit and type")
|
||||
def test_create_collection_segment_row_limit(self, connect, get_segment_row_limit):
|
||||
'''
|
||||
target: test create normal collection with different fields
|
||||
method: create collection with diff segment_row_limit
|
||||
expected: no exception raised
|
||||
'''
|
||||
collection_name = gen_unique_str(uid)
|
||||
fields = copy.deepcopy(default_fields)
|
||||
fields["segment_row_limit"] = get_segment_row_limit
|
||||
connect.create_collection(collection_name, fields)
|
||||
# assert segment row count
|
||||
res = connect.get_collection_info(collection_name)
|
||||
assert res['segment_row_limit'] == get_segment_row_limit
|
||||
|
||||
@pytest.mark.skip("no create Index")
|
||||
def test_get_collection_info_after_index_created(self, connect, collection, get_simple_index):
|
||||
connect.create_index(collection, default_float_vec_field_name, get_simple_index)
|
||||
res = connect.get_collection_info(collection)
|
||||
for field in res["fields"]:
|
||||
if field["name"] == default_float_vec_field_name:
|
||||
index = field["indexes"][0]
|
||||
assert index["index_type"] == get_simple_index["index_type"]
|
||||
assert index["metric_type"] == get_simple_index["metric_type"]
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_get_collection_info_without_connection(self, connect, collection, dis_connect):
|
||||
'''
|
||||
target: test get collection info, without connection
|
||||
method: calling get collection info with correct params, with a disconnected instance
|
||||
expected: get collection info raise exception
|
||||
'''
|
||||
with pytest.raises(Exception) as e:
|
||||
assert connect.get_collection_info(dis_connect, collection)
|
||||
|
||||
def test_get_collection_info_not_existed(self, connect):
|
||||
'''
|
||||
target: test if collection not created
|
||||
method: random a collection name, which not existed in db,
|
||||
assert the value returned by get_collection_info method
|
||||
expected: False
|
||||
'''
|
||||
collection_name = gen_unique_str(uid)
|
||||
with pytest.raises(Exception) as e:
|
||||
res = connect.get_collection_info(connect, collection_name)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_get_collection_info_multithread(self, connect):
|
||||
'''
|
||||
target: test create collection with multithread
|
||||
method: create collection using multithread,
|
||||
expected: collections are created
|
||||
'''
|
||||
threads_num = 4
|
||||
threads = []
|
||||
collection_name = gen_unique_str(uid)
|
||||
connect.create_collection(collection_name, default_fields)
|
||||
|
||||
def get_info():
|
||||
res = connect.get_collection_info(connect, collection_name)
|
||||
# assert
|
||||
|
||||
for i in range(threads_num):
|
||||
t = threading.Thread(target=get_info, args=())
|
||||
threads.append(t)
|
||||
t.start()
|
||||
time.sleep(0.2)
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test `get_collection_info` function, and insert data in collection
|
||||
******************************************************************
|
||||
"""
|
||||
|
||||
@pytest.mark.skip("no segment row limit and type")
|
||||
def test_info_collection_fields_after_insert(self, connect, get_filter_field, get_vector_field):
|
||||
'''
|
||||
target: test create normal collection with different fields, check info returned
|
||||
method: create collection with diff fields: metric/field_type/..., calling `get_collection_info`
|
||||
expected: no exception raised, and value returned correct
|
||||
'''
|
||||
filter_field = get_filter_field
|
||||
vector_field = get_vector_field
|
||||
collection_name = gen_unique_str(uid)
|
||||
fields = {
|
||||
"fields": [filter_field, vector_field],
|
||||
"segment_row_limit": default_segment_row_limit
|
||||
}
|
||||
connect.create_collection(collection_name, fields)
|
||||
entities = gen_entities_by_fields(fields["fields"], default_nb, vector_field["params"]["dim"])
|
||||
res_ids = connect.bulk_insert(collection_name, entities)
|
||||
connect.flush([collection_name])
|
||||
res = connect.get_collection_info(collection_name)
|
||||
assert res['auto_id'] == True
|
||||
assert res['segment_row_limit'] == default_segment_row_limit
|
||||
assert len(res["fields"]) == 2
|
||||
for field in res["fields"]:
|
||||
if field["type"] == filter_field:
|
||||
assert field["name"] == filter_field["name"]
|
||||
elif field["type"] == vector_field:
|
||||
assert field["name"] == vector_field["name"]
|
||||
assert field["params"] == vector_field["params"]
|
||||
|
||||
@pytest.mark.skip("not segment row limit")
|
||||
def test_create_collection_segment_row_limit_after_insert(self, connect, get_segment_row_limit):
|
||||
'''
|
||||
target: test create normal collection with different fields
|
||||
method: create collection with diff segment_row_limit
|
||||
expected: no exception raised
|
||||
'''
|
||||
collection_name = gen_unique_str(uid)
|
||||
fields = copy.deepcopy(default_fields)
|
||||
fields["segment_row_limit"] = get_segment_row_limit
|
||||
connect.create_collection(collection_name, fields)
|
||||
entities = gen_entities_by_fields(fields["fields"], default_nb, fields["fields"][-1]["params"]["dim"])
|
||||
res_ids = connect.bulk_insert(collection_name, entities)
|
||||
connect.flush([collection_name])
|
||||
res = connect.get_collection_info(collection_name)
|
||||
assert res['auto_id'] == True
|
||||
assert res['segment_row_limit'] == get_segment_row_limit
|
||||
|
||||
|
||||
class TestInfoInvalid(object):
|
||||
"""
|
||||
Test get collection info with invalid params
|
||||
"""
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_invalid_strs()
|
||||
)
|
||||
def get_collection_name(self, request):
|
||||
yield request.param
|
||||
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_get_collection_info_with_invalid_collectionname(self, connect, get_collection_name):
|
||||
collection_name = get_collection_name
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.get_collection_info(collection_name)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_get_collection_info_with_empty_collectionname(self, connect):
|
||||
collection_name = ''
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.get_collection_info(collection_name)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_get_collection_info_with_none_collectionname(self, connect):
|
||||
collection_name = None
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.get_collection_info(collection_name)
|
||||
|
||||
def test_get_collection_info_None(self, connect):
|
||||
'''
|
||||
target: test create collection but the collection name is None
|
||||
method: create collection, param collection_name is None
|
||||
expected: create raise error
|
||||
'''
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.get_collection_info(None)
|
93
tests/python/test_has_collection.py
Normal file
93
tests/python/test_has_collection.py
Normal file
@ -0,0 +1,93 @@
|
||||
import pytest
|
||||
from .utils import *
|
||||
from .constants import *
|
||||
|
||||
uid = "has_collection"
|
||||
|
||||
class TestHasCollection:
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test `has_collection` function
|
||||
******************************************************************
|
||||
"""
|
||||
def test_has_collection(self, connect, collection):
|
||||
'''
|
||||
target: test if the created collection existed
|
||||
method: create collection, assert the value returned by has_collection method
|
||||
expected: True
|
||||
'''
|
||||
assert connect.has_collection(collection)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_has_collection_without_connection(self, collection, dis_connect):
|
||||
'''
|
||||
target: test has collection, without connection
|
||||
method: calling has collection with correct params, with a disconnected instance
|
||||
expected: has collection raise exception
|
||||
'''
|
||||
with pytest.raises(Exception) as e:
|
||||
assert dis_connect.has_collection(collection)
|
||||
|
||||
def test_has_collection_not_existed(self, connect):
|
||||
'''
|
||||
target: test if collection not created
|
||||
method: random a collection name, which not existed in db,
|
||||
assert the value returned by has_collection method
|
||||
expected: False
|
||||
'''
|
||||
collection_name = gen_unique_str("test_collection")
|
||||
assert not connect.has_collection(collection_name)
|
||||
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_has_collection_multithread(self, connect):
|
||||
'''
|
||||
target: test create collection with multithread
|
||||
method: create collection using multithread,
|
||||
expected: collections are created
|
||||
'''
|
||||
threads_num = 4
|
||||
threads = []
|
||||
collection_name = gen_unique_str(uid)
|
||||
connect.create_collection(collection_name, default_fields)
|
||||
|
||||
def has():
|
||||
assert connect.has_collection(collection_name)
|
||||
# assert not assert_collection(connect, collection_name)
|
||||
for i in range(threads_num):
|
||||
t = MilvusTestThread(target=has, args=())
|
||||
threads.append(t)
|
||||
t.start()
|
||||
time.sleep(0.2)
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
|
||||
class TestHasCollectionInvalid(object):
|
||||
"""
|
||||
Test has collection with invalid params
|
||||
"""
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_invalid_strs()
|
||||
)
|
||||
def get_collection_name(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_has_collection_with_invalid_collectionname(self, connect, get_collection_name):
|
||||
collection_name = get_collection_name
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.has_collection(collection_name)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_has_collection_with_empty_collectionname(self, connect):
|
||||
collection_name = ''
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.has_collection(collection_name)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_has_collection_with_none_collectionname(self, connect):
|
||||
collection_name = None
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.has_collection(collection_name)
|
462
tests/python/test_insert.py
Normal file
462
tests/python/test_insert.py
Normal file
@ -0,0 +1,462 @@
|
||||
import pytest
|
||||
from .utils import *
|
||||
from .constants import *
|
||||
|
||||
ADD_TIMEOUT = 600
|
||||
uid = "test_insert"
|
||||
field_name = default_float_vec_field_name
|
||||
binary_field_name = default_binary_vec_field_name
|
||||
default_single_query = {
|
||||
"bool": {
|
||||
"must": [
|
||||
{"vector": {field_name: {"topk": 10, "query": gen_vectors(1, default_dim), "metric_type": "L2",
|
||||
"params": {"nprobe": 10}}}}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class TestInsertBase:
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test `insert` function
|
||||
******************************************************************
|
||||
"""
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_simple_index()
|
||||
)
|
||||
def get_simple_index(self, request, connect):
|
||||
if str(connect._cmd("mode")) == "CPU":
|
||||
if request.param["index_type"] in index_cpu_not_support():
|
||||
pytest.skip("CPU not support index_type: ivf_sq8h")
|
||||
return request.param
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_single_filter_fields()
|
||||
)
|
||||
def get_filter_field(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_single_vector_fields()
|
||||
)
|
||||
def get_vector_field(self, request):
|
||||
yield request.param
|
||||
|
||||
def test_add_vector_with_empty_vector(self, connect, collection):
|
||||
'''
|
||||
target: test add vectors with empty vectors list
|
||||
method: set empty vectors list as add method params
|
||||
expected: raises a Exception
|
||||
'''
|
||||
vector = []
|
||||
with pytest.raises(Exception) as e:
|
||||
status, ids = connect.insert(collection, vector)
|
||||
|
||||
def test_add_vector_with_None(self, connect, collection):
|
||||
'''
|
||||
target: test add vectors with None
|
||||
method: set None as add method params
|
||||
expected: raises a Exception
|
||||
'''
|
||||
vector = None
|
||||
with pytest.raises(Exception) as e:
|
||||
status, ids = connect.insert(collection, vector)
|
||||
|
||||
@pytest.mark.timeout(ADD_TIMEOUT)
|
||||
def test_insert_collection_not_existed(self, connect):
|
||||
'''
|
||||
target: test insert, with collection not existed
|
||||
method: insert entity into a random named collection
|
||||
expected: error raised
|
||||
'''
|
||||
collection_name = gen_unique_str(uid)
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.insert(collection_name, default_entities_rows)
|
||||
|
||||
@pytest.mark.timeout(ADD_TIMEOUT)
|
||||
def test_insert_drop_collection(self, connect, collection):
|
||||
'''
|
||||
target: test delete collection after insert vector
|
||||
method: insert vector and delete collection
|
||||
expected: no error raised
|
||||
'''
|
||||
ids = connect.insert(collection, default_entity_row)
|
||||
assert len(ids) == 1
|
||||
connect.drop_collection(collection)
|
||||
|
||||
@pytest.mark.timeout(ADD_TIMEOUT)
|
||||
def test_insert_sleep_drop_collection(self, connect, collection):
|
||||
'''
|
||||
target: test delete collection after insert vector for a while
|
||||
method: insert vector, sleep, and delete collection
|
||||
expected: no error raised
|
||||
'''
|
||||
ids = connect.insert(collection, default_entity_row)
|
||||
assert len(ids) == 1
|
||||
connect.flush([collection])
|
||||
connect.drop_collection(collection)
|
||||
|
||||
@pytest.mark.skip("create_index")
|
||||
@pytest.mark.timeout(ADD_TIMEOUT)
|
||||
def test_insert_create_index(self, connect, collection, get_simple_index):
|
||||
'''
|
||||
target: test build index insert after vector
|
||||
method: insert vector and build index
|
||||
expected: no error raised
|
||||
'''
|
||||
ids = connect.insert(collection, default_entities_rows)
|
||||
assert len(ids) == default_nb
|
||||
connect.flush([collection])
|
||||
connect.create_index(collection, field_name, get_simple_index)
|
||||
info = connect.get_collection_info(collection)
|
||||
fields = info["fields"]
|
||||
for field in fields:
|
||||
if field["name"] == field_name:
|
||||
assert field["indexes"][0] == get_simple_index
|
||||
|
||||
@pytest.mark.skip("create_index")
|
||||
@pytest.mark.timeout(ADD_TIMEOUT)
|
||||
def test_insert_after_create_index(self, connect, collection, get_simple_index):
|
||||
'''
|
||||
target: test build index insert after vector
|
||||
method: insert vector and build index
|
||||
expected: no error raised
|
||||
'''
|
||||
connect.create_index(collection, field_name, get_simple_index)
|
||||
ids = connect.insert(collection, default_entities_rows)
|
||||
assert len(ids) == default_nb
|
||||
info = connect.get_collection_info(collection)
|
||||
fields = info["fields"]
|
||||
for field in fields:
|
||||
if field["name"] == field_name:
|
||||
assert field["indexes"][0] == get_simple_index
|
||||
|
||||
@pytest.mark.skip(" todo fix search")
|
||||
@pytest.mark.timeout(ADD_TIMEOUT)
|
||||
def test_insert_search(self, connect, collection):
|
||||
'''
|
||||
target: test search vector after insert vector after a while
|
||||
method: insert vector, sleep, and search collection
|
||||
expected: no error raised
|
||||
'''
|
||||
ids = connect.insert(collection, default_entities_rows)
|
||||
connect.flush([collection])
|
||||
res = connect.search(collection, default_single_query)
|
||||
logging.getLogger().debug(res)
|
||||
assert res
|
||||
|
||||
@pytest.mark.skip("segment row count")
|
||||
def test_insert_segment_row_count(self, connect, collection):
|
||||
nb = default_segment_row_limit + 1
|
||||
res_ids = connect.insert(collection, gen_entities_rows(nb))
|
||||
connect.flush([collection])
|
||||
assert len(res_ids) == nb
|
||||
stats = connect.get_collection_stats(collection)
|
||||
assert len(stats['partitions'][0]['segments']) == 2
|
||||
for segment in stats['partitions'][0]['segments']:
|
||||
assert segment['row_count'] in [default_segment_row_limit, 1]
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=[
|
||||
1,
|
||||
2000
|
||||
],
|
||||
)
|
||||
def insert_count(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.mark.timeout(ADD_TIMEOUT)
|
||||
def test_insert_ids_not_match(self, connect, id_collection, insert_count):
|
||||
'''
|
||||
target: test insert vectors in collection, use customize ids
|
||||
method: create collection and insert vectors in it, check the ids returned and the collection length after vectors inserted
|
||||
expected: the length of ids and the collection row count
|
||||
'''
|
||||
nb = insert_count
|
||||
with pytest.raises(Exception) as e:
|
||||
res_ids = connect.insert(id_collection, gen_entities_rows(nb))
|
||||
|
||||
@pytest.mark.timeout(ADD_TIMEOUT)
|
||||
def test_insert_twice_ids_no_ids(self, connect, collection):
|
||||
'''
|
||||
target: check the result of insert, with params ids and no ids
|
||||
method: test insert vectors twice, use customize ids first, and then use no ids
|
||||
expected: error raised
|
||||
'''
|
||||
with pytest.raises(Exception) as e:
|
||||
res_ids = connect.insert(collection, gen_entities_rows(default_nb, _id=False))
|
||||
|
||||
@pytest.mark.timeout(ADD_TIMEOUT)
|
||||
def test_insert_tag(self, connect, collection):
|
||||
'''
|
||||
target: test insert entities in collection created before
|
||||
method: create collection and insert entities in it, with the partition_tag param
|
||||
expected: the collection row count equals to nq
|
||||
'''
|
||||
connect.create_partition(collection, default_tag)
|
||||
ids = connect.insert(collection, default_entities_rows, partition_tag=default_tag)
|
||||
assert len(ids) == default_nb
|
||||
assert connect.has_partition(collection, default_tag)
|
||||
|
||||
@pytest.mark.timeout(ADD_TIMEOUT)
|
||||
def test_insert_tag_with_ids(self, connect, id_collection):
|
||||
'''
|
||||
target: test insert entities in collection created before, insert with ids
|
||||
method: create collection and insert entities in it, with the partition_tag param
|
||||
expected: the collection row count equals to nq
|
||||
'''
|
||||
connect.create_partition(id_collection, default_tag)
|
||||
ids = [i for i in range(default_nb)]
|
||||
res_ids = connect.insert(id_collection, gen_entities_rows(default_nb, _id=False), partition_tag=default_tag)
|
||||
assert res_ids == ids
|
||||
|
||||
@pytest.mark.timeout(ADD_TIMEOUT)
|
||||
def test_insert_tag_not_existed(self, connect, collection):
|
||||
'''
|
||||
target: test insert entities in collection created before
|
||||
method: create collection and insert entities in it, with the not existed partition_tag param
|
||||
expected: error raised
|
||||
'''
|
||||
tag = gen_unique_str()
|
||||
with pytest.raises(Exception) as e:
|
||||
ids = connect.insert(collection, default_entities_rows, partition_tag=tag)
|
||||
|
||||
@pytest.mark.skip("todo support count entites")
|
||||
@pytest.mark.timeout(ADD_TIMEOUT)
|
||||
def test_insert_tag_existed(self, connect, collection):
|
||||
'''
|
||||
target: test insert entities in collection created before
|
||||
method: create collection and insert entities in it repeatly, with the partition_tag param
|
||||
expected: the collection row count equals to nq
|
||||
'''
|
||||
connect.create_partition(collection, default_tag)
|
||||
ids = connect.insert(collection, default_entities_rows, partition_tag=default_tag)
|
||||
ids = connect.insert(collection, default_entities_rows, partition_tag=default_tag)
|
||||
connect.flush([collection])
|
||||
res_count = connect.count_entities(collection)
|
||||
assert res_count == 2 * default_nb
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_insert_collection_not_existed(self, connect):
|
||||
'''
|
||||
target: test insert entities in collection, which not existed before
|
||||
method: insert entities collection not existed, check the status
|
||||
expected: error raised
|
||||
'''
|
||||
with pytest.raises(Exception) as e:
|
||||
ids = connect.insert(gen_unique_str("not_exist_collection"), default_entities_rows)
|
||||
|
||||
@pytest.mark.skip("todo support row data check")
|
||||
def test_insert_dim_not_matched(self, connect, collection):
|
||||
'''
|
||||
target: test insert entities, the vector dimension is not equal to the collection dimension
|
||||
method: the entities dimension is half of the collection dimension, check the status
|
||||
expected: error raised
|
||||
'''
|
||||
vectors = gen_vectors(default_nb, int(default_dim) // 2)
|
||||
insert_entities = copy.deepcopy(default_entities_rows)
|
||||
insert_entities[-1][default_float_vec_field_name] = vectors
|
||||
with pytest.raises(Exception) as e:
|
||||
ids = connect.insert(collection, insert_entities)
|
||||
|
||||
|
||||
class TestInsertBinary:
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_binary_index()
|
||||
)
|
||||
def get_binary_index(self, request):
|
||||
request.param["metric_type"] = "JACCARD"
|
||||
return request.param
|
||||
|
||||
@pytest.mark.skip("count entities")
|
||||
def test_insert_binary_entities(self, connect, binary_collection):
|
||||
'''
|
||||
target: test insert entities in binary collection
|
||||
method: create collection and insert binary entities in it
|
||||
expected: the collection row count equals to nb
|
||||
'''
|
||||
ids = connect.insert(binary_collection, default_binary_entities_rows)
|
||||
assert len(ids) == default_nb
|
||||
connect.flush()
|
||||
assert connect.count_entities(binary_collection) == default_nb
|
||||
|
||||
def test_insert_binary_tag(self, connect, binary_collection):
|
||||
'''
|
||||
target: test insert entities and create partition tag
|
||||
method: create collection and insert binary entities in it, with the partition_tag param
|
||||
expected: the collection row count equals to nb
|
||||
'''
|
||||
connect.create_partition(binary_collection, default_tag)
|
||||
ids = connect.insert(binary_collection, default_binary_entities_rows, partition_tag=default_tag)
|
||||
assert len(ids) == default_nb
|
||||
assert connect.has_partition(binary_collection, default_tag)
|
||||
|
||||
# TODO
|
||||
@pytest.mark.skip("count entities")
|
||||
@pytest.mark.level(2)
|
||||
def test_insert_binary_multi_times(self, connect, binary_collection):
|
||||
'''
|
||||
target: test insert entities multi times and final flush
|
||||
method: create collection and insert binary entity multi and final flush
|
||||
expected: the collection row count equals to nb
|
||||
'''
|
||||
for i in range(default_nb):
|
||||
ids = connect.insert(binary_collection, default_binary_entity_row)
|
||||
assert len(ids) == 1
|
||||
connect.flush([binary_collection])
|
||||
assert connect.count_entities(binary_collection) == default_nb
|
||||
|
||||
@pytest.mark.skip("create index")
|
||||
def test_insert_binary_after_create_index(self, connect, binary_collection, get_binary_index):
|
||||
'''
|
||||
target: test insert binary entities after build index
|
||||
method: build index and insert entities
|
||||
expected: no error raised
|
||||
'''
|
||||
connect.create_index(binary_collection, binary_field_name, get_binary_index)
|
||||
ids = connect.insert(binary_collection, default_binary_entities_rows)
|
||||
assert len(ids) == default_nb
|
||||
connect.flush([binary_collection])
|
||||
info = connect.get_collection_info(binary_collection)
|
||||
fields = info["fields"]
|
||||
for field in fields:
|
||||
if field["name"] == binary_field_name:
|
||||
assert field["indexes"][0] == get_binary_index
|
||||
|
||||
@pytest.mark.skip("create index")
|
||||
@pytest.mark.timeout(ADD_TIMEOUT)
|
||||
def test_insert_binary_create_index(self, connect, binary_collection, get_binary_index):
|
||||
'''
|
||||
target: test build index insert after vector
|
||||
method: insert vector and build index
|
||||
expected: no error raised
|
||||
'''
|
||||
ids = connect.insert(binary_collection, default_binary_entities_rows)
|
||||
assert len(ids) == default_nb
|
||||
connect.flush([binary_collection])
|
||||
connect.create_index(binary_collection, binary_field_name, get_binary_index)
|
||||
info = connect.get_collection_info(binary_collection)
|
||||
fields = info["fields"]
|
||||
for field in fields:
|
||||
if field["name"] == binary_field_name:
|
||||
assert field["indexes"][0] == get_binary_index
|
||||
|
||||
|
||||
class TestInsertInvalid(object):
|
||||
"""
|
||||
Test inserting vectors with invalid collection names
|
||||
"""
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_invalid_strs()
|
||||
)
|
||||
def get_collection_name(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_invalid_strs()
|
||||
)
|
||||
def get_tag_name(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_invalid_strs()
|
||||
)
|
||||
def get_field_name(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_invalid_strs()
|
||||
)
|
||||
def get_field_type(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_invalid_strs()
|
||||
)
|
||||
def get_field_int_value(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_invalid_ints()
|
||||
)
|
||||
def get_entity_id(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_invalid_vectors()
|
||||
)
|
||||
def get_field_vectors_value(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.mark.skip("todo support row data check")
|
||||
def test_insert_field_name_not_match(self, connect, collection):
|
||||
'''
|
||||
target: test insert, with field name not matched
|
||||
method: create collection and insert entities in it
|
||||
expected: raise an exception
|
||||
'''
|
||||
tmp_entity = copy.deepcopy(default_entity_row)
|
||||
tmp_entity[0]["string"] = "string"
|
||||
with pytest.raises(Exception):
|
||||
connect.insert(collection, default_entity_row)
|
||||
|
||||
def test_insert_with_invalid_collection_name(self, connect, get_collection_name):
|
||||
collection_name = get_collection_name
|
||||
with pytest.raises(Exception):
|
||||
connect.insert(collection_name, default_entity_row)
|
||||
|
||||
def test_insert_with_invalid_tag_name(self, connect, collection, get_tag_name):
|
||||
tag_name = get_tag_name
|
||||
connect.create_partition(collection, default_tag)
|
||||
if tag_name is not None:
|
||||
with pytest.raises(Exception):
|
||||
connect.insert(collection, default_entity_row, partition_tag=tag_name)
|
||||
else:
|
||||
connect.insert(collection, default_entity_row, partition_tag=tag_name)
|
||||
|
||||
@pytest.mark.skip("todo support row data check")
|
||||
def test_insert_with_less_field(self, connect, collection):
|
||||
tmp_entity = copy.deepcopy(default_entity_row)
|
||||
tmp_entity[0].pop(default_float_vec_field_name)
|
||||
with pytest.raises(Exception):
|
||||
connect.insert(collection, tmp_entity)
|
||||
|
||||
def test_insert_with_less_field_id(self, connect, id_collection):
|
||||
tmp_entity = copy.deepcopy(gen_entities_rows(default_nb, _id=False))
|
||||
tmp_entity[0].pop("_id")
|
||||
with pytest.raises(Exception):
|
||||
connect.insert(id_collection, tmp_entity)
|
||||
|
||||
def test_insert_with_more_field(self, connect, collection):
|
||||
tmp_entity = copy.deepcopy(default_entity_row)
|
||||
tmp_entity[0]["new_field"] = 1
|
||||
with pytest.raises(Exception):
|
||||
connect.insert(collection, tmp_entity)
|
||||
|
||||
def test_insert_with_more_field_id(self, connect, collection):
|
||||
tmp_entity = copy.deepcopy(default_entity_row)
|
||||
tmp_entity[0]["_id"] = 1
|
||||
with pytest.raises(Exception):
|
||||
connect.insert(collection, tmp_entity)
|
||||
|
||||
@pytest.mark.skip("todo support row data check")
|
||||
def test_insert_with_invalid_field_vector_value(self, connect, collection, get_field_vectors_value):
|
||||
tmp_entity = copy.deepcopy(default_entity_row)
|
||||
tmp_entity[0][default_float_vec_field_name][1] = get_field_vectors_value
|
||||
with pytest.raises(Exception):
|
||||
connect.insert(collection, tmp_entity)
|
86
tests/python/test_list_collections.py
Normal file
86
tests/python/test_list_collections.py
Normal file
@ -0,0 +1,86 @@
|
||||
import pytest
|
||||
from .utils import *
|
||||
from .constants import *
|
||||
|
||||
uid = "list_collections"
|
||||
|
||||
class TestListCollections:
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test `list_collections` function
|
||||
******************************************************************
|
||||
"""
|
||||
def test_list_collections(self, connect, collection):
|
||||
'''
|
||||
target: test list collections
|
||||
method: create collection, assert the value returned by list_collections method
|
||||
expected: True
|
||||
'''
|
||||
assert collection in connect.list_collections()
|
||||
|
||||
def test_list_collections_multi_collections(self, connect):
|
||||
'''
|
||||
target: test list collections
|
||||
method: create collection, assert the value returned by list_collections method
|
||||
expected: True
|
||||
'''
|
||||
collection_num = 50
|
||||
for i in range(collection_num):
|
||||
collection_name = gen_unique_str(uid)
|
||||
connect.create_collection(collection_name, default_fields)
|
||||
assert collection_name in connect.list_collections()
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_list_collections_without_connection(self, dis_connect):
|
||||
'''
|
||||
target: test list collections, without connection
|
||||
method: calling list collections with correct params, with a disconnected instance
|
||||
expected: list collections raise exception
|
||||
'''
|
||||
with pytest.raises(Exception) as e:
|
||||
dis_connect.list_collections()
|
||||
|
||||
def test_list_collections_not_existed(self, connect):
|
||||
'''
|
||||
target: test if collection not created
|
||||
method: random a collection name, which not existed in db,
|
||||
assert the value returned by list_collections method
|
||||
expected: False
|
||||
'''
|
||||
collection_name = gen_unique_str(uid)
|
||||
assert collection_name not in connect.list_collections()
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_list_collections_no_collection(self, connect):
|
||||
'''
|
||||
target: test show collections is correct or not, if no collection in db
|
||||
method: delete all collections,
|
||||
assert the value returned by list_collections method is equal to []
|
||||
expected: the status is ok, and the result is equal to []
|
||||
'''
|
||||
result = connect.list_collections()
|
||||
if result:
|
||||
for collection_name in result:
|
||||
assert connect.has_collection(collection_name)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_list_collections_multithread(self, connect):
|
||||
'''
|
||||
target: test create collection with multithread
|
||||
method: create collection using multithread,
|
||||
expected: collections are created
|
||||
'''
|
||||
threads_num = 4
|
||||
threads = []
|
||||
collection_name = gen_unique_str(uid)
|
||||
connect.create_collection(collection_name, default_fields)
|
||||
|
||||
def _list():
|
||||
assert collection_name in connect.list_collections()
|
||||
for i in range(threads_num):
|
||||
t = threading.Thread(target=_list, args=())
|
||||
threads.append(t)
|
||||
t.start()
|
||||
time.sleep(0.2)
|
||||
for t in threads:
|
||||
t.join()
|
396
tests/python/test_partition.py
Normal file
396
tests/python/test_partition.py
Normal file
@ -0,0 +1,396 @@
|
||||
import pytest
|
||||
from .utils import *
|
||||
from .constants import *
|
||||
|
||||
TIMEOUT = 120
|
||||
|
||||
class TestCreateBase:
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test `create_partition` function
|
||||
******************************************************************
|
||||
"""
|
||||
def test_create_partition(self, connect, collection):
|
||||
'''
|
||||
target: test create partition, check status returned
|
||||
method: call function: create_partition
|
||||
expected: status ok
|
||||
'''
|
||||
connect.create_partition(collection, default_tag)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
@pytest.mark.timeout(600)
|
||||
@pytest.mark.skip
|
||||
def test_create_partition_limit(self, connect, collection, args):
|
||||
'''
|
||||
target: test create partitions, check status returned
|
||||
method: call function: create_partition for 4097 times
|
||||
expected: exception raised
|
||||
'''
|
||||
threads_num = 8
|
||||
threads = []
|
||||
if args["handler"] == "HTTP":
|
||||
pytest.skip("skip in http mode")
|
||||
|
||||
def create(connect, threads_num):
|
||||
for i in range(max_partition_num // threads_num):
|
||||
tag_tmp = gen_unique_str()
|
||||
connect.create_partition(collection, tag_tmp)
|
||||
|
||||
for i in range(threads_num):
|
||||
m = get_milvus(host=args["ip"], port=args["port"], handler=args["handler"])
|
||||
t = threading.Thread(target=create, args=(m, threads_num, ))
|
||||
threads.append(t)
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
tag_tmp = gen_unique_str()
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.create_partition(collection, tag_tmp)
|
||||
|
||||
def test_create_partition_repeat(self, connect, collection):
|
||||
'''
|
||||
target: test create partition, check status returned
|
||||
method: call function: create_partition
|
||||
expected: status ok
|
||||
'''
|
||||
connect.create_partition(collection, default_tag)
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.create_partition(collection, default_tag)
|
||||
|
||||
def test_create_partition_collection_not_existed(self, connect):
|
||||
'''
|
||||
target: test create partition, its owner collection name not existed in db, check status returned
|
||||
method: call function: create_partition
|
||||
expected: status not ok
|
||||
'''
|
||||
collection_name = gen_unique_str()
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.create_partition(collection_name, default_tag)
|
||||
|
||||
def test_create_partition_tag_name_None(self, connect, collection):
|
||||
'''
|
||||
target: test create partition, tag name set None, check status returned
|
||||
method: call function: create_partition
|
||||
expected: status ok
|
||||
'''
|
||||
tag_name = None
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.create_partition(collection, tag_name)
|
||||
|
||||
def test_create_different_partition_tags(self, connect, collection):
|
||||
'''
|
||||
target: test create partition twice with different names
|
||||
method: call function: create_partition, and again
|
||||
expected: status ok
|
||||
'''
|
||||
connect.create_partition(collection, default_tag)
|
||||
tag_name = gen_unique_str()
|
||||
connect.create_partition(collection, tag_name)
|
||||
tag_list = connect.list_partitions(collection)
|
||||
assert default_tag in tag_list
|
||||
assert tag_name in tag_list
|
||||
assert "_default" in tag_list
|
||||
|
||||
@pytest.mark.skip("not support custom id")
|
||||
def test_create_partition_insert_default(self, connect, id_collection):
|
||||
'''
|
||||
target: test create partition, and insert vectors, check status returned
|
||||
method: call function: create_partition
|
||||
expected: status ok
|
||||
'''
|
||||
connect.create_partition(id_collection, default_tag)
|
||||
ids = [i for i in range(default_nb)]
|
||||
insert_ids = connect.bulk_insert(id_collection, default_entities, ids)
|
||||
assert len(insert_ids) == len(ids)
|
||||
|
||||
@pytest.mark.skip("not support custom id")
|
||||
def test_create_partition_insert_with_tag(self, connect, id_collection):
|
||||
'''
|
||||
target: test create partition, and insert vectors, check status returned
|
||||
method: call function: create_partition
|
||||
expected: status ok
|
||||
'''
|
||||
connect.create_partition(id_collection, default_tag)
|
||||
ids = [i for i in range(default_nb)]
|
||||
insert_ids = connect.bulk_insert(id_collection, default_entities, ids, partition_tag=default_tag)
|
||||
assert len(insert_ids) == len(ids)
|
||||
|
||||
def test_create_partition_insert_with_tag_not_existed(self, connect, collection):
|
||||
'''
|
||||
target: test create partition, and insert vectors, check status returned
|
||||
method: call function: create_partition
|
||||
expected: status not ok
|
||||
'''
|
||||
tag_new = "tag_new"
|
||||
connect.create_partition(collection, default_tag)
|
||||
ids = [i for i in range(default_nb)]
|
||||
with pytest.raises(Exception) as e:
|
||||
insert_ids = connect.bulk_insert(collection, default_entities, ids, partition_tag=tag_new)
|
||||
|
||||
@pytest.mark.skip("not support custom id")
|
||||
def test_create_partition_insert_same_tags(self, connect, id_collection):
|
||||
'''
|
||||
target: test create partition, and insert vectors, check status returned
|
||||
method: call function: create_partition
|
||||
expected: status ok
|
||||
'''
|
||||
connect.create_partition(id_collection, default_tag)
|
||||
ids = [i for i in range(default_nb)]
|
||||
insert_ids = connect.bulk_insert(id_collection, default_entities, ids, partition_tag=default_tag)
|
||||
ids = [(i+default_nb) for i in range(default_nb)]
|
||||
new_insert_ids = connect.bulk_insert(id_collection, default_entities, ids, partition_tag=default_tag)
|
||||
connect.flush([id_collection])
|
||||
res = connect.count_entities(id_collection)
|
||||
assert res == default_nb * 2
|
||||
|
||||
@pytest.mark.level(2)
|
||||
@pytest.mark.skip("not support count entities")
|
||||
def test_create_partition_insert_same_tags_two_collections(self, connect, collection):
|
||||
'''
|
||||
target: test create two partitions, and insert vectors with the same tag to each collection, check status returned
|
||||
method: call function: create_partition
|
||||
expected: status ok, collection length is correct
|
||||
'''
|
||||
connect.create_partition(collection, default_tag)
|
||||
collection_new = gen_unique_str()
|
||||
connect.create_collection(collection_new, default_fields)
|
||||
connect.create_partition(collection_new, default_tag)
|
||||
ids = connect.bulk_insert(collection, default_entities, partition_tag=default_tag)
|
||||
ids = connect.bulk_insert(collection_new, default_entities, partition_tag=default_tag)
|
||||
connect.flush([collection, collection_new])
|
||||
res = connect.count_entities(collection)
|
||||
assert res == default_nb
|
||||
res = connect.count_entities(collection_new)
|
||||
assert res == default_nb
|
||||
|
||||
|
||||
class TestShowBase:
|
||||
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test `list_partitions` function
|
||||
******************************************************************
|
||||
"""
|
||||
def test_list_partitions(self, connect, collection):
|
||||
'''
|
||||
target: test show partitions, check status and partitions returned
|
||||
method: create partition first, then call function: list_partitions
|
||||
expected: status ok, partition correct
|
||||
'''
|
||||
connect.create_partition(collection, default_tag)
|
||||
res = connect.list_partitions(collection)
|
||||
assert default_tag in res
|
||||
|
||||
def test_list_partitions_no_partition(self, connect, collection):
|
||||
'''
|
||||
target: test show partitions with collection name, check status and partitions returned
|
||||
method: call function: list_partitions
|
||||
expected: status ok, partitions correct
|
||||
'''
|
||||
res = connect.list_partitions(collection)
|
||||
assert len(res) == 1
|
||||
|
||||
def test_show_multi_partitions(self, connect, collection):
|
||||
'''
|
||||
target: test show partitions, check status and partitions returned
|
||||
method: create partitions first, then call function: list_partitions
|
||||
expected: status ok, partitions correct
|
||||
'''
|
||||
tag_new = gen_unique_str()
|
||||
connect.create_partition(collection, default_tag)
|
||||
connect.create_partition(collection, tag_new)
|
||||
res = connect.list_partitions(collection)
|
||||
assert default_tag in res
|
||||
assert tag_new in res
|
||||
|
||||
|
||||
class TestHasBase:
|
||||
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test `has_partition` function
|
||||
******************************************************************
|
||||
"""
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_invalid_strs()
|
||||
)
|
||||
def get_tag_name(self, request):
|
||||
yield request.param
|
||||
|
||||
def test_has_partition(self, connect, collection):
|
||||
'''
|
||||
target: test has_partition, check status and result
|
||||
method: create partition first, then call function: has_partition
|
||||
expected: status ok, result true
|
||||
'''
|
||||
connect.create_partition(collection, default_tag)
|
||||
res = connect.has_partition(collection, default_tag)
|
||||
logging.getLogger().info(res)
|
||||
assert res
|
||||
|
||||
def test_has_partition_multi_partitions(self, connect, collection):
|
||||
'''
|
||||
target: test has_partition, check status and result
|
||||
method: create partition first, then call function: has_partition
|
||||
expected: status ok, result true
|
||||
'''
|
||||
for tag_name in [default_tag, "tag_new", "tag_new_new"]:
|
||||
connect.create_partition(collection, tag_name)
|
||||
for tag_name in [default_tag, "tag_new", "tag_new_new"]:
|
||||
res = connect.has_partition(collection, tag_name)
|
||||
assert res
|
||||
|
||||
def test_has_partition_tag_not_existed(self, connect, collection):
|
||||
'''
|
||||
target: test has_partition, check status and result
|
||||
method: then call function: has_partition, with tag not existed
|
||||
expected: status ok, result empty
|
||||
'''
|
||||
res = connect.has_partition(collection, default_tag)
|
||||
logging.getLogger().info(res)
|
||||
assert not res
|
||||
|
||||
def test_has_partition_collection_not_existed(self, connect, collection):
|
||||
'''
|
||||
target: test has_partition, check status and result
|
||||
method: then call function: has_partition, with collection not existed
|
||||
expected: status not ok
|
||||
'''
|
||||
with pytest.raises(Exception) as e:
|
||||
res = connect.has_partition("not_existed_collection", default_tag)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_has_partition_with_invalid_tag_name(self, connect, collection, get_tag_name):
|
||||
'''
|
||||
target: test has partition, with invalid tag name, check status returned
|
||||
method: call function: has_partition
|
||||
expected: status ok
|
||||
'''
|
||||
tag_name = get_tag_name
|
||||
connect.create_partition(collection, default_tag)
|
||||
with pytest.raises(Exception) as e:
|
||||
res = connect.has_partition(collection, tag_name)
|
||||
|
||||
|
||||
class TestDropBase:
|
||||
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test `drop_partition` function
|
||||
******************************************************************
|
||||
"""
|
||||
def test_drop_partition(self, connect, collection):
|
||||
'''
|
||||
target: test drop partition, check status and partition if existed
|
||||
method: create partitions first, then call function: drop_partition
|
||||
expected: status ok, no partitions in db
|
||||
'''
|
||||
connect.create_partition(collection, default_tag)
|
||||
connect.drop_partition(collection, default_tag)
|
||||
res = connect.list_partitions(collection)
|
||||
tag_list = []
|
||||
assert default_tag not in tag_list
|
||||
|
||||
def test_drop_partition_tag_not_existed(self, connect, collection):
|
||||
'''
|
||||
target: test drop partition, but tag not existed
|
||||
method: create partitions first, then call function: drop_partition
|
||||
expected: status not ok
|
||||
'''
|
||||
connect.create_partition(collection, default_tag)
|
||||
new_tag = "new_tag"
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.drop_partition(collection, new_tag)
|
||||
|
||||
def test_drop_partition_tag_not_existed_A(self, connect, collection):
|
||||
'''
|
||||
target: test drop partition, but collection not existed
|
||||
method: create partitions first, then call function: drop_partition
|
||||
expected: status not ok
|
||||
'''
|
||||
connect.create_partition(collection, default_tag)
|
||||
new_collection = gen_unique_str()
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.drop_partition(new_collection, default_tag)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_drop_partition_repeatedly(self, connect, collection):
|
||||
'''
|
||||
target: test drop partition twice, check status and partition if existed
|
||||
method: create partitions first, then call function: drop_partition
|
||||
expected: status not ok, no partitions in db
|
||||
'''
|
||||
connect.create_partition(collection, default_tag)
|
||||
connect.drop_partition(collection, default_tag)
|
||||
time.sleep(2)
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.drop_partition(collection, default_tag)
|
||||
tag_list = connect.list_partitions(collection)
|
||||
assert default_tag not in tag_list
|
||||
|
||||
def test_drop_partition_create(self, connect, collection):
|
||||
'''
|
||||
target: test drop partition, and create again, check status
|
||||
method: create partitions first, then call function: drop_partition, create_partition
|
||||
expected: status not ok, partition in db
|
||||
'''
|
||||
connect.create_partition(collection, default_tag)
|
||||
connect.drop_partition(collection, default_tag)
|
||||
time.sleep(2)
|
||||
connect.create_partition(collection, default_tag)
|
||||
tag_list = connect.list_partitions(collection)
|
||||
assert default_tag in tag_list
|
||||
|
||||
|
||||
class TestNameInvalid(object):
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_invalid_strs()
|
||||
)
|
||||
def get_tag_name(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_invalid_strs()
|
||||
)
|
||||
def get_collection_name(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_drop_partition_with_invalid_collection_name(self, connect, collection, get_collection_name):
|
||||
'''
|
||||
target: test drop partition, with invalid collection name, check status returned
|
||||
method: call function: drop_partition
|
||||
expected: status not ok
|
||||
'''
|
||||
collection_name = get_collection_name
|
||||
connect.create_partition(collection, default_tag)
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.drop_partition(collection_name, default_tag)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_drop_partition_with_invalid_tag_name(self, connect, collection, get_tag_name):
|
||||
'''
|
||||
target: test drop partition, with invalid tag name, check status returned
|
||||
method: call function: drop_partition
|
||||
expected: status not ok
|
||||
'''
|
||||
tag_name = get_tag_name
|
||||
connect.create_partition(collection, default_tag)
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.drop_partition(collection, tag_name)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_list_partitions_with_invalid_collection_name(self, connect, collection, get_collection_name):
|
||||
'''
|
||||
target: test show partitions, with invalid collection name, check status returned
|
||||
method: call function: list_partitions
|
||||
expected: status not ok
|
||||
'''
|
||||
collection_name = get_collection_name
|
||||
connect.create_partition(collection, default_tag)
|
||||
with pytest.raises(Exception) as e:
|
||||
res = connect.list_partitions(collection_name)
|
1824
tests/python/test_search.py
Normal file
1824
tests/python/test_search.py
Normal file
File diff suppressed because it is too large
Load Diff
1005
tests/python/utils.py
Normal file
1005
tests/python/utils.py
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user