From 77fa75b1ecda20b43b0803d68c0f6ab465ebfe07 Mon Sep 17 00:00:00 2001 From: FluorineDog Date: Mon, 30 Nov 2020 05:18:44 +0800 Subject: [PATCH] Add binary insert and warper of binary search, rename vector Signed-off-by: FluorineDog --- .jenkins/modules/Build/Build.groovy | 13 ++ .jenkins/modules/Package/Package.groovy | 7 ++ .jenkins/modules/Publish/Publish.groovy | 32 +++++ .jenkins/modules/UnitTest/UnitTest.groovy | 3 + Makefile | 8 +- build/ci/jenkins/Jenkinsfile | 68 +++++++++++ build/ci/jenkins/pod/build-env.yaml | 51 ++++++++ build/ci/jenkins/pod/docker-pod.yaml | 35 ++++++ build/docker/deploy/.env | 4 + build/docker/deploy/docker-compose.yml | 14 ++- internal/allocator/segment.go | 2 +- internal/core/build-support/add_license.sh | 23 ++-- internal/core/run_clang_format.sh | 1 + internal/core/src/pb/service_msg.pb.cc | 24 ++-- internal/core/src/pb/service_msg.pb.h | 26 ++-- internal/core/src/query/BruteForceSearch.cpp | 88 ++++++++++++++ internal/core/src/query/BruteForceSearch.h | 29 +++++ internal/core/src/query/CMakeLists.txt | 1 + internal/core/src/query/Search.cpp | 4 +- .../src/query/visitors/ExecExprVisitor.cpp | 2 +- internal/core/src/segcore/ConcurrentVector.h | 40 ++++-- internal/core/src/segcore/DeletedRecord.h | 4 +- internal/core/src/segcore/IndexingEntry.cpp | 11 +- internal/core/src/segcore/IndexingEntry.h | 4 +- internal/core/src/segcore/InsertRecord.cpp | 24 ++-- internal/core/src/segcore/InsertRecord.h | 24 +--- internal/core/src/segcore/SegmentNaive.cpp | 11 +- .../core/src/segcore/SegmentSmallIndex.cpp | 4 +- internal/core/src/segcore/collection_c.h | 12 ++ internal/core/src/segcore/plan_c.cpp | 47 ++++++-- internal/core/src/segcore/plan_c.h | 11 +- internal/core/src/segcore/segment_c.cpp | 6 +- internal/core/src/segcore/segment_c.h | 11 -- internal/core/unittest/CMakeLists.txt | 4 +- internal/core/unittest/test_binary.cpp | 31 +++++ internal/core/unittest/test_c_api.cpp | 23 +++- .../core/unittest/test_concurrent_vector.cpp | 4 +- internal/core/unittest/test_indexing.cpp | 108 ++++++++++++++++- internal/core/unittest/test_query.cpp | 13 +- internal/core/unittest/test_utils/DataGen.h | 55 +++++++++ internal/master/meta_table.go | 10 +- internal/master/segment_manager_test.go | 4 +- internal/master/time_sync_producer.go | 2 +- internal/master/timesync.go | 2 +- internal/master/timesync_test.go | 2 +- internal/msgstream/msgstream.go | 2 +- internal/msgstream/msgstream_test.go | 12 +- internal/msgstream/task.go | 6 +- internal/msgstream/task_test.go | 4 +- internal/proto/service_msg.proto | 2 +- internal/proto/servicepb/service_msg.pb.go | 66 +++++----- internal/proxy/meta_cache.go | 77 +++++++++--- internal/proxy/proxy_test.go | 4 +- internal/proxy/task.go | 57 +++++++-- internal/proxy/timetick.go | 2 +- internal/querynode/data_sync_service_test.go | 6 +- internal/querynode/plan.go | 34 +++++- internal/querynode/plan_test.go | 9 +- internal/querynode/reduce_test.go | 6 +- internal/querynode/search_service.go | 37 +++--- internal/querynode/search_service_test.go | 8 +- internal/querynode/segment_test.go | 6 +- internal/querynode/stats_service.go | 2 +- internal/util/paramtable/paramtable.go | 13 +- scripts/before-install.sh | 14 +-- scripts/check_cache.sh | 114 ++++++++++++++++++ scripts/update_cache.sh | 104 ++++++++++++++++ 67 files changed, 1210 insertions(+), 277 deletions(-) create mode 100644 .jenkins/modules/Build/Build.groovy create mode 100644 .jenkins/modules/Package/Package.groovy create mode 100644 .jenkins/modules/Publish/Publish.groovy create mode 100644 .jenkins/modules/UnitTest/UnitTest.groovy create mode 100644 build/ci/jenkins/Jenkinsfile create mode 100644 build/ci/jenkins/pod/build-env.yaml create mode 100644 build/ci/jenkins/pod/docker-pod.yaml create mode 100644 internal/core/src/query/BruteForceSearch.cpp create mode 100644 internal/core/src/query/BruteForceSearch.h create mode 100644 internal/core/unittest/test_binary.cpp create mode 100755 scripts/check_cache.sh create mode 100755 scripts/update_cache.sh diff --git a/.jenkins/modules/Build/Build.groovy b/.jenkins/modules/Build/Build.groovy new file mode 100644 index 0000000000..998ed76e99 --- /dev/null +++ b/.jenkins/modules/Build/Build.groovy @@ -0,0 +1,13 @@ +timeout(time: 5, unit: 'MINUTES') { + dir ("scripts") { + sh '. ./before-install.sh && unset http_proxy && unset https_proxy && ./check_cache.sh -l $CCACHE_ARTFACTORY_URL --cache_dir=\$CCACHE_DIR -f ccache-\$OS_NAME-\$BUILD_ENV_IMAGE_ID.tar.gz || echo \"ccache files not found!\"' + } + + sh '. ./scripts/before-install.sh && make check-proto-product && make verifiers && make install' + + dir ("scripts") { + withCredentials([usernamePassword(credentialsId: "${env.JFROG_CREDENTIALS_ID}", usernameVariable: 'USERNAME', passwordVariable: 'PASSWORD')]) { + sh '. ./before-install.sh && unset http_proxy && unset https_proxy && ./update_cache.sh -l $CCACHE_ARTFACTORY_URL --cache_dir=\$CCACHE_DIR -f ccache-\$OS_NAME-\$BUILD_ENV_IMAGE_ID.tar.gz -u ${USERNAME} -p ${PASSWORD}' + } + } +} diff --git a/.jenkins/modules/Package/Package.groovy b/.jenkins/modules/Package/Package.groovy new file mode 100644 index 0000000000..256b38fec9 --- /dev/null +++ b/.jenkins/modules/Package/Package.groovy @@ -0,0 +1,7 @@ +sh 'tar -zcvf ./${PACKAGE_NAME} ./bin ./configs ./lib' +withCredentials([usernamePassword(credentialsId: "${env.JFROG_CREDENTIALS_ID}", usernameVariable: 'JFROG_USERNAME', passwordVariable: 'JFROG_PASSWORD')]) { + def uploadStatus = sh(returnStatus: true, script: 'curl -u${JFROG_USERNAME}:${JFROG_PASSWORD} -T ./${PACKAGE_NAME} ${PACKAGE_ARTFACTORY_URL}') + if (uploadStatus != 0) { + error("\" ${PACKAGE_NAME} \" upload to \" ${PACKAGE_ARTFACTORY_URL} \" failed!") + } +} diff --git a/.jenkins/modules/Publish/Publish.groovy b/.jenkins/modules/Publish/Publish.groovy new file mode 100644 index 0000000000..06d18d906c --- /dev/null +++ b/.jenkins/modules/Publish/Publish.groovy @@ -0,0 +1,32 @@ +withCredentials([usernamePassword(credentialsId: "${env.JFROG_CREDENTIALS_ID}", usernameVariable: 'JFROG_USERNAME', passwordVariable: 'JFROG_PASSWORD')]) { + def downloadStatus = sh(returnStatus: true, script: 'curl -u${JFROG_USERNAME}:${JFROG_PASSWORD} -O ${PACKAGE_ARTFACTORY_URL}') + + if (downloadStatus != 0) { + error("\" Download \" ${PACKAGE_ARTFACTORY_URL} \" failed!") + } +} + +sh 'tar zxvf ${PACKAGE_NAME}' + +dir ('build/docker/deploy') { + try { + withCredentials([usernamePassword(credentialsId: "${env.DOCKER_CREDENTIALS_ID}", usernameVariable: 'DOCKER_USERNAME', passwordVariable: 'DOCKER_PASSWORD')]) { + sh 'docker login -u ${DOCKER_USERNAME} -p ${DOCKER_PASSWORD} ${DOKCER_REGISTRY_URL}' + sh 'docker pull ${SOURCE_REPO}/master:${SOURCE_TAG} || true' + sh 'docker-compose build --force-rm master' + sh 'docker-compose push master' + sh 'docker pull ${SOURCE_REPO}/proxy:${SOURCE_TAG} || true' + sh 'docker-compose build --force-rm proxy' + sh 'docker-compose push proxy' + sh 'docker pull ${SOURCE_REPO}/querynode:${SOURCE_TAG} || true' + sh 'docker-compose build --force-rm querynode' + sh 'docker-compose push querynode' + } + } catch (exc) { + throw exc + } finally { + sh 'docker logout ${DOKCER_REGISTRY_URL}' + sh "docker rmi -f \$(docker images | grep '' | awk '{print \$3}') || true" + sh 'docker-compose down --rmi all' + } +} diff --git a/.jenkins/modules/UnitTest/UnitTest.groovy b/.jenkins/modules/UnitTest/UnitTest.groovy new file mode 100644 index 0000000000..f78834e99c --- /dev/null +++ b/.jenkins/modules/UnitTest/UnitTest.groovy @@ -0,0 +1,3 @@ +timeout(time: 5, unit: 'MINUTES') { + sh 'make unittest' +} \ No newline at end of file diff --git a/Makefile b/Makefile index 61c2c19aff..0856c124cf 100644 --- a/Makefile +++ b/Makefile @@ -55,13 +55,13 @@ verifiers: cppcheck fmt lint ruleguard # Builds various components locally. build-go: - @echo "Building each component's binary to './'" + @echo "Building each component's binary to './bin'" @echo "Building query node ..." - @mkdir -p $(INSTALL_PATH) && GO111MODULE=on $(GO) build -o $(INSTALL_PATH)/querynode $(PWD)/cmd/querynode/query_node.go 1>/dev/null + @mkdir -p $(INSTALL_PATH) && go env -w CGO_ENABLED="1" && GO111MODULE=on $(GO) build -o $(INSTALL_PATH)/querynode $(PWD)/cmd/querynode/query_node.go 1>/dev/null @echo "Building master ..." - @mkdir -p $(INSTALL_PATH) && GO111MODULE=on $(GO) build -o $(INSTALL_PATH)/master $(PWD)/cmd/master/main.go 1>/dev/null + @mkdir -p $(INSTALL_PATH) && go env -w CGO_ENABLED="0" && GO111MODULE=on $(GO) build -o $(INSTALL_PATH)/master $(PWD)/cmd/master/main.go 1>/dev/null @echo "Building proxy ..." - @mkdir -p $(INSTALL_PATH) && GO111MODULE=on $(GO) build -o $(INSTALL_PATH)/proxy $(PWD)/cmd/proxy/proxy.go 1>/dev/null + @mkdir -p $(INSTALL_PATH) && go env -w CGO_ENABLED="0" && GO111MODULE=on $(GO) build -o $(INSTALL_PATH)/proxy $(PWD)/cmd/proxy/proxy.go 1>/dev/null build-cpp: @(env bash $(PWD)/scripts/core_build.sh) diff --git a/build/ci/jenkins/Jenkinsfile b/build/ci/jenkins/Jenkinsfile new file mode 100644 index 0000000000..679a2f27ac --- /dev/null +++ b/build/ci/jenkins/Jenkinsfile @@ -0,0 +1,68 @@ +#!/usr/bin/env groovy +@Library('mpl') _ + +pipeline { + agent none + options { + timestamps() + } + environment { + PROJECT_NAME = "milvus-distributed" + SEMVER = "${BRANCH_NAME.contains('/') ? BRANCH_NAME.substring(BRANCH_NAME.lastIndexOf('/') + 1) : BRANCH_NAME}" + BUILD_TYPE = "Release" + LOWER_BUILD_TYPE = BUILD_TYPE.toLowerCase() + PACKAGE_VERSION = "${SEMVER}-${LOWER_BUILD_TYPE}" + PACKAGE_NAME = "${PROJECT_NAME}-${PACKAGE_VERSION}.tar.gz" + JFROG_CREDENTIALS_ID = "1a527823-d2b7-44fd-834b-9844350baf14" + JFROG_ARTFACTORY_URL = "http://192.168.1.201/artifactory/milvus" + PACKAGE_ARTFACTORY_URL = "${JFROG_ARTFACTORY_URL}/${PROJECT_NAME}/package/${PACKAGE_NAME}" + DOCKER_CREDENTIALS_ID = "ba070c98-c8cc-4f7c-b657-897715f359fc" + DOKCER_REGISTRY_URL = "registry.zilliz.com" + } + stages { + stage ('Build and UnitTest') { + agent { + kubernetes { + label "${env.PROJECT_NAME}-${SEMVER}-${env.BUILD_NUMBER}-build" + defaultContainer 'build-env' + customWorkspace '/home/jenkins/agent/workspace' + yamlFile "build/ci/jenkins/pod/build-env.yaml" + } + } + environment { + PULSAR_ADDRESS = "pulsar://127.0.0.1:6650" + ETCD_ADDRESS = "127.0.0.1:2379" + CCACHE_ARTFACTORY_URL = "${JFROG_ARTFACTORY_URL}/milvus-distributed/ccache" + } + steps { + container('build-env') { + MPLModule('Build') + MPLModule('Package') + // MPLModule('UnitTest') + } + } + } + + stage ('Publish Docker Images') { + agent { + kubernetes { + label "${env.PROJECT_NAME}-${SEMVER}-${env.BUILD_NUMBER}-publish" + defaultContainer 'publish-images' + yamlFile "build/ci/jenkins/pod/docker-pod.yaml" + } + } + environment{ + SOURCE_REPO = "${DOKCER_REGISTRY_URL}/milvus-distributed" + TARGET_REPO = "${DOKCER_REGISTRY_URL}/milvus-distributed" + SOURCE_TAG = "${CHANGE_TARGET ? CHANGE_TARGET : SEMVER}-${LOWER_BUILD_TYPE}" + TARGET_TAG = "${SEMVER}-${LOWER_BUILD_TYPE}" + DOCKER_BUILDKIT = 1 + } + steps { + container('publish-images') { + MPLModule('Publish') + } + } + } + } +} diff --git a/build/ci/jenkins/pod/build-env.yaml b/build/ci/jenkins/pod/build-env.yaml new file mode 100644 index 0000000000..ba8d5cfe97 --- /dev/null +++ b/build/ci/jenkins/pod/build-env.yaml @@ -0,0 +1,51 @@ +apiVersion: v1 +kind: Pod +metadata: + name: build-env + labels: + app: milvus-distributed + componet: build-env +spec: + securityContext: + runAsUser: 2000 + runAsGroup: 2000 + containers: + - name: build-env + image: milvusdb/milvus-distributed-dev:amd64-ubuntu18.04-20201124-101232 + env: + - name: OS_NAME + value: "ubuntu18.04" + - name: BUILD_ENV_IMAGE_ID + value: "f0f52760fde8758793f5a68c39ba20298c812e754de337ba4cc7fd8edf4ae7a9" + command: + - cat + tty: true + resources: + limits: + memory: "16Gi" + cpu: "8.0" + requests: + memory: "8Gi" + cpu: "4.0" + - name: etcd + image: quay.io/coreos/etcd:v3.4.13 + env: + - name: ETCD_LISTEN_CLIENT_URLS + value: "http://0.0.0.0:2379" + - name: ETCD_ADVERTISE_CLIENT_URLS + value: "http://0.0.0.0:2379" + ports: + - containerPort: 2379 + name: etcd + securityContext: + runAsUser: 0 + runAsGroup: 0 + - name: pulsar + image: apachepulsar/pulsar:2.6.1 + ports: + - containerPort: 6650 + name: pulsar + command: ["bin/pulsar", "standalone"] + securityContext: + runAsUser: 0 + runAsGroup: 0 diff --git a/build/ci/jenkins/pod/docker-pod.yaml b/build/ci/jenkins/pod/docker-pod.yaml new file mode 100644 index 0000000000..1f88a14f38 --- /dev/null +++ b/build/ci/jenkins/pod/docker-pod.yaml @@ -0,0 +1,35 @@ +apiVersion: v1 +kind: Pod +metadata: + labels: + app: publish + componet: docker +spec: + containers: + - name: publish-images + image: registry.zilliz.com/library/docker:v1.1.0 + imagePullPolicy: Always + securityContext: + privileged: true + command: + - cat + tty: true + resources: + limits: + memory: "8Gi" + cpu: "2" + requests: + memory: "2Gi" + cpu: "1" + volumeMounts: + - name: docker-sock + mountPath: /var/run/docker.sock + volumes: + - name: docker-sock + hostPath: + path: /var/run/docker.sock + tolerations: + - key: dedicated + operator: Equal + value: milvus + effect: NoSchedule diff --git a/build/docker/deploy/.env b/build/docker/deploy/.env index 2c2ec0a3df..dd4ba94a07 100644 --- a/build/docker/deploy/.env +++ b/build/docker/deploy/.env @@ -1,3 +1,7 @@ +SOURCE_REPO=milvusdb +TARGET_REPO=milvusdb +SOURCE_TAG=latest +TARGET_TAG=latest PULSAR_ADDRESS=pulsar://pulsar:6650 ETCD_ADDRESS=etcd:2379 MASTER_ADDRESS=master:53100 diff --git a/build/docker/deploy/docker-compose.yml b/build/docker/deploy/docker-compose.yml index 9588637d3f..e8b5044ebc 100644 --- a/build/docker/deploy/docker-compose.yml +++ b/build/docker/deploy/docker-compose.yml @@ -2,27 +2,29 @@ version: '3.5' services: master: - image: master + image: ${TARGET_REPO}/master:${TARGET_TAG} build: context: ../../../ dockerfile: build/docker/deploy/master/DockerFile + cache_from: + - ${SOURCE_REPO}/master:${SOURCE_TAG} environment: PULSAR_ADDRESS: ${PULSAR_ADDRESS} ETCD_ADDRESS: ${ETCD_ADDRESS} - MASTER_ADDRESS: ${MASTER_ADDRESS} networks: - milvus ports: - "53100:53100" proxy: - image: proxy + image: ${TARGET_REPO}/proxy:${TARGET_TAG} build: context: ../../../ dockerfile: build/docker/deploy/proxy/DockerFile + cache_from: + - ${SOURCE_REPO}/proxy:${SOURCE_TAG} environment: PULSAR_ADDRESS: ${PULSAR_ADDRESS} - ETCD_ADDRESS: ${ETCD_ADDRESS} MASTER_ADDRESS: ${MASTER_ADDRESS} ports: - "19530:19530" @@ -30,10 +32,12 @@ services: - milvus querynode: - image: querynode + image: ${TARGET_REPO}/querynode:${TARGET_TAG} build: context: ../../../ dockerfile: build/docker/deploy/querynode/DockerFile + cache_from: + - ${SOURCE_REPO}/querynode:${SOURCE_TAG} environment: PULSAR_ADDRESS: ${PULSAR_ADDRESS} ETCD_ADDRESS: ${ETCD_ADDRESS} diff --git a/internal/allocator/segment.go b/internal/allocator/segment.go index 419033d244..83fc0f05e4 100644 --- a/internal/allocator/segment.go +++ b/internal/allocator/segment.go @@ -175,7 +175,7 @@ func (sa *SegIDAssigner) syncSegments() { resp, err := sa.masterClient.AssignSegmentID(ctx, req) if resp.Status.GetErrorCode() != commonpb.ErrorCode_SUCCESS { - log.Panic("GRPC AssignSegmentID Failed") + log.Println("GRPC AssignSegmentID Failed", resp, err) return } diff --git a/internal/core/build-support/add_license.sh b/internal/core/build-support/add_license.sh index 5d27816890..83dd6d48f3 100755 --- a/internal/core/build-support/add_license.sh +++ b/internal/core/build-support/add_license.sh @@ -1,11 +1,13 @@ -FOLDER=$1 -if [ -z ${FOLDER} ]; then - echo usage $0 [folder_to_add_license] +LICENSE=$1 +FOLDER=$2 + +if [ -z ${FOLDER} ] || [ -z ${LICENSE} ]; then + echo "usage $0 " exit -else - echo good fi +cat ${LICENSE} > /dev/null || exit -1 + FILES=`find ${FOLDER} \ | grep -E "(*\.cpp$|*\.h$|*\.cu$)" \ | grep -v thirdparty \ @@ -13,13 +15,16 @@ FILES=`find ${FOLDER} \ | grep -v cmake-build \ | grep -v output \ | grep -v "\.pb\."` -echo formating ${FILES} ... +# echo formating ${FILES} ... +skip_count=0 for f in ${FILES}; do - if (grep "Apache License" $f);then - echo "No need to copy the License Header to $f" + if (grep "Apache License" $f > /dev/null);then + # echo "No need to copy the License Header to $f" + skip_count=$((skip_count+1)) else - cat cpp_license.txt $f > $f.new + cat ${LICENSE} $f > $f.new mv $f.new $f echo "License Header copied to $f" fi done +echo "license adder: $skip_count file(s) skiped" diff --git a/internal/core/run_clang_format.sh b/internal/core/run_clang_format.sh index e713a542df..b5a682639c 100755 --- a/internal/core/run_clang_format.sh +++ b/internal/core/run_clang_format.sh @@ -13,3 +13,4 @@ formatThis() { formatThis "${CorePath}/src" formatThis "${CorePath}/unittest" +${CorePath}/build-support/add_license.sh ${CorePath}/build-support/cpp_license.txt ${CorePath} diff --git a/internal/core/src/pb/service_msg.pb.cc b/internal/core/src/pb/service_msg.pb.cc index 070902e4da..168a4661e9 100644 --- a/internal/core/src/pb/service_msg.pb.cc +++ b/internal/core/src/pb/service_msg.pb.cc @@ -473,7 +473,7 @@ const char descriptor_table_protodef_service_5fmsg_2eproto[] PROTOBUF_SECTION_VA "g\030\002 \001(\t\"z\n\010RowBatch\022\027\n\017collection_name\030\001" " \001(\t\022\025\n\rpartition_tag\030\002 \001(\t\022+\n\010row_data\030" "\003 \003(\0132\031.milvus.proto.common.Blob\022\021\n\thash" - "_keys\030\004 \003(\005\"d\n\020PlaceholderValue\022\013\n\003tag\030\001" + "_keys\030\004 \003(\r\"d\n\020PlaceholderValue\022\013\n\003tag\030\001" " \001(\t\0223\n\004type\030\002 \001(\0162%.milvus.proto.servic" "e.PlaceholderType\022\016\n\006values\030\003 \003(\014\"P\n\020Pla" "ceholderGroup\022<\n\014placeholders\030\001 \003(\0132&.mi" @@ -1265,10 +1265,10 @@ const char* RowBatch::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::i } while (::PROTOBUF_NAMESPACE_ID::internal::UnalignedLoad<::PROTOBUF_NAMESPACE_ID::uint8>(ptr) == 26); } else goto handle_unusual; continue; - // repeated int32 hash_keys = 4; + // repeated uint32 hash_keys = 4; case 4: if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 34)) { - ptr = ::PROTOBUF_NAMESPACE_ID::internal::PackedInt32Parser(mutable_hash_keys(), ptr, ctx); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::PackedUInt32Parser(mutable_hash_keys(), ptr, ctx); CHK_(ptr); } else if (static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 32) { add_hash_keys(::PROTOBUF_NAMESPACE_ID::internal::ReadVarint(&ptr)); @@ -1346,15 +1346,15 @@ bool RowBatch::MergePartialFromCodedStream( break; } - // repeated int32 hash_keys = 4; + // repeated uint32 hash_keys = 4; case 4: { if (static_cast< ::PROTOBUF_NAMESPACE_ID::uint8>(tag) == (34 & 0xFF)) { DO_((::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::ReadPackedPrimitive< - ::PROTOBUF_NAMESPACE_ID::int32, ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::TYPE_INT32>( + ::PROTOBUF_NAMESPACE_ID::uint32, ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::TYPE_UINT32>( input, this->mutable_hash_keys()))); } else if (static_cast< ::PROTOBUF_NAMESPACE_ID::uint8>(tag) == (32 & 0xFF)) { DO_((::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::ReadRepeatedPrimitiveNoInline< - ::PROTOBUF_NAMESPACE_ID::int32, ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::TYPE_INT32>( + ::PROTOBUF_NAMESPACE_ID::uint32, ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::TYPE_UINT32>( 1, 34u, input, this->mutable_hash_keys()))); } else { goto handle_unusual; @@ -1418,14 +1418,14 @@ void RowBatch::SerializeWithCachedSizes( output); } - // repeated int32 hash_keys = 4; + // repeated uint32 hash_keys = 4; if (this->hash_keys_size() > 0) { ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteTag(4, ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WIRETYPE_LENGTH_DELIMITED, output); output->WriteVarint32(_hash_keys_cached_byte_size_.load( std::memory_order_relaxed)); } for (int i = 0, n = this->hash_keys_size(); i < n; i++) { - ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32NoTag( + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteUInt32NoTag( this->hash_keys(i), output); } @@ -1472,7 +1472,7 @@ void RowBatch::SerializeWithCachedSizes( 3, this->row_data(static_cast(i)), target); } - // repeated int32 hash_keys = 4; + // repeated uint32 hash_keys = 4; if (this->hash_keys_size() > 0) { target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteTagToArray( 4, @@ -1482,7 +1482,7 @@ void RowBatch::SerializeWithCachedSizes( _hash_keys_cached_byte_size_.load(std::memory_order_relaxed), target); target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: - WriteInt32NoTagToArray(this->hash_keys_, target); + WriteUInt32NoTagToArray(this->hash_keys_, target); } if (_internal_metadata_.have_unknown_fields()) { @@ -1517,10 +1517,10 @@ size_t RowBatch::ByteSizeLong() const { } } - // repeated int32 hash_keys = 4; + // repeated uint32 hash_keys = 4; { size_t data_size = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: - Int32Size(this->hash_keys_); + UInt32Size(this->hash_keys_); if (data_size > 0) { total_size += 1 + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size( diff --git a/internal/core/src/pb/service_msg.pb.h b/internal/core/src/pb/service_msg.pb.h index 62d9101204..d113e12ca9 100644 --- a/internal/core/src/pb/service_msg.pb.h +++ b/internal/core/src/pb/service_msg.pb.h @@ -573,15 +573,15 @@ class RowBatch : const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::milvus::proto::common::Blob >& row_data() const; - // repeated int32 hash_keys = 4; + // repeated uint32 hash_keys = 4; int hash_keys_size() const; void clear_hash_keys(); - ::PROTOBUF_NAMESPACE_ID::int32 hash_keys(int index) const; - void set_hash_keys(int index, ::PROTOBUF_NAMESPACE_ID::int32 value); - void add_hash_keys(::PROTOBUF_NAMESPACE_ID::int32 value); - const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >& + ::PROTOBUF_NAMESPACE_ID::uint32 hash_keys(int index) const; + void set_hash_keys(int index, ::PROTOBUF_NAMESPACE_ID::uint32 value); + void add_hash_keys(::PROTOBUF_NAMESPACE_ID::uint32 value); + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint32 >& hash_keys() const; - ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >* + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint32 >* mutable_hash_keys(); // string collection_name = 1; @@ -612,7 +612,7 @@ class RowBatch : ::PROTOBUF_NAMESPACE_ID::internal::InternalMetadataWithArena _internal_metadata_; ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::milvus::proto::common::Blob > row_data_; - ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 > hash_keys_; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint32 > hash_keys_; mutable std::atomic _hash_keys_cached_byte_size_; ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr collection_name_; ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr partition_tag_; @@ -2777,31 +2777,31 @@ RowBatch::row_data() const { return row_data_; } -// repeated int32 hash_keys = 4; +// repeated uint32 hash_keys = 4; inline int RowBatch::hash_keys_size() const { return hash_keys_.size(); } inline void RowBatch::clear_hash_keys() { hash_keys_.Clear(); } -inline ::PROTOBUF_NAMESPACE_ID::int32 RowBatch::hash_keys(int index) const { +inline ::PROTOBUF_NAMESPACE_ID::uint32 RowBatch::hash_keys(int index) const { // @@protoc_insertion_point(field_get:milvus.proto.service.RowBatch.hash_keys) return hash_keys_.Get(index); } -inline void RowBatch::set_hash_keys(int index, ::PROTOBUF_NAMESPACE_ID::int32 value) { +inline void RowBatch::set_hash_keys(int index, ::PROTOBUF_NAMESPACE_ID::uint32 value) { hash_keys_.Set(index, value); // @@protoc_insertion_point(field_set:milvus.proto.service.RowBatch.hash_keys) } -inline void RowBatch::add_hash_keys(::PROTOBUF_NAMESPACE_ID::int32 value) { +inline void RowBatch::add_hash_keys(::PROTOBUF_NAMESPACE_ID::uint32 value) { hash_keys_.Add(value); // @@protoc_insertion_point(field_add:milvus.proto.service.RowBatch.hash_keys) } -inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >& +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint32 >& RowBatch::hash_keys() const { // @@protoc_insertion_point(field_list:milvus.proto.service.RowBatch.hash_keys) return hash_keys_; } -inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >* +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint32 >* RowBatch::mutable_hash_keys() { // @@protoc_insertion_point(field_mutable_list:milvus.proto.service.RowBatch.hash_keys) return &hash_keys_; diff --git a/internal/core/src/query/BruteForceSearch.cpp b/internal/core/src/query/BruteForceSearch.cpp new file mode 100644 index 0000000000..626fe54290 --- /dev/null +++ b/internal/core/src/query/BruteForceSearch.cpp @@ -0,0 +1,88 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#include "BruteForceSearch.h" +#include + +namespace milvus::query { + +void +BinarySearchBruteForce(faiss::MetricType metric_type, + int64_t code_size, + const uint8_t* binary_chunk, + int64_t chunk_size, + int64_t topk, + int64_t num_queries, + const uint8_t* query_data, + float* result_distances, + idx_t* result_labels, + faiss::ConcurrentBitsetPtr bitset) { + const idx_t block_size = segcore::DefaultElementPerChunk; + bool use_heap = true; + + if (metric_type == faiss::METRIC_Jaccard || metric_type == faiss::METRIC_Tanimoto) { + float* D = result_distances; + for (idx_t query_base_index = 0; query_base_index < num_queries; query_base_index += block_size) { + idx_t query_size = block_size; + if (query_base_index + block_size > num_queries) { + query_size = num_queries - query_base_index; + } + + // We see the distances and labels as heaps. + faiss::float_maxheap_array_t res = {size_t(query_size), size_t(topk), + result_labels + query_base_index * topk, D + query_base_index * topk}; + + binary_distence_knn_hc(metric_type, &res, query_data + query_base_index * code_size, binary_chunk, + chunk_size, code_size, + /* ordered = */ true, bitset); + } + if (metric_type == faiss::METRIC_Tanimoto) { + for (int i = 0; i < topk * num_queries; i++) { + D[i] = -log2(1 - D[i]); + } + } + } else if (metric_type == faiss::METRIC_Substructure || metric_type == faiss::METRIC_Superstructure) { + float* D = result_distances; + for (idx_t s = 0; s < num_queries; s += block_size) { + idx_t nn = block_size; + if (s + block_size > num_queries) { + nn = num_queries - s; + } + + // only match ids will be chosed, not to use heap + binary_distence_knn_mc(metric_type, query_data + s * code_size, binary_chunk, nn, chunk_size, topk, + code_size, D + s * topk, result_labels + s * topk, bitset); + } + } else if (metric_type == faiss::METRIC_Hamming) { + std::vector int_distances(topk * num_queries); + for (idx_t s = 0; s < num_queries; s += block_size) { + idx_t nn = block_size; + if (s + block_size > num_queries) { + nn = num_queries - s; + } + if (use_heap) { + // We see the distances and labels as heaps. + faiss::int_maxheap_array_t res = {size_t(nn), size_t(topk), result_labels + s * topk, + int_distances.data() + s * topk}; + + hammings_knn_hc(&res, query_data + s * code_size, binary_chunk, chunk_size, code_size, + /* ordered = */ true, bitset); + } else { + hammings_knn_mc(query_data + s * code_size, binary_chunk, nn, chunk_size, topk, code_size, + int_distances.data() + s * topk, result_labels + s * topk, bitset); + } + } + for (int i = 0; i < num_queries; ++i) { + result_distances[i] = static_cast(int_distances[i]); + } + } +} +} // namespace milvus::query diff --git a/internal/core/src/query/BruteForceSearch.h b/internal/core/src/query/BruteForceSearch.h new file mode 100644 index 0000000000..1edc19e159 --- /dev/null +++ b/internal/core/src/query/BruteForceSearch.h @@ -0,0 +1,29 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#pragma once +#include +#include "segcore/ConcurrentVector.h" +#include "common/Schema.h" + +namespace milvus::query { +void +BinarySearchBruteForce(faiss::MetricType metric_type, + int64_t code_size, + const uint8_t* binary_chunk, + int64_t chunk_size, + int64_t topk, + int64_t num_queries, + const uint8_t* query_data, + float* result_distances, + idx_t* result_labels, + faiss::ConcurrentBitsetPtr bitset = nullptr); +} // namespace milvus::query diff --git a/internal/core/src/query/CMakeLists.txt b/internal/core/src/query/CMakeLists.txt index 9cd8c684af..671713dc5b 100644 --- a/internal/core/src/query/CMakeLists.txt +++ b/internal/core/src/query/CMakeLists.txt @@ -9,6 +9,7 @@ set(MILVUS_QUERY_SRCS visitors/ExecExprVisitor.cpp Plan.cpp Search.cpp + BruteForceSearch.cpp ) add_library(milvus_query ${MILVUS_QUERY_SRCS}) target_link_libraries(milvus_query milvus_proto) diff --git a/internal/core/src/query/Search.cpp b/internal/core/src/query/Search.cpp index d18f4dc57f..b73014a86a 100644 --- a/internal/core/src/query/Search.cpp +++ b/internal/core/src/query/Search.cpp @@ -93,8 +93,8 @@ QueryBruteForceImpl(const segcore::SegmentSmallIndex& segment, } segcore::merge_into(num_queries, topK, final_dis.data(), final_uids.data(), dis, uids); } - - auto vec_ptr = record.get_vec_entity(vecfield_offset); + using segcore::FloatVector; + auto vec_ptr = record.get_entity(vecfield_offset); // step 4: brute force search where small indexing is unavailable for (int chunk_id = max_indexed_id; chunk_id < max_chunk; ++chunk_id) { std::vector buf_uids(total_count, -1); diff --git a/internal/core/src/query/visitors/ExecExprVisitor.cpp b/internal/core/src/query/visitors/ExecExprVisitor.cpp index d1b5c3d963..83e5782c91 100644 --- a/internal/core/src/query/visitors/ExecExprVisitor.cpp +++ b/internal/core/src/query/visitors/ExecExprVisitor.cpp @@ -79,7 +79,7 @@ ExecExprVisitor::ExecRangeVisitorImpl(RangeExprImpl& expr, IndexFunc index_fu Assert(field_offset_opt); auto field_offset = field_offset_opt.value(); auto& field_meta = schema[field_offset]; - auto vec_ptr = records.get_scalar_entity(field_offset); + auto vec_ptr = records.get_entity(field_offset); auto& vec = *vec_ptr; auto& indexing_record = segment_.get_indexing_record(); const segcore::ScalarIndexingEntry& entry = indexing_record.get_scalar_entry(field_offset); diff --git a/internal/core/src/segcore/ConcurrentVector.h b/internal/core/src/segcore/ConcurrentVector.h index 8373b6ee2f..74ece55121 100644 --- a/internal/core/src/segcore/ConcurrentVector.h +++ b/internal/core/src/segcore/ConcurrentVector.h @@ -109,20 +109,20 @@ class VectorBase { }; template -class ConcurrentVector : public VectorBase { +class ConcurrentVectorImpl : public VectorBase { public: // constants using Chunk = FixedVector; - ConcurrentVector(ConcurrentVector&&) = delete; - ConcurrentVector(const ConcurrentVector&) = delete; + ConcurrentVectorImpl(ConcurrentVectorImpl&&) = delete; + ConcurrentVectorImpl(const ConcurrentVectorImpl&) = delete; - ConcurrentVector& - operator=(ConcurrentVector&&) = delete; - ConcurrentVector& - operator=(const ConcurrentVector&) = delete; + ConcurrentVectorImpl& + operator=(ConcurrentVectorImpl&&) = delete; + ConcurrentVectorImpl& + operator=(const ConcurrentVectorImpl&) = delete; public: - explicit ConcurrentVector(ssize_t dim = 1) : Dim(is_scalar ? 1 : dim), SizePerChunk(Dim * ElementsPerChunk) { + explicit ConcurrentVectorImpl(ssize_t dim = 1) : Dim(is_scalar ? 1 : dim), SizePerChunk(Dim * ElementsPerChunk) { Assert(is_scalar ? dim == 1 : dim != 1); } @@ -221,4 +221,28 @@ class ConcurrentVector : public VectorBase { ThreadSafeVector chunks_; }; +template +class ConcurrentVector : public ConcurrentVectorImpl { + using ConcurrentVectorImpl::ConcurrentVectorImpl; +}; + +class FloatVector {}; +class BinaryVector {}; + +template <> +class ConcurrentVector : public ConcurrentVectorImpl { + using ConcurrentVectorImpl::ConcurrentVectorImpl; +}; + +template <> +class ConcurrentVector : public ConcurrentVectorImpl { + public: + explicit ConcurrentVector(int64_t dim) : binary_dim_(dim), ConcurrentVectorImpl(dim / 8) { + Assert(dim % 8 == 0); + } + + private: + int64_t binary_dim_; +}; + } // namespace milvus::segcore diff --git a/internal/core/src/segcore/DeletedRecord.h b/internal/core/src/segcore/DeletedRecord.h index 58b8f81207..c88d1edde5 100644 --- a/internal/core/src/segcore/DeletedRecord.h +++ b/internal/core/src/segcore/DeletedRecord.h @@ -55,8 +55,8 @@ struct DeletedRecord { public: std::atomic reserved = 0; AckResponder ack_responder_; - ConcurrentVector timestamps_; - ConcurrentVector uids_; + ConcurrentVector timestamps_; + ConcurrentVector uids_; private: std::shared_ptr lru_; diff --git a/internal/core/src/segcore/IndexingEntry.cpp b/internal/core/src/segcore/IndexingEntry.cpp index 460291a7da..4930bf7f98 100644 --- a/internal/core/src/segcore/IndexingEntry.cpp +++ b/internal/core/src/segcore/IndexingEntry.cpp @@ -22,7 +22,7 @@ VecIndexingEntry::BuildIndexRange(int64_t ack_beg, int64_t ack_end, const Vector assert(field_meta_.get_data_type() == DataType::VECTOR_FLOAT); auto dim = field_meta_.get_dim(); - auto source = dynamic_cast*>(vec_base); + auto source = dynamic_cast*>(vec_base); Assert(source); auto chunk_size = source->chunk_size(); assert(ack_end <= chunk_size); @@ -87,7 +87,7 @@ void ScalarIndexingEntry::BuildIndexRange(int64_t ack_beg, int64_t ack_end, const VectorBase* vec_base) { auto dim = field_meta_.get_dim(); - auto source = dynamic_cast*>(vec_base); + auto source = dynamic_cast*>(vec_base); Assert(source); auto chunk_size = source->chunk_size(); assert(ack_end <= chunk_size); @@ -106,7 +106,12 @@ ScalarIndexingEntry::BuildIndexRange(int64_t ack_beg, int64_t ack_end, const std::unique_ptr CreateIndex(const FieldMeta& field_meta) { if (field_meta.is_vector()) { - return std::make_unique(field_meta); + if (field_meta.get_data_type() == DataType::VECTOR_FLOAT) { + return std::make_unique(field_meta); + } else { + // TODO + PanicInfo("unsupported"); + } } switch (field_meta.get_data_type()) { case DataType::INT8: diff --git a/internal/core/src/segcore/IndexingEntry.h b/internal/core/src/segcore/IndexingEntry.h index 33b2d96c52..bc7787d8c7 100644 --- a/internal/core/src/segcore/IndexingEntry.h +++ b/internal/core/src/segcore/IndexingEntry.h @@ -100,7 +100,9 @@ class IndexingRecord { Initialize() { int offset = 0; for (auto& field : schema_) { - entries_.try_emplace(offset, CreateIndex(field)); + if (field.get_data_type() != DataType::VECTOR_BINARY) { + entries_.try_emplace(offset, CreateIndex(field)); + } ++offset; } assert(offset == schema_.size()); diff --git a/internal/core/src/segcore/InsertRecord.cpp b/internal/core/src/segcore/InsertRecord.cpp index ac007fb440..5eeb91c0af 100644 --- a/internal/core/src/segcore/InsertRecord.cpp +++ b/internal/core/src/segcore/InsertRecord.cpp @@ -16,36 +16,42 @@ namespace milvus::segcore { InsertRecord::InsertRecord(const Schema& schema) : uids_(1), timestamps_(1) { for (auto& field : schema) { if (field.is_vector()) { - Assert(field.get_data_type() == DataType::VECTOR_FLOAT); - entity_vec_.emplace_back(std::make_shared>(field.get_dim())); - continue; + if (field.get_data_type() == DataType::VECTOR_FLOAT) { + entity_vec_.emplace_back(std::make_shared>(field.get_dim())); + continue; + } else if (field.get_data_type() == DataType::VECTOR_BINARY) { + entity_vec_.emplace_back(std::make_shared>(field.get_dim())); + continue; + } else { + PanicInfo("unsupported"); + } } switch (field.get_data_type()) { case DataType::INT8: { - entity_vec_.emplace_back(std::make_shared>()); + entity_vec_.emplace_back(std::make_shared>()); break; } case DataType::INT16: { - entity_vec_.emplace_back(std::make_shared>()); + entity_vec_.emplace_back(std::make_shared>()); break; } case DataType::INT32: { - entity_vec_.emplace_back(std::make_shared>()); + entity_vec_.emplace_back(std::make_shared>()); break; } case DataType::INT64: { - entity_vec_.emplace_back(std::make_shared>()); + entity_vec_.emplace_back(std::make_shared>()); break; } case DataType::FLOAT: { - entity_vec_.emplace_back(std::make_shared>()); + entity_vec_.emplace_back(std::make_shared>()); break; } case DataType::DOUBLE: { - entity_vec_.emplace_back(std::make_shared>()); + entity_vec_.emplace_back(std::make_shared>()); break; } default: { diff --git a/internal/core/src/segcore/InsertRecord.h b/internal/core/src/segcore/InsertRecord.h index c8115ab428..ab8d5f5284 100644 --- a/internal/core/src/segcore/InsertRecord.h +++ b/internal/core/src/segcore/InsertRecord.h @@ -21,22 +21,14 @@ namespace milvus::segcore { struct InsertRecord { std::atomic reserved = 0; AckResponder ack_responder_; - ConcurrentVector timestamps_; - ConcurrentVector uids_; + ConcurrentVector timestamps_; + ConcurrentVector uids_; std::vector> entity_vec_; explicit InsertRecord(const Schema& schema); template auto - get_scalar_entity(int offset) const { - auto ptr = std::dynamic_pointer_cast>(entity_vec_[offset]); - Assert(ptr); - return ptr; - } - - template - auto - get_vec_entity(int offset) const { + get_entity(int offset) const { auto ptr = std::dynamic_pointer_cast>(entity_vec_[offset]); Assert(ptr); return ptr; @@ -44,15 +36,7 @@ struct InsertRecord { template auto - get_scalar_entity(int offset) { - auto ptr = std::dynamic_pointer_cast>(entity_vec_[offset]); - Assert(ptr); - return ptr; - } - - template - auto - get_vec_entity(int offset) { + get_entity(int offset) { auto ptr = std::dynamic_pointer_cast>(entity_vec_[offset]); Assert(ptr); return ptr; diff --git a/internal/core/src/segcore/SegmentNaive.cpp b/internal/core/src/segcore/SegmentNaive.cpp index 29edd86bd1..1091c6df25 100644 --- a/internal/core/src/segcore/SegmentNaive.cpp +++ b/internal/core/src/segcore/SegmentNaive.cpp @@ -249,7 +249,8 @@ SegmentNaive::QueryImpl(query::QueryDeprecatedPtr query_info, Timestamp timestam auto the_offset_opt = schema_->get_offset(query_info->field_name); Assert(the_offset_opt.has_value()); Assert(the_offset_opt.value() < record_.entity_vec_.size()); - auto vec_ptr = std::static_pointer_cast>(record_.entity_vec_.at(the_offset_opt.value())); + auto vec_ptr = + std::static_pointer_cast>(record_.entity_vec_.at(the_offset_opt.value())); auto index_entry = index_meta_->lookup_by_field(query_info->field_name); auto conf = index_entry.config; @@ -308,7 +309,8 @@ SegmentNaive::QueryBruteForceImpl(query::QueryDeprecatedPtr query_info, Timestam auto the_offset_opt = schema_->get_offset(query_info->field_name); Assert(the_offset_opt.has_value()); Assert(the_offset_opt.value() < record_.entity_vec_.size()); - auto vec_ptr = std::static_pointer_cast>(record_.entity_vec_.at(the_offset_opt.value())); + auto vec_ptr = + std::static_pointer_cast>(record_.entity_vec_.at(the_offset_opt.value())); std::vector final_uids(total_count); std::vector final_dis(total_count, std::numeric_limits::max()); @@ -364,7 +366,8 @@ SegmentNaive::QuerySlowImpl(query::QueryDeprecatedPtr query_info, Timestamp time auto the_offset_opt = schema_->get_offset(query_info->field_name); Assert(the_offset_opt.has_value()); Assert(the_offset_opt.value() < record_.entity_vec_.size()); - auto vec_ptr = std::static_pointer_cast>(record_.entity_vec_.at(the_offset_opt.value())); + auto vec_ptr = + std::static_pointer_cast>(record_.entity_vec_.at(the_offset_opt.value())); std::vector>> records(num_queries); auto get_L2_distance = [dim](const float* a, const float* b) { @@ -467,7 +470,7 @@ SegmentNaive::BuildVecIndexImpl(const IndexMeta::Entry& entry) { auto chunk_size = record_.uids_.chunk_size(); auto& uids = record_.uids_; - auto entities = record_.get_vec_entity(offset); + auto entities = record_.get_entity(offset); std::vector datasets; for (int chunk_id = 0; chunk_id < uids.chunk_size(); ++chunk_id) { diff --git a/internal/core/src/segcore/SegmentSmallIndex.cpp b/internal/core/src/segcore/SegmentSmallIndex.cpp index de52f78252..ea87efed04 100644 --- a/internal/core/src/segcore/SegmentSmallIndex.cpp +++ b/internal/core/src/segcore/SegmentSmallIndex.cpp @@ -238,7 +238,7 @@ SegmentSmallIndex::BuildVecIndexImpl(const IndexMeta::Entry& entry) { auto indexing = knowhere::VecIndexFactory::GetInstance().CreateVecIndex(entry.type, entry.mode); auto& uids = record_.uids_; - auto entities = record_.get_vec_entity(offset); + auto entities = record_.get_entity(offset); std::vector datasets; for (int chunk_id = 0; chunk_id < uids.chunk_size(); ++chunk_id) { @@ -367,7 +367,7 @@ SegmentSmallIndex::FillTargetEntry(const query::Plan* plan, QueryResult& results auto key_offset = key_offset_opt.value(); auto field_meta = schema_->operator[](key_offset); Assert(field_meta.get_data_type() == DataType::INT64); - auto uids = record_.get_scalar_entity(key_offset); + auto uids = record_.get_entity(key_offset); for (int64_t i = 0; i < size; ++i) { auto seg_offset = results.internal_seg_offsets_[i]; auto row_id = uids->operator[](seg_offset); diff --git a/internal/core/src/segcore/collection_c.h b/internal/core/src/segcore/collection_c.h index e00a9be688..34a5d8a0fa 100644 --- a/internal/core/src/segcore/collection_c.h +++ b/internal/core/src/segcore/collection_c.h @@ -9,10 +9,22 @@ // is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express // or implied. See the License for the specific language governing permissions and limitations under the License +#pragma once + #ifdef __cplusplus extern "C" { #endif +enum ErrorCode { + Success = 0, + UnexpectedException = 1, +}; + +typedef struct CStatus { + int error_code; + const char* error_msg; +} CStatus; + typedef void* CCollection; CCollection diff --git a/internal/core/src/segcore/plan_c.cpp b/internal/core/src/segcore/plan_c.cpp index 0d0fcf07de..0fe80580cf 100644 --- a/internal/core/src/segcore/plan_c.cpp +++ b/internal/core/src/segcore/plan_c.cpp @@ -13,21 +13,52 @@ #include "query/Plan.h" #include "segcore/Collection.h" -CPlan -CreatePlan(CCollection c_col, const char* dsl) { +CStatus +CreatePlan(CCollection c_col, const char* dsl, CPlan* res_plan) { auto col = (milvus::segcore::Collection*)c_col; - auto res = milvus::query::CreatePlan(*col->get_schema(), dsl); - return (CPlan)res.release(); + try { + auto res = milvus::query::CreatePlan(*col->get_schema(), dsl); + + auto status = CStatus(); + status.error_code = Success; + status.error_msg = ""; + auto plan = (CPlan)res.release(); + *res_plan = plan; + return status; + } catch (std::exception& e) { + auto status = CStatus(); + status.error_code = UnexpectedException; + status.error_msg = strdup(e.what()); + *res_plan = nullptr; + return status; + } } -CPlaceholderGroup -ParsePlaceholderGroup(CPlan c_plan, void* placeholder_group_blob, int64_t blob_size) { +CStatus +ParsePlaceholderGroup(CPlan c_plan, + void* placeholder_group_blob, + int64_t blob_size, + CPlaceholderGroup* res_placeholder_group) { std::string blob_string((char*)placeholder_group_blob, (char*)placeholder_group_blob + blob_size); auto plan = (milvus::query::Plan*)c_plan; - auto res = milvus::query::ParsePlaceholderGroup(plan, blob_string); - return (CPlaceholderGroup)res.release(); + try { + auto res = milvus::query::ParsePlaceholderGroup(plan, blob_string); + + auto status = CStatus(); + status.error_code = Success; + status.error_msg = ""; + auto group = (CPlaceholderGroup)res.release(); + *res_placeholder_group = group; + return status; + } catch (std::exception& e) { + auto status = CStatus(); + status.error_code = UnexpectedException; + status.error_msg = strdup(e.what()); + *res_placeholder_group = nullptr; + return status; + } } int64_t diff --git a/internal/core/src/segcore/plan_c.h b/internal/core/src/segcore/plan_c.h index b25e4843d4..8f995f22a5 100644 --- a/internal/core/src/segcore/plan_c.h +++ b/internal/core/src/segcore/plan_c.h @@ -20,11 +20,14 @@ extern "C" { typedef void* CPlan; typedef void* CPlaceholderGroup; -CPlan -CreatePlan(CCollection col, const char* dsl); +CStatus +CreatePlan(CCollection col, const char* dsl, CPlan* res_plan); -CPlaceholderGroup -ParsePlaceholderGroup(CPlan plan, void* placeholder_group_blob, int64_t blob_size); +CStatus +ParsePlaceholderGroup(CPlan plan, + void* placeholder_group_blob, + int64_t blob_size, + CPlaceholderGroup* res_placeholder_group); int64_t GetNumOfQueries(CPlaceholderGroup placeholder_group); diff --git a/internal/core/src/segcore/segment_c.cpp b/internal/core/src/segcore/segment_c.cpp index 7b6caee4f5..b053daddca 100644 --- a/internal/core/src/segcore/segment_c.cpp +++ b/internal/core/src/segcore/segment_c.cpp @@ -71,7 +71,7 @@ Insert(CSegmentBase c_segment, status.error_code = Success; status.error_msg = ""; return status; - } catch (std::runtime_error& e) { + } catch (std::exception& e) { auto status = CStatus(); status.error_code = UnexpectedException; status.error_msg = strdup(e.what()); @@ -103,7 +103,7 @@ Delete( status.error_code = Success; status.error_msg = ""; return status; - } catch (std::runtime_error& e) { + } catch (std::exception& e) { auto status = CStatus(); status.error_code = UnexpectedException; status.error_msg = strdup(e.what()); @@ -141,7 +141,7 @@ Search(CSegmentBase c_segment, auto res = segment->Search(plan, placeholder_groups.data(), timestamps, num_groups, *query_result); status.error_code = Success; status.error_msg = ""; - } catch (std::runtime_error& e) { + } catch (std::exception& e) { status.error_code = UnexpectedException; status.error_msg = strdup(e.what()); } diff --git a/internal/core/src/segcore/segment_c.h b/internal/core/src/segcore/segment_c.h index d3583f1295..7f73223347 100644 --- a/internal/core/src/segcore/segment_c.h +++ b/internal/core/src/segcore/segment_c.h @@ -17,22 +17,11 @@ extern "C" { #include #include -#include "segcore/collection_c.h" #include "segcore/plan_c.h" typedef void* CSegmentBase; typedef void* CQueryResult; -enum ErrorCode { - Success = 0, - UnexpectedException = 1, -}; - -typedef struct CStatus { - int error_code; - const char* error_msg; -} CStatus; - CSegmentBase NewSegment(CCollection collection, uint64_t segment_id); diff --git a/internal/core/unittest/CMakeLists.txt b/internal/core/unittest/CMakeLists.txt index 4210b43f54..cb61890d2d 100644 --- a/internal/core/unittest/CMakeLists.txt +++ b/internal/core/unittest/CMakeLists.txt @@ -10,7 +10,9 @@ set(MILVUS_TEST_FILES test_indexing.cpp test_query.cpp test_expr.cpp - test_bitmap.cpp) + test_bitmap.cpp + test_binary.cpp + ) add_executable(all_tests ${MILVUS_TEST_FILES} ) diff --git a/internal/core/unittest/test_binary.cpp b/internal/core/unittest/test_binary.cpp new file mode 100644 index 0000000000..e5adf3a269 --- /dev/null +++ b/internal/core/unittest/test_binary.cpp @@ -0,0 +1,31 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#include +#include "test_utils/DataGen.h" + +using namespace milvus; +using namespace milvus::query; +using namespace milvus::segcore; + +TEST(Binary, Insert) { + int64_t N = 100000; + int64_t num_queries = 10; + int64_t topK = 5; + auto schema = std::make_shared(); + schema->AddField("vecbin", DataType::VECTOR_BINARY, 128); + schema->AddField("age", DataType::INT64); + auto dataset = DataGen(schema, N, 10); + auto segment = CreateSegment(schema); + segment->PreInsert(N); + segment->Insert(0, N, dataset.row_ids_.data(), dataset.timestamps_.data(), dataset.raw_); + int i = 1 + 1; +} diff --git a/internal/core/unittest/test_c_api.cpp b/internal/core/unittest/test_c_api.cpp index 1a5dee4315..d974f4bc46 100644 --- a/internal/core/unittest/test_c_api.cpp +++ b/internal/core/unittest/test_c_api.cpp @@ -14,7 +14,6 @@ #include #include -#include "segcore/collection_c.h" #include "pb/service_msg.pb.h" #include "segcore/reduce_c.h" @@ -151,8 +150,15 @@ TEST(CApiTest, SearchTest) { } auto blob = raw_group.SerializeAsString(); - auto plan = CreatePlan(collection, dsl_string); - auto placeholderGroup = ParsePlaceholderGroup(plan, blob.data(), blob.length()); + void* plan = nullptr; + + auto status = CreatePlan(collection, dsl_string, &plan); + assert(status.error_code == Success); + + void* placeholderGroup = nullptr; + status = ParsePlaceholderGroup(plan, blob.data(), blob.length(), &placeholderGroup); + assert(status.error_code == Success); + std::vector placeholderGroups; placeholderGroups.push_back(placeholderGroup); timestamps.clear(); @@ -611,8 +617,15 @@ TEST(CApiTest, Reduce) { } auto blob = raw_group.SerializeAsString(); - auto plan = CreatePlan(collection, dsl_string); - auto placeholderGroup = ParsePlaceholderGroup(plan, blob.data(), blob.length()); + void* plan = nullptr; + + auto status = CreatePlan(collection, dsl_string, &plan); + assert(status.error_code == Success); + + void* placeholderGroup = nullptr; + status = ParsePlaceholderGroup(plan, blob.data(), blob.length(), &placeholderGroup); + assert(status.error_code == Success); + std::vector placeholderGroups; placeholderGroups.push_back(placeholderGroup); timestamps.clear(); diff --git a/internal/core/unittest/test_concurrent_vector.cpp b/internal/core/unittest/test_concurrent_vector.cpp index 847697dde6..ce0acfb11c 100644 --- a/internal/core/unittest/test_concurrent_vector.cpp +++ b/internal/core/unittest/test_concurrent_vector.cpp @@ -38,7 +38,7 @@ TEST(ConcurrentVector, TestABI) { TEST(ConcurrentVector, TestSingle) { auto dim = 8; - ConcurrentVector c_vec(dim); + ConcurrentVectorImpl c_vec(dim); std::default_random_engine e(42); int data = 0; auto total_count = 0; @@ -66,7 +66,7 @@ TEST(ConcurrentVector, TestMultithreads) { constexpr int threads = 16; std::vector total_counts(threads); - ConcurrentVector c_vec(dim); + ConcurrentVectorImpl c_vec(dim); std::atomic ack_counter = 0; // std::mutex mutex; diff --git a/internal/core/unittest/test_indexing.cpp b/internal/core/unittest/test_indexing.cpp index 6f677f4566..266549d01c 100644 --- a/internal/core/unittest/test_indexing.cpp +++ b/internal/core/unittest/test_indexing.cpp @@ -32,6 +32,8 @@ #include #include "test_utils/Timer.h" #include "segcore/Reduce.h" +#include "test_utils/DataGen.h" +#include "query/BruteForceSearch.h" using std::cin; using std::cout; @@ -55,7 +57,7 @@ generate_data(int N) { uids.push_back(10 * N + i); timestamps.push_back(0); // append vec - float vec[DIM]; + vector vec(DIM); for (auto& x : vec) { x = distribution(er); } @@ -81,6 +83,7 @@ TEST(Indexing, SmartBruteForce) { auto [raw_data, timestamps, uids] = generate_data(N); auto total_count = DIM * TOPK; auto raw = (const float*)raw_data.data(); + AssertInfo(raw, "wtf"); constexpr int64_t queries = 3; auto heap = faiss::float_maxheap_array_t{}; @@ -231,3 +234,106 @@ TEST(Indexing, IVFFlatNM) { cout << ids[i] << "->" << dis[i] << endl; } } + +TEST(Indexing, DISABLED_BinaryBruteForce) { + int64_t N = 100000; + int64_t num_queries = 10; + int64_t topk = 5; + int64_t dim = 64; + auto result_count = topk * num_queries; + auto schema = std::make_shared(); + schema->AddField("vecbin", DataType::VECTOR_BINARY, dim); + schema->AddField("age", DataType::INT64); + auto dataset = DataGen(schema, N, 10); + vector distances(result_count); + vector ids(result_count); + auto bin_vec = dataset.get_col(0); + auto line_sizeof = schema->operator[](0).get_sizeof(); + auto query_data = 1024 * line_sizeof + bin_vec.data(); + query::BinarySearchBruteForce(faiss::MetricType::METRIC_Jaccard, line_sizeof, bin_vec.data(), N, topk, num_queries, + query_data, distances.data(), ids.data()); + QueryResult qr; + qr.num_queries_ = num_queries; + qr.topK_ = topk; + qr.internal_seg_offsets_ = ids; + qr.result_distances_ = distances; + + auto json = QueryResultToJson(qr); + auto ref = Json::parse(R"( +[ + [ + [ + "1024->0.000000", + "86966->0.395349", + "24843->0.404762", + "13806->0.416667", + "44313->0.421053" + ], + [ + "1025->0.000000", + "14226->0.348837", + "1488->0.365854", + "47337->0.377778", + "20913->0.377778" + ], + [ + "1026->0.000000", + "81882->0.386364", + "9215->0.409091", + "95024->0.409091", + "54987->0.414634" + ], + [ + "1027->0.000000", + "68981->0.394737", + "75528->0.404762", + "68794->0.405405", + "21975->0.425000" + ], + [ + "1028->0.000000", + "90290->0.375000", + "34309->0.394737", + "58559->0.400000", + "33865->0.400000" + ], + [ + "1029->0.000000", + "62722->0.388889", + "89070->0.394737", + "18528->0.414634", + "94971->0.421053" + ], + [ + "1030->0.000000", + "67402->0.333333", + "3988->0.347826", + "86376->0.354167", + "84381->0.361702" + ], + [ + "1031->0.000000", + "81569->0.325581", + "12715->0.347826", + "40332->0.363636", + "21037->0.372093" + ], + [ + "1032->0.000000", + "60536->0.428571", + "93293->0.432432", + "70969->0.435897", + "64048->0.450000" + ], + [ + "1033->0.000000", + "99022->0.394737", + "11763->0.405405", + "50073->0.428571", + "97118->0.428571" + ] + ] +] +)"); + ASSERT_EQ(json, ref); +} diff --git a/internal/core/unittest/test_query.cpp b/internal/core/unittest/test_query.cpp index 556cfd7cbb..f25f3ddf62 100644 --- a/internal/core/unittest/test_query.cpp +++ b/internal/core/unittest/test_query.cpp @@ -214,17 +214,9 @@ TEST(Query, ExecWithPredicate) { Timestamp time = 1000000; std::vector ph_group_arr = {ph_group.get()}; segment->Search(plan.get(), ph_group_arr.data(), &time, 1, qr); - std::vector> results; int topk = 5; - for (int q = 0; q < num_queries; ++q) { - std::vector result; - for (int k = 0; k < topk; ++k) { - int index = q * topk + k; - result.emplace_back(std::to_string(qr.result_ids_[index]) + "->" + - std::to_string(qr.result_distances_[index])); - } - results.emplace_back(std::move(result)); - } + + Json json = QueryResultToJson(qr); auto ref = Json::parse(R"([ [ @@ -266,7 +258,6 @@ TEST(Query, ExecWithPredicate) { ] ])"); - Json json{results}; ASSERT_EQ(json, ref); } diff --git a/internal/core/unittest/test_utils/DataGen.h b/internal/core/unittest/test_utils/DataGen.h index de6860b503..4ce51c3612 100644 --- a/internal/core/unittest/test_utils/DataGen.h +++ b/internal/core/unittest/test_utils/DataGen.h @@ -87,6 +87,24 @@ DataGen(SchemaPtr schema, int64_t N, uint64_t seed = 42) { insert_cols(data); break; } + case engine::DataType::VECTOR_BINARY: { + auto dim = field.get_dim(); + Assert(dim % 8 == 0); + vector data(dim / 8 * N); + for (auto& x : data) { + x = er(); + } + insert_cols(data); + break; + } + case engine::DataType::INT64: { + vector data(N); + for (auto& x : data) { + x = er(); + } + insert_cols(data); + break; + } case engine::DataType::INT32: { vector data(N); for (auto& x : data) { @@ -142,4 +160,41 @@ CreatePlaceholderGroup(int64_t num_queries, int dim, int64_t seed = 42) { return raw_group; } +inline auto +CreateBinaryPlaceholderGroup(int64_t num_queries, int64_t dim, int64_t seed = 42) { + assert(dim % 8 == 0); + namespace ser = milvus::proto::service; + ser::PlaceholderGroup raw_group; + auto value = raw_group.add_placeholders(); + value->set_tag("$0"); + value->set_type(ser::PlaceholderType::VECTOR_FLOAT); + std::default_random_engine e(seed); + for (int i = 0; i < num_queries; ++i) { + std::vector vec; + for (int d = 0; d < dim / 8; ++d) { + vec.push_back(e()); + } + // std::string line((char*)vec.data(), (char*)vec.data() + vec.size() * sizeof(float)); + value->add_values(vec.data(), vec.size() * sizeof(float)); + } + return raw_group; +} + +inline Json +QueryResultToJson(const QueryResult& qr) { + int64_t num_queries = qr.num_queries_; + int64_t topk = qr.topK_; + std::vector> results; + for (int q = 0; q < num_queries; ++q) { + std::vector result; + for (int k = 0; k < topk; ++k) { + int index = q * topk + k; + result.emplace_back(std::to_string(qr.internal_seg_offsets_[index]) + "->" + + std::to_string(qr.result_distances_[index])); + } + results.emplace_back(std::move(result)); + } + return Json{results}; +}; + } // namespace milvus::segcore diff --git a/internal/master/meta_table.go b/internal/master/meta_table.go index 788e8641ca..75a906e710 100644 --- a/internal/master/meta_table.go +++ b/internal/master/meta_table.go @@ -110,7 +110,7 @@ func (mt *metaTable) saveCollectionMeta(coll *pb.CollectionMeta) error { collBytes := proto.MarshalTextString(coll) mt.collID2Meta[coll.ID] = *coll mt.collName2ID[coll.Schema.Name] = coll.ID - return mt.client.Save("/collection/"+strconv.FormatInt(coll.ID, 10), string(collBytes)) + return mt.client.Save("/collection/"+strconv.FormatInt(coll.ID, 10), collBytes) } // metaTable.ddLock.Lock() before call this function @@ -119,7 +119,7 @@ func (mt *metaTable) saveSegmentMeta(seg *pb.SegmentMeta) error { mt.segID2Meta[seg.SegmentID] = *seg - return mt.client.Save("/segment/"+strconv.FormatInt(seg.SegmentID, 10), string(segBytes)) + return mt.client.Save("/segment/"+strconv.FormatInt(seg.SegmentID, 10), segBytes) } // metaTable.ddLock.Lock() before call this function @@ -132,7 +132,7 @@ func (mt *metaTable) saveCollectionAndDeleteSegmentsMeta(coll *pb.CollectionMeta kvs := make(map[string]string) collStrs := proto.MarshalTextString(coll) - kvs["/collection/"+strconv.FormatInt(coll.ID, 10)] = string(collStrs) + kvs["/collection/"+strconv.FormatInt(coll.ID, 10)] = collStrs for _, segID := range segIDs { _, ok := mt.segID2Meta[segID] @@ -152,14 +152,14 @@ func (mt *metaTable) saveCollectionsAndSegmentsMeta(coll *pb.CollectionMeta, seg kvs := make(map[string]string) collBytes := proto.MarshalTextString(coll) - kvs["/collection/"+strconv.FormatInt(coll.ID, 10)] = string(collBytes) + kvs["/collection/"+strconv.FormatInt(coll.ID, 10)] = collBytes mt.collID2Meta[coll.ID] = *coll mt.collName2ID[coll.Schema.Name] = coll.ID segBytes := proto.MarshalTextString(seg) - kvs["/segment/"+strconv.FormatInt(seg.SegmentID, 10)] = string(segBytes) + kvs["/segment/"+strconv.FormatInt(seg.SegmentID, 10)] = segBytes mt.segID2Meta[seg.SegmentID] = *seg diff --git a/internal/master/segment_manager_test.go b/internal/master/segment_manager_test.go index 0287ea955c..bc20bf308d 100644 --- a/internal/master/segment_manager_test.go +++ b/internal/master/segment_manager_test.go @@ -186,7 +186,7 @@ func TestSegmentManager_SegmentStats(t *testing.T) { baseMsg := msgstream.BaseMsg{ BeginTimestamp: 0, EndTimestamp: 0, - HashValues: []int32{1}, + HashValues: []uint32{1}, } msg := msgstream.QueryNodeSegStatsMsg{ QueryNodeSegStats: stats, @@ -358,7 +358,7 @@ func TestSegmentManager_RPC(t *testing.T) { }, }, BaseMsg: msgstream.BaseMsg{ - HashValues: []int32{0}, + HashValues: []uint32{0}, }, }, }, diff --git a/internal/master/time_sync_producer.go b/internal/master/time_sync_producer.go index fc58198858..30d5ec962a 100644 --- a/internal/master/time_sync_producer.go +++ b/internal/master/time_sync_producer.go @@ -60,7 +60,7 @@ func (syncMsgProducer *timeSyncMsgProducer) broadcastMsg(barrier TimeTickBarrier baseMsg := ms.BaseMsg{ BeginTimestamp: timetick, EndTimestamp: timetick, - HashValues: []int32{0}, + HashValues: []uint32{0}, } timeTickResult := internalPb.TimeTickMsg{ MsgType: internalPb.MsgType_kTimeTick, diff --git a/internal/master/timesync.go b/internal/master/timesync.go index e4448cab05..9ea8b35a9d 100644 --- a/internal/master/timesync.go +++ b/internal/master/timesync.go @@ -69,7 +69,7 @@ func (ttBarrier *softTimeTickBarrier) Start() error { for _, timetickmsg := range ttmsgs.Msgs { ttmsg := timetickmsg.(*ms.TimeTickMsg) oldT, ok := ttBarrier.peer2LastTt[ttmsg.PeerID] - log.Printf("[softTimeTickBarrier] peer(%d)=%d\n", ttmsg.PeerID, ttmsg.Timestamp) + // log.Printf("[softTimeTickBarrier] peer(%d)=%d\n", ttmsg.PeerID, ttmsg.Timestamp) if !ok { log.Printf("[softTimeTickBarrier] Warning: peerID %d not exist\n", ttmsg.PeerID) diff --git a/internal/master/timesync_test.go b/internal/master/timesync_test.go index a0cdb2cef5..59fb7b2762 100644 --- a/internal/master/timesync_test.go +++ b/internal/master/timesync_test.go @@ -16,7 +16,7 @@ import ( func getTtMsg(msgType internalPb.MsgType, peerID UniqueID, timeStamp uint64) ms.TsMsg { baseMsg := ms.BaseMsg{ - HashValues: []int32{int32(peerID)}, + HashValues: []uint32{uint32(peerID)}, } timeTickResult := internalPb.TimeTickMsg{ MsgType: internalPb.MsgType_kTimeTick, diff --git a/internal/msgstream/msgstream.go b/internal/msgstream/msgstream.go index c36d488ab3..54a4bd93ca 100644 --- a/internal/msgstream/msgstream.go +++ b/internal/msgstream/msgstream.go @@ -148,7 +148,7 @@ func (ms *PulsarMsgStream) Produce(msgPack *MsgPack) error { bucketValues[index] = channelID continue } - bucketValues[index] = hashValue % int32(len(ms.producers)) + bucketValues[index] = int32(hashValue % uint32(len(ms.producers))) } reBucketValues[channelID] = bucketValues } diff --git a/internal/msgstream/msgstream_test.go b/internal/msgstream/msgstream_test.go index ce1b76959b..b35d4aa6b4 100644 --- a/internal/msgstream/msgstream_test.go +++ b/internal/msgstream/msgstream_test.go @@ -36,11 +36,11 @@ func repackFunc(msgs []TsMsg, hashKeys [][]int32) (map[int32]*MsgPack, error) { return result, nil } -func getTsMsg(msgType MsgType, reqID UniqueID, hashValue int32) TsMsg { +func getTsMsg(msgType MsgType, reqID UniqueID, hashValue uint32) TsMsg { baseMsg := BaseMsg{ BeginTimestamp: 0, EndTimestamp: 0, - HashValues: []int32{hashValue}, + HashValues: []uint32{hashValue}, } switch msgType { case internalPb.MsgType_kInsert: @@ -129,11 +129,11 @@ func getTsMsg(msgType MsgType, reqID UniqueID, hashValue int32) TsMsg { return nil } -func getTimeTickMsg(reqID UniqueID, hashValue int32, time uint64) TsMsg { +func getTimeTickMsg(reqID UniqueID, hashValue uint32, time uint64) TsMsg { baseMsg := BaseMsg{ BeginTimestamp: 0, EndTimestamp: 0, - HashValues: []int32{hashValue}, + HashValues: []uint32{hashValue}, } timeTickResult := internalPb.TimeTickMsg{ MsgType: internalPb.MsgType_kTimeTick, @@ -369,7 +369,7 @@ func TestStream_PulsarMsgStream_InsertRepackFunc(t *testing.T) { baseMsg := BaseMsg{ BeginTimestamp: 0, EndTimestamp: 0, - HashValues: []int32{1, 3}, + HashValues: []uint32{1, 3}, } insertRequest := internalPb.InsertRequest{ @@ -422,7 +422,7 @@ func TestStream_PulsarMsgStream_DeleteRepackFunc(t *testing.T) { baseMsg := BaseMsg{ BeginTimestamp: 0, EndTimestamp: 0, - HashValues: []int32{1, 3}, + HashValues: []uint32{1, 3}, } deleteRequest := internalPb.DeleteRequest{ diff --git a/internal/msgstream/task.go b/internal/msgstream/task.go index 69fe04cd1d..1863ba7c60 100644 --- a/internal/msgstream/task.go +++ b/internal/msgstream/task.go @@ -11,7 +11,7 @@ type TsMsg interface { BeginTs() Timestamp EndTs() Timestamp Type() MsgType - HashKeys() []int32 + HashKeys() []uint32 Marshal(TsMsg) ([]byte, error) Unmarshal([]byte) (TsMsg, error) } @@ -19,7 +19,7 @@ type TsMsg interface { type BaseMsg struct { BeginTimestamp Timestamp EndTimestamp Timestamp - HashValues []int32 + HashValues []uint32 } func (bm *BaseMsg) BeginTs() Timestamp { @@ -30,7 +30,7 @@ func (bm *BaseMsg) EndTs() Timestamp { return bm.EndTimestamp } -func (bm *BaseMsg) HashKeys() []int32 { +func (bm *BaseMsg) HashKeys() []uint32 { return bm.HashValues } diff --git a/internal/msgstream/task_test.go b/internal/msgstream/task_test.go index d9903af104..17133b3eb4 100644 --- a/internal/msgstream/task_test.go +++ b/internal/msgstream/task_test.go @@ -87,11 +87,11 @@ func newRepackFunc(tsMsgs []TsMsg, hashKeys [][]int32) (map[int32]*MsgPack, erro return result, nil } -func getInsertTask(reqID UniqueID, hashValue int32) TsMsg { +func getInsertTask(reqID UniqueID, hashValue uint32) TsMsg { baseMsg := BaseMsg{ BeginTimestamp: 0, EndTimestamp: 0, - HashValues: []int32{hashValue}, + HashValues: []uint32{hashValue}, } insertRequest := internalPb.InsertRequest{ MsgType: internalPb.MsgType_kInsert, diff --git a/internal/proto/service_msg.proto b/internal/proto/service_msg.proto index b480b7c13b..23c199b0d3 100644 --- a/internal/proto/service_msg.proto +++ b/internal/proto/service_msg.proto @@ -30,7 +30,7 @@ message RowBatch { string collection_name = 1; string partition_tag = 2; repeated common.Blob row_data = 3; - repeated int32 hash_keys = 4; + repeated uint32 hash_keys = 4; } /** diff --git a/internal/proto/servicepb/service_msg.pb.go b/internal/proto/servicepb/service_msg.pb.go index 48b44e7817..ccff111ac1 100644 --- a/internal/proto/servicepb/service_msg.pb.go +++ b/internal/proto/servicepb/service_msg.pb.go @@ -148,7 +148,7 @@ type RowBatch struct { CollectionName string `protobuf:"bytes,1,opt,name=collection_name,json=collectionName,proto3" json:"collection_name,omitempty"` PartitionTag string `protobuf:"bytes,2,opt,name=partition_tag,json=partitionTag,proto3" json:"partition_tag,omitempty"` RowData []*commonpb.Blob `protobuf:"bytes,3,rep,name=row_data,json=rowData,proto3" json:"row_data,omitempty"` - HashKeys []int32 `protobuf:"varint,4,rep,packed,name=hash_keys,json=hashKeys,proto3" json:"hash_keys,omitempty"` + HashKeys []uint32 `protobuf:"varint,4,rep,packed,name=hash_keys,json=hashKeys,proto3" json:"hash_keys,omitempty"` XXX_NoUnkeyedLiteral struct{} `json:"-"` XXX_unrecognized []byte `json:"-"` XXX_sizecache int32 `json:"-"` @@ -200,7 +200,7 @@ func (m *RowBatch) GetRowData() []*commonpb.Blob { return nil } -func (m *RowBatch) GetHashKeys() []int32 { +func (m *RowBatch) GetHashKeys() []uint32 { if m != nil { return m.HashKeys } @@ -881,36 +881,36 @@ var fileDescriptor_b4b40b84dd2f74cb = []byte{ 0x9d, 0x41, 0x7f, 0x1a, 0xd1, 0x98, 0x18, 0x7b, 0xc4, 0x45, 0x2f, 0xa0, 0x18, 0xf1, 0x51, 0xdf, 0x21, 0x92, 0x98, 0x7a, 0x5d, 0x6f, 0x94, 0xb6, 0x56, 0x9b, 0x73, 0x6d, 0xca, 0xba, 0xd3, 0xf1, 0xf9, 0x00, 0xff, 0x1d, 0xf1, 0xd1, 0x1e, 0x91, 0x04, 0xad, 0xc3, 0xa2, 0x47, 0x84, 0xd7, 0xff, - 0x4c, 0xc7, 0xc2, 0xcc, 0xd7, 0xf5, 0xc6, 0x02, 0x2e, 0x26, 0x86, 0x23, 0x3a, 0x16, 0xd6, 0x08, - 0xaa, 0xa7, 0x3e, 0xb1, 0xa9, 0xc7, 0x7d, 0x87, 0x46, 0xe7, 0xc4, 0x8f, 0x27, 0x35, 0x69, 0x93, - 0x9a, 0xd0, 0x0e, 0xe4, 0xe5, 0x38, 0xa4, 0x2a, 0xa9, 0xca, 0xd6, 0x66, 0xf3, 0xb6, 0xd9, 0x34, - 0x67, 0xe2, 0xf4, 0xc6, 0x21, 0xc5, 0xca, 0x05, 0xad, 0x40, 0xe1, 0x32, 0x89, 0x2a, 0x54, 0xc6, - 0x06, 0xce, 0x34, 0xeb, 0xd3, 0x1c, 0xf0, 0xdb, 0x88, 0xc7, 0x21, 0x3a, 0x04, 0x23, 0x9c, 0xda, - 0x84, 0xa9, 0xa9, 0x1a, 0xff, 0xff, 0x2d, 0x9c, 0x4a, 0x1b, 0xcf, 0xf9, 0x5a, 0x5f, 0x34, 0x58, - 0x78, 0x1f, 0xd3, 0x68, 0x7c, 0xf7, 0x19, 0x6c, 0x42, 0x65, 0x6e, 0x06, 0xc2, 0xcc, 0xd5, 0xf5, - 0xc6, 0x22, 0x2e, 0xcf, 0x0e, 0x41, 0x24, 0xed, 0x71, 0x84, 0x6f, 0xea, 0x69, 0x7b, 0x1c, 0xe1, - 0xa3, 0x67, 0xb0, 0x3c, 0x83, 0xdd, 0x77, 0x93, 0x62, 0xcc, 0x7c, 0x5d, 0x6b, 0x18, 0xb8, 0x1a, - 0xde, 0x28, 0xd2, 0xfa, 0x08, 0x95, 0x33, 0x19, 0xb1, 0xc0, 0xc5, 0x54, 0x84, 0x3c, 0x10, 0x14, - 0x6d, 0x43, 0x41, 0x48, 0x22, 0x63, 0xa1, 0xf2, 0x2a, 0x6d, 0xad, 0xdf, 0x3a, 0xd4, 0x33, 0xf5, - 0x05, 0x67, 0x5f, 0x51, 0x0d, 0x16, 0x54, 0x27, 0xb3, 0x45, 0x49, 0x15, 0xeb, 0x02, 0x8c, 0x0e, - 0xe7, 0xfe, 0x23, 0x86, 0x2e, 0x5e, 0x87, 0x26, 0x80, 0xd2, 0xbc, 0x8f, 0x99, 0x90, 0x0f, 0x03, - 0x98, 0xee, 0x44, 0xda, 0xe0, 0xeb, 0x9d, 0x18, 0xc0, 0x3f, 0x07, 0x81, 0xa4, 0x2e, 0x8d, 0x1e, - 0x1b, 0x43, 0x9f, 0x60, 0x08, 0xa8, 0x65, 0x18, 0x98, 0x04, 0x2e, 0x7d, 0x70, 0xa7, 0x06, 0xd4, - 0x65, 0x81, 0xea, 0x94, 0x8e, 0x53, 0x25, 0x59, 0x10, 0x1a, 0x38, 0x6a, 0x41, 0x74, 0x9c, 0x88, - 0xd6, 0x77, 0x0d, 0xfe, 0x9d, 0x72, 0xd3, 0x1e, 0x15, 0x76, 0xc4, 0xc2, 0x44, 0xbc, 0x1f, 0xec, - 0x2b, 0x28, 0xa4, 0xcc, 0xa7, 0x70, 0x4b, 0x3f, 0x1d, 0x64, 0xca, 0x8a, 0x53, 0xc0, 0x33, 0x65, - 0xc0, 0x99, 0x13, 0x6a, 0x03, 0x24, 0x81, 0x98, 0x90, 0xcc, 0x16, 0x19, 0x91, 0xfc, 0x77, 0x2b, - 0xee, 0x11, 0x1d, 0xab, 0xdb, 0x3a, 0x25, 0x2c, 0xc2, 0x33, 0x4e, 0xd6, 0x37, 0x0d, 0x6a, 0x13, - 0xc6, 0x7c, 0x70, 0x3d, 0x2f, 0x21, 0xaf, 0xce, 0x32, 0xad, 0x66, 0xe3, 0x17, 0xf7, 0x3e, 0x4b, - 0xd0, 0x58, 0x39, 0x3c, 0x46, 0x25, 0x47, 0x90, 0x7f, 0xc7, 0xa4, 0xba, 0xea, 0x83, 0xbd, 0x94, - 0x72, 0x74, 0x9c, 0x88, 0x68, 0x75, 0x86, 0x6d, 0x73, 0x8a, 0xbb, 0x26, 0x94, 0xba, 0x92, 0x0c, - 0x80, 0x47, 0x19, 0xa9, 0xe5, 0x70, 0xa6, 0x59, 0xe7, 0x50, 0x52, 0x9c, 0x83, 0xa9, 0x88, 0x7d, - 0x79, 0xbf, 0x66, 0x20, 0xc8, 0x7b, 0x4c, 0x8a, 0x0c, 0x52, 0xc9, 0x4f, 0x5f, 0xc3, 0xd2, 0x0d, - 0x76, 0x45, 0x45, 0xc8, 0x77, 0x4f, 0xba, 0xfb, 0xd5, 0xbf, 0xd0, 0x32, 0x94, 0xcf, 0xf7, 0x77, - 0x7b, 0x27, 0xb8, 0xdf, 0x39, 0xe8, 0xb6, 0xf1, 0x45, 0xd5, 0x41, 0x55, 0x30, 0x32, 0xd3, 0x9b, - 0xe3, 0x93, 0x76, 0xaf, 0x4a, 0x3b, 0xbb, 0x1f, 0xda, 0x2e, 0x93, 0x5e, 0x3c, 0x48, 0x50, 0x5b, - 0x57, 0xcc, 0xf7, 0xd9, 0x95, 0xa4, 0xb6, 0xd7, 0x4a, 0x33, 0x7a, 0xee, 0x30, 0x21, 0x23, 0x36, - 0x88, 0x25, 0x75, 0x5a, 0x2c, 0x90, 0x34, 0x0a, 0x88, 0xdf, 0x52, 0x69, 0xb6, 0xb2, 0x01, 0x84, - 0x83, 0x41, 0x41, 0x19, 0xb6, 0x7f, 0x04, 0x00, 0x00, 0xff, 0xff, 0x33, 0xc8, 0x08, 0xe2, 0xaf, + 0x4c, 0xc7, 0xc2, 0xcc, 0xd7, 0xf5, 0x46, 0x19, 0x17, 0x13, 0xc3, 0x11, 0x1d, 0x0b, 0x6b, 0x04, + 0xd5, 0x53, 0x9f, 0xd8, 0xd4, 0xe3, 0xbe, 0x43, 0xa3, 0x73, 0xe2, 0xc7, 0x93, 0x9a, 0xb4, 0x49, + 0x4d, 0x68, 0x07, 0xf2, 0x72, 0x1c, 0x52, 0x95, 0x54, 0x65, 0x6b, 0xb3, 0x79, 0xdb, 0x6c, 0x9a, + 0x33, 0x71, 0x7a, 0xe3, 0x90, 0x62, 0xe5, 0x82, 0x56, 0xa0, 0x70, 0x99, 0x44, 0x15, 0x2a, 0x63, + 0x03, 0x67, 0x9a, 0xf5, 0x69, 0x0e, 0xf8, 0x6d, 0xc4, 0xe3, 0x10, 0x1d, 0x82, 0x11, 0x4e, 0x6d, + 0xc2, 0xd4, 0x54, 0x8d, 0xff, 0xff, 0x16, 0x4e, 0xa5, 0x8d, 0xe7, 0x7c, 0xad, 0x2f, 0x1a, 0x2c, + 0xbc, 0x8f, 0x69, 0x34, 0xbe, 0xfb, 0x0c, 0x36, 0xa1, 0x32, 0x37, 0x03, 0x61, 0xe6, 0xea, 0x7a, + 0x63, 0x11, 0x97, 0x67, 0x87, 0x20, 0x92, 0xf6, 0x38, 0xc2, 0x37, 0xf5, 0xb4, 0x3d, 0x8e, 0xf0, + 0xd1, 0x33, 0x58, 0x9e, 0xc1, 0xee, 0xbb, 0x49, 0x31, 0x66, 0xbe, 0xae, 0x35, 0x0c, 0x5c, 0x0d, + 0x6f, 0x14, 0x69, 0x7d, 0x84, 0xca, 0x99, 0x8c, 0x58, 0xe0, 0x62, 0x2a, 0x42, 0x1e, 0x08, 0x8a, + 0xb6, 0xa1, 0x20, 0x24, 0x91, 0xb1, 0x50, 0x79, 0x95, 0xb6, 0xd6, 0x6f, 0x1d, 0xea, 0x99, 0xfa, + 0x82, 0xb3, 0xaf, 0xa8, 0x06, 0x0b, 0xaa, 0x93, 0xd9, 0xa2, 0xa4, 0x8a, 0x75, 0x01, 0x46, 0x87, + 0x73, 0xff, 0x11, 0x43, 0x17, 0xaf, 0x43, 0x13, 0x40, 0x69, 0xde, 0xc7, 0x4c, 0xc8, 0x87, 0x01, + 0x4c, 0x77, 0x22, 0x6d, 0xf0, 0xf5, 0x4e, 0x0c, 0xe0, 0x9f, 0x83, 0x40, 0x52, 0x97, 0x46, 0x8f, + 0x8d, 0xa1, 0x4f, 0x30, 0x04, 0xd4, 0x32, 0x0c, 0x4c, 0x02, 0x97, 0x3e, 0xb8, 0x53, 0x03, 0xea, + 0xb2, 0x40, 0x75, 0x4a, 0xc7, 0xa9, 0x92, 0x2c, 0x08, 0x0d, 0x1c, 0xb5, 0x20, 0x3a, 0x4e, 0x44, + 0xeb, 0xbb, 0x06, 0xff, 0x4e, 0xb9, 0x69, 0x8f, 0x0a, 0x3b, 0x62, 0x61, 0x22, 0xde, 0x0f, 0xf6, + 0x15, 0x14, 0x52, 0xe6, 0x53, 0xb8, 0xa5, 0x9f, 0x0e, 0x32, 0x65, 0xc5, 0x29, 0xe0, 0x99, 0x32, + 0xe0, 0xcc, 0x09, 0xb5, 0x01, 0x92, 0x40, 0x4c, 0x48, 0x66, 0x8b, 0x8c, 0x48, 0xfe, 0xbb, 0x15, + 0xf7, 0x88, 0x8e, 0xd5, 0x6d, 0x9d, 0x12, 0x16, 0xe1, 0x19, 0x27, 0xeb, 0x9b, 0x06, 0xb5, 0x09, + 0x63, 0x3e, 0xb8, 0x9e, 0x97, 0x90, 0x57, 0x67, 0x99, 0x56, 0xb3, 0xf1, 0x8b, 0x7b, 0x9f, 0x25, + 0x68, 0xac, 0x1c, 0x1e, 0xa3, 0x92, 0x23, 0xc8, 0xbf, 0x63, 0x52, 0x5d, 0xf5, 0xc1, 0x5e, 0x4a, + 0x39, 0x3a, 0x4e, 0x44, 0xb4, 0x3a, 0xc3, 0xb6, 0x39, 0xc5, 0x5d, 0x13, 0x4a, 0x5d, 0x49, 0x06, + 0xc0, 0xa3, 0x8c, 0xd4, 0x72, 0x38, 0xd3, 0xac, 0x73, 0x28, 0x29, 0xce, 0xc1, 0x54, 0xc4, 0xbe, + 0xbc, 0x5f, 0x33, 0x10, 0xe4, 0x3d, 0x26, 0x45, 0x06, 0xa9, 0xe4, 0xa7, 0xaf, 0x61, 0xe9, 0x06, + 0xbb, 0xa2, 0x22, 0xe4, 0xbb, 0x27, 0xdd, 0xfd, 0xea, 0x5f, 0x68, 0x19, 0xca, 0xe7, 0xfb, 0xbb, + 0xbd, 0x13, 0xdc, 0xef, 0x1c, 0x74, 0xdb, 0xf8, 0xa2, 0xea, 0xa0, 0x2a, 0x18, 0x99, 0xe9, 0xcd, + 0xf1, 0x49, 0xbb, 0x57, 0xa5, 0x9d, 0xdd, 0x0f, 0x6d, 0x97, 0x49, 0x2f, 0x1e, 0x24, 0xa8, 0xad, + 0x2b, 0xe6, 0xfb, 0xec, 0x4a, 0x52, 0xdb, 0x6b, 0xa5, 0x19, 0x3d, 0x77, 0x98, 0x90, 0x11, 0x1b, + 0xc4, 0x92, 0x3a, 0x2d, 0x16, 0x48, 0x1a, 0x05, 0xc4, 0x6f, 0xa9, 0x34, 0x5b, 0xd9, 0x00, 0xc2, + 0xc1, 0xa0, 0xa0, 0x0c, 0xdb, 0x3f, 0x02, 0x00, 0x00, 0xff, 0xff, 0xee, 0x08, 0x5d, 0xa4, 0xaf, 0x07, 0x00, 0x00, } diff --git a/internal/proxy/meta_cache.go b/internal/proxy/meta_cache.go index 527b5447ee..25c9e1163f 100644 --- a/internal/proxy/meta_cache.go +++ b/internal/proxy/meta_cache.go @@ -11,14 +11,14 @@ import ( "github.com/zilliztech/milvus-distributed/internal/proto/servicepb" ) -type MetaCache interface { +type Cache interface { Hit(collectionName string) bool Get(collectionName string) (*servicepb.CollectionDescription, error) Update(collectionName string) error - //Write(collectionName string, schema *servicepb.CollectionDescription) error + Remove(collectionName string) error } -var globalMetaCache MetaCache +var globalMetaCache Cache type SimpleMetaCache struct { mu sync.RWMutex @@ -30,29 +30,54 @@ type SimpleMetaCache struct { ctx context.Context } -func (smc *SimpleMetaCache) Hit(collectionName string) bool { - smc.mu.RLock() - defer smc.mu.RUnlock() - _, ok := smc.metas[collectionName] +func (metaCache *SimpleMetaCache) Hit(collectionName string) bool { + metaCache.mu.RLock() + defer metaCache.mu.RUnlock() + _, ok := metaCache.metas[collectionName] return ok } -func (smc *SimpleMetaCache) Get(collectionName string) (*servicepb.CollectionDescription, error) { - smc.mu.RLock() - defer smc.mu.RUnlock() - schema, ok := smc.metas[collectionName] +func (metaCache *SimpleMetaCache) Get(collectionName string) (*servicepb.CollectionDescription, error) { + metaCache.mu.RLock() + defer metaCache.mu.RUnlock() + schema, ok := metaCache.metas[collectionName] if !ok { return nil, errors.New("collection meta miss") } return schema, nil } -func (smc *SimpleMetaCache) Update(collectionName string) error { - reqID, err := smc.reqIDAllocator.AllocOne() +func (metaCache *SimpleMetaCache) Update(collectionName string) error { + reqID, err := metaCache.reqIDAllocator.AllocOne() if err != nil { return err } - ts, err := smc.tsoAllocator.AllocOne() + ts, err := metaCache.tsoAllocator.AllocOne() + if err != nil { + return err + } + hasCollectionReq := &internalpb.HasCollectionRequest{ + MsgType: internalpb.MsgType_kHasCollection, + ReqID: reqID, + Timestamp: ts, + ProxyID: metaCache.proxyID, + CollectionName: &servicepb.CollectionName{ + CollectionName: collectionName, + }, + } + has, err := metaCache.masterClient.HasCollection(metaCache.ctx, hasCollectionReq) + if err != nil { + return err + } + if !has.Value { + return errors.New("collection " + collectionName + " not exists") + } + + reqID, err = metaCache.reqIDAllocator.AllocOne() + if err != nil { + return err + } + ts, err = metaCache.tsoAllocator.AllocOne() if err != nil { return err } @@ -60,20 +85,32 @@ func (smc *SimpleMetaCache) Update(collectionName string) error { MsgType: internalpb.MsgType_kDescribeCollection, ReqID: reqID, Timestamp: ts, - ProxyID: smc.proxyID, + ProxyID: metaCache.proxyID, CollectionName: &servicepb.CollectionName{ CollectionName: collectionName, }, } - - resp, err := smc.masterClient.DescribeCollection(smc.ctx, req) + resp, err := metaCache.masterClient.DescribeCollection(metaCache.ctx, req) if err != nil { return err } - smc.mu.Lock() - defer smc.mu.Unlock() - smc.metas[collectionName] = resp + metaCache.mu.Lock() + defer metaCache.mu.Unlock() + metaCache.metas[collectionName] = resp + + return nil +} + +func (metaCache *SimpleMetaCache) Remove(collectionName string) error { + metaCache.mu.Lock() + defer metaCache.mu.Unlock() + + _, ok := metaCache.metas[collectionName] + if !ok { + return errors.New("cannot find collection: " + collectionName) + } + delete(metaCache.metas, collectionName) return nil } diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index 9553cdff48..2319d55482 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -231,7 +231,7 @@ func TestProxy_Insert(t *testing.T) { CollectionName: collectionName, PartitionTag: "haha", RowData: make([]*commonpb.Blob, 0), - HashKeys: make([]int32, 0), + HashKeys: make([]uint32, 0), } wg.Add(1) @@ -281,7 +281,7 @@ func TestProxy_Search(t *testing.T) { for j := 0; j < 4; j++ { searchResultMsg := &msgstream.SearchResultMsg{ BaseMsg: msgstream.BaseMsg{ - HashValues: []int32{1}, + HashValues: []uint32{1}, }, SearchResult: internalpb.SearchResult{ MsgType: internalpb.MsgType_kSearchResult, diff --git a/internal/proxy/task.go b/internal/proxy/task.go index 250c2894fe..f26d4aabaf 100644 --- a/internal/proxy/task.go +++ b/internal/proxy/task.go @@ -7,6 +7,8 @@ import ( "math" "strconv" + "github.com/zilliztech/milvus-distributed/internal/util/typeutil" + "github.com/golang/protobuf/proto" "github.com/zilliztech/milvus-distributed/internal/allocator" @@ -99,13 +101,20 @@ func (it *InsertTask) Execute() error { return err } autoID := description.Schema.AutoID + var rowIDBegin UniqueID + var rowIDEnd UniqueID if autoID || true { + if it.HashValues == nil || len(it.HashValues) == 0 { + it.HashValues = make([]uint32, 0) + } rowNums := len(it.BaseInsertTask.RowData) - rowIDBegin, rowIDEnd, _ := it.rowIDAllocator.Alloc(uint32(rowNums)) + rowIDBegin, rowIDEnd, _ = it.rowIDAllocator.Alloc(uint32(rowNums)) it.BaseInsertTask.RowIDs = make([]UniqueID, rowNums) for i := rowIDBegin; i < rowIDEnd; i++ { offset := i - rowIDBegin it.BaseInsertTask.RowIDs[offset] = i + hashValue, _ := typeutil.Hash32Int64(i) + it.HashValues = append(it.HashValues, hashValue) } } @@ -121,6 +130,8 @@ func (it *InsertTask) Execute() error { Status: &commonpb.Status{ ErrorCode: commonpb.ErrorCode_SUCCESS, }, + Begin: rowIDBegin, + End: rowIDEnd, } if err != nil { it.result.Status.ErrorCode = commonpb.ErrorCode_UNEXPECTED_ERROR @@ -287,7 +298,7 @@ func (dct *DropCollectionTask) Execute() error { } func (dct *DropCollectionTask) PostExecute() error { - return nil + return globalMetaCache.Remove(dct.CollectionName.CollectionName) } type QueryTask struct { @@ -325,6 +336,18 @@ func (qt *QueryTask) SetTs(ts Timestamp) { } func (qt *QueryTask) PreExecute() error { + collectionName := qt.query.CollectionName + if !globalMetaCache.Hit(collectionName) { + err := globalMetaCache.Update(collectionName) + if err != nil { + return err + } + } + _, err := globalMetaCache.Get(collectionName) + if err != nil { // err is not nil if collection not exists + return err + } + if err := ValidateCollectionName(qt.query.CollectionName); err != nil { return err } @@ -352,7 +375,7 @@ func (qt *QueryTask) Execute() error { var tsMsg msgstream.TsMsg = &msgstream.SearchMsg{ SearchRequest: qt.SearchRequest, BaseMsg: msgstream.BaseMsg{ - HashValues: []int32{int32(Params.ProxyID())}, + HashValues: []uint32{uint32(Params.ProxyID())}, BeginTimestamp: qt.Timestamp, EndTimestamp: qt.Timestamp, }, @@ -378,25 +401,33 @@ func (qt *QueryTask) PostExecute() error { log.Print("wait to finish failed, timeout!") return errors.New("wait to finish failed, timeout") case searchResults := <-qt.resultBuf: - rlen := len(searchResults) // query num + filterSearchResult := make([]*internalpb.SearchResult, 0) + for _, partialSearchResult := range searchResults { + if partialSearchResult.Status.ErrorCode == commonpb.ErrorCode_SUCCESS { + filterSearchResult = append(filterSearchResult, partialSearchResult) + } + } + + rlen := len(filterSearchResult) // query num if rlen <= 0 { qt.result = &servicepb.QueryResult{} return nil } - n := len(searchResults[0].Hits) // n + n := len(filterSearchResult[0].Hits) // n if n <= 0 { qt.result = &servicepb.QueryResult{} return nil } hits := make([][]*servicepb.Hits, rlen) - for i, searchResult := range searchResults { + for i, partialSearchResult := range filterSearchResult { hits[i] = make([]*servicepb.Hits, n) - for j, bs := range searchResult.Hits { + for j, bs := range partialSearchResult.Hits { hits[i][j] = &servicepb.Hits{} err := proto.Unmarshal(bs, hits[i][j]) if err != nil { + log.Println("unmarshal error") return err } } @@ -428,6 +459,17 @@ func (qt *QueryTask) PostExecute() error { } } choiceOffset := locs[choice] + // check if distance is valid, `invalid` here means very very big, + // in this process, distance here is the smallest, so the rest of distance are all invalid + if hits[choice][i].Scores[choiceOffset] >= float32(math.MaxFloat32) { + qt.result = &servicepb.QueryResult{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR, + Reason: "topk in dsl greater than the row nums of collection", + }, + } + return nil + } reducedHits.IDs = append(reducedHits.IDs, hits[choice][i].IDs[choiceOffset]) if hits[choice][i].RowData != nil && len(hits[choice][i].RowData) > 0 { reducedHits.RowData = append(reducedHits.RowData, hits[choice][i].RowData[choiceOffset]) @@ -437,6 +479,7 @@ func (qt *QueryTask) PostExecute() error { } reducedHitsBs, err := proto.Marshal(reducedHits) if err != nil { + log.Println("marshal error") return err } qt.result.Hits = append(qt.result.Hits, reducedHitsBs) diff --git a/internal/proxy/timetick.go b/internal/proxy/timetick.go index 6a4436dd1e..ffbfb50771 100644 --- a/internal/proxy/timetick.go +++ b/internal/proxy/timetick.go @@ -70,7 +70,7 @@ func (tt *timeTick) tick() error { msgPack := msgstream.MsgPack{} timeTickMsg := &msgstream.TimeTickMsg{ BaseMsg: msgstream.BaseMsg{ - HashValues: []int32{int32(Params.ProxyID())}, + HashValues: []uint32{uint32(Params.ProxyID())}, }, TimeTickMsg: internalpb.TimeTickMsg{ MsgType: internalpb.MsgType_kTimeTick, diff --git a/internal/querynode/data_sync_service_test.go b/internal/querynode/data_sync_service_test.go index e833c5e08a..4b1e8f147b 100644 --- a/internal/querynode/data_sync_service_test.go +++ b/internal/querynode/data_sync_service_test.go @@ -129,8 +129,8 @@ func TestDataSyncService_Start(t *testing.T) { for i := 0; i < msgLength; i++ { var msg msgstream.TsMsg = &msgstream.InsertMsg{ BaseMsg: msgstream.BaseMsg{ - HashValues: []int32{ - int32(i), int32(i), + HashValues: []uint32{ + uint32(i), uint32(i), }, }, InsertRequest: internalPb.InsertRequest{ @@ -163,7 +163,7 @@ func TestDataSyncService_Start(t *testing.T) { baseMsg := msgstream.BaseMsg{ BeginTimestamp: 0, EndTimestamp: 0, - HashValues: []int32{0}, + HashValues: []uint32{0}, } timeTickResult := internalPb.TimeTickMsg{ MsgType: internalPb.MsgType_kTimeTick, diff --git a/internal/querynode/plan.go b/internal/querynode/plan.go index 90feb0029b..5673f95cea 100644 --- a/internal/querynode/plan.go +++ b/internal/querynode/plan.go @@ -10,6 +10,8 @@ package querynode */ import "C" import ( + "errors" + "strconv" "unsafe" ) @@ -17,11 +19,21 @@ type Plan struct { cPlan C.CPlan } -func createPlan(col Collection, dsl string) *Plan { +func createPlan(col Collection, dsl string) (*Plan, error) { cDsl := C.CString(dsl) - cPlan := C.CreatePlan(col.collectionPtr, cDsl) + var cPlan C.CPlan + status := C.CreatePlan(col.collectionPtr, cDsl, &cPlan) + + errorCode := status.error_code + + if errorCode != 0 { + errorMsg := C.GoString(status.error_msg) + defer C.free(unsafe.Pointer(status.error_msg)) + return nil, errors.New("Insert failed, C runtime error detected, error code = " + strconv.Itoa(int(errorCode)) + ", error msg = " + errorMsg) + } + var newPlan = &Plan{cPlan: cPlan} - return newPlan + return newPlan, nil } func (plan *Plan) getTopK() int64 { @@ -37,12 +49,22 @@ type PlaceholderGroup struct { cPlaceholderGroup C.CPlaceholderGroup } -func parserPlaceholderGroup(plan *Plan, placeHolderBlob []byte) *PlaceholderGroup { +func parserPlaceholderGroup(plan *Plan, placeHolderBlob []byte) (*PlaceholderGroup, error) { var blobPtr = unsafe.Pointer(&placeHolderBlob[0]) blobSize := C.long(len(placeHolderBlob)) - cPlaceholderGroup := C.ParsePlaceholderGroup(plan.cPlan, blobPtr, blobSize) + var cPlaceholderGroup C.CPlaceholderGroup + status := C.ParsePlaceholderGroup(plan.cPlan, blobPtr, blobSize, &cPlaceholderGroup) + + errorCode := status.error_code + + if errorCode != 0 { + errorMsg := C.GoString(status.error_msg) + defer C.free(unsafe.Pointer(status.error_msg)) + return nil, errors.New("Insert failed, C runtime error detected, error code = " + strconv.Itoa(int(errorCode)) + ", error msg = " + errorMsg) + } + var newPlaceholderGroup = &PlaceholderGroup{cPlaceholderGroup: cPlaceholderGroup} - return newPlaceholderGroup + return newPlaceholderGroup, nil } func (pg *PlaceholderGroup) getNumOfQuery() int64 { diff --git a/internal/querynode/plan_test.go b/internal/querynode/plan_test.go index 714f7edbe2..0d26f90b9d 100644 --- a/internal/querynode/plan_test.go +++ b/internal/querynode/plan_test.go @@ -64,7 +64,8 @@ func TestPlan_Plan(t *testing.T) { dslString := "{\"bool\": { \n\"vector\": {\n \"vec\": {\n \"metric_type\": \"L2\", \n \"params\": {\n \"nprobe\": 10 \n},\n \"query\": \"$0\",\"topk\": 10 \n } \n } \n } \n }" - plan := createPlan(*collection, dslString) + plan, err := createPlan(*collection, dslString) + assert.NoError(t, err) assert.NotEqual(t, plan, nil) topk := plan.getTopK() assert.Equal(t, int(topk), 10) @@ -122,7 +123,8 @@ func TestPlan_PlaceholderGroup(t *testing.T) { dslString := "{\"bool\": { \n\"vector\": {\n \"vec\": {\n \"metric_type\": \"L2\", \n \"params\": {\n \"nprobe\": 10 \n},\n \"query\": \"$0\",\"topk\": 10 \n } \n } \n } \n }" - plan := createPlan(*collection, dslString) + plan, err := createPlan(*collection, dslString) + assert.NoError(t, err) assert.NotNil(t, plan) var searchRawData1 []byte @@ -151,7 +153,8 @@ func TestPlan_PlaceholderGroup(t *testing.T) { placeGroupByte, err := proto.Marshal(&placeholderGroup) assert.Nil(t, err) - holder := parserPlaceholderGroup(plan, placeGroupByte) + holder, err := parserPlaceholderGroup(plan, placeGroupByte) + assert.NoError(t, err) assert.NotNil(t, holder) numQueries := holder.getNumOfQuery() assert.Equal(t, int(numQueries), 2) diff --git a/internal/querynode/reduce_test.go b/internal/querynode/reduce_test.go index 241a3001bb..5774c38acc 100644 --- a/internal/querynode/reduce_test.go +++ b/internal/querynode/reduce_test.go @@ -99,8 +99,10 @@ func TestReduce_AllFunc(t *testing.T) { log.Print("marshal placeholderGroup failed") } - plan := createPlan(*collection, dslString) - holder := parserPlaceholderGroup(plan, placeGroupByte) + plan, err := createPlan(*collection, dslString) + assert.NoError(t, err) + holder, err := parserPlaceholderGroup(plan, placeGroupByte) + assert.NoError(t, err) placeholderGroups := make([]*PlaceholderGroup, 0) placeholderGroups = append(placeholderGroups, holder) diff --git a/internal/querynode/search_service.go b/internal/querynode/search_service.go index 9ecf3d78c2..a2a66c329d 100644 --- a/internal/querynode/search_service.go +++ b/internal/querynode/search_service.go @@ -139,10 +139,10 @@ func (ss *searchService) receiveSearchMsg() { err := ss.search(msg) if err != nil { log.Println("search Failed, error msg type: ", msg.Type()) - } - err = ss.publishFailedSearchResult(msg) - if err != nil { - log.Println("publish FailedSearchResult failed, error message: ", err) + err = ss.publishFailedSearchResult(msg) + if err != nil { + log.Println("publish FailedSearchResult failed, error message: ", err) + } } } log.Println("ReceiveSearchMsg, do search done, num of searchMsg = ", len(searchMsg)) @@ -191,10 +191,10 @@ func (ss *searchService) doUnsolvedMsgSearch() { err := ss.search(msg) if err != nil { log.Println("search Failed, error msg type: ", msg.Type()) - } - err = ss.publishFailedSearchResult(msg) - if err != nil { - log.Println("publish FailedSearchResult failed, error message: ", err) + err = ss.publishFailedSearchResult(msg) + if err != nil { + log.Println("publish FailedSearchResult failed, error message: ", err) + } } } log.Println("doUnsolvedMsgSearch, do search done, num of searchMsg = ", len(searchMsg)) @@ -225,9 +225,15 @@ func (ss *searchService) search(msg msgstream.TsMsg) error { } collectionID := collection.ID() dsl := query.Dsl - plan := createPlan(*collection, dsl) + plan, err := createPlan(*collection, dsl) + if err != nil { + return err + } placeHolderGroupBlob := query.PlaceholderGroup - placeholderGroup := parserPlaceholderGroup(plan, placeHolderGroupBlob) + placeholderGroup, err := parserPlaceholderGroup(plan, placeHolderGroupBlob) + if err != nil { + return err + } placeholderGroups := make([]*PlaceholderGroup, 0) placeholderGroups = append(placeholderGroups, placeholderGroup) @@ -251,12 +257,7 @@ func (ss *searchService) search(msg msgstream.TsMsg) error { } if len(searchResults) <= 0 { - log.Println("search Failed, invalid partitionTag") - err = ss.publishFailedSearchResult(msg) - if err != nil { - log.Println("publish FailedSearchResult failed, error message: ", err) - } - return err + return errors.New("search Failed, invalid partitionTag") } reducedSearchResult := reduceSearchResults(searchResults, int64(len(searchResults))) @@ -296,7 +297,7 @@ func (ss *searchService) search(msg msgstream.TsMsg) error { Hits: hits, } searchResultMsg := &msgstream.SearchResultMsg{ - BaseMsg: msgstream.BaseMsg{HashValues: []int32{0}}, + BaseMsg: msgstream.BaseMsg{HashValues: []uint32{0}}, SearchResult: results, } err = ss.publishSearchResult(searchResultMsg) @@ -341,7 +342,7 @@ func (ss *searchService) publishFailedSearchResult(msg msgstream.TsMsg) error { } tsMsg := &msgstream.SearchResultMsg{ - BaseMsg: msgstream.BaseMsg{HashValues: []int32{0}}, + BaseMsg: msgstream.BaseMsg{HashValues: []uint32{0}}, SearchResult: results, } msgPack.Msgs = append(msgPack.Msgs, tsMsg) diff --git a/internal/querynode/search_service_test.go b/internal/querynode/search_service_test.go index 422cef945e..d095db3a7f 100644 --- a/internal/querynode/search_service_test.go +++ b/internal/querynode/search_service_test.go @@ -143,7 +143,7 @@ func TestSearch_Search(t *testing.T) { searchMsg := &msgstream.SearchMsg{ BaseMsg: msgstream.BaseMsg{ - HashValues: []int32{0}, + HashValues: []uint32{0}, }, SearchRequest: internalpb.SearchRequest{ MsgType: internalpb.MsgType_kSearch, @@ -188,8 +188,8 @@ func TestSearch_Search(t *testing.T) { var msg msgstream.TsMsg = &msgstream.InsertMsg{ BaseMsg: msgstream.BaseMsg{ - HashValues: []int32{ - int32(i), + HashValues: []uint32{ + uint32(i), }, }, InsertRequest: internalpb.InsertRequest{ @@ -221,7 +221,7 @@ func TestSearch_Search(t *testing.T) { baseMsg := msgstream.BaseMsg{ BeginTimestamp: 0, EndTimestamp: 0, - HashValues: []int32{0}, + HashValues: []uint32{0}, } timeTickResult := internalpb.TimeTickMsg{ MsgType: internalpb.MsgType_kTimeTick, diff --git a/internal/querynode/segment_test.go b/internal/querynode/segment_test.go index a9100ca21b..4522289cac 100644 --- a/internal/querynode/segment_test.go +++ b/internal/querynode/segment_test.go @@ -690,8 +690,10 @@ func TestSegment_segmentSearch(t *testing.T) { } searchTimestamp := Timestamp(1020) - plan := createPlan(*collection, dslString) - holder := parserPlaceholderGroup(plan, placeHolderGroupBlob) + plan, err := createPlan(*collection, dslString) + assert.NoError(t, err) + holder, err := parserPlaceholderGroup(plan, placeHolderGroupBlob) + assert.NoError(t, err) placeholderGroups := make([]*PlaceholderGroup, 0) placeholderGroups = append(placeholderGroups, holder) diff --git a/internal/querynode/stats_service.go b/internal/querynode/stats_service.go index 81ebfe15e4..0ded253326 100644 --- a/internal/querynode/stats_service.go +++ b/internal/querynode/stats_service.go @@ -74,7 +74,7 @@ func (sService *statsService) sendSegmentStatistic() { func (sService *statsService) publicStatistic(statistic *internalpb.QueryNodeSegStats) { var msg msgstream.TsMsg = &msgstream.QueryNodeSegStatsMsg{ BaseMsg: msgstream.BaseMsg{ - HashValues: []int32{0}, + HashValues: []uint32{0}, }, QueryNodeSegStats: *statistic, } diff --git a/internal/util/paramtable/paramtable.go b/internal/util/paramtable/paramtable.go index 04b24bbf84..b8334c27f3 100644 --- a/internal/util/paramtable/paramtable.go +++ b/internal/util/paramtable/paramtable.go @@ -82,8 +82,17 @@ func (gp *BaseTable) LoadRange(key, endKey string, limit int) ([]string, []strin func (gp *BaseTable) LoadYaml(fileName string) error { config := viper.New() _, fpath, _, _ := runtime.Caller(0) - configPath := path.Dir(fpath) + "/../../../configs/" - config.SetConfigFile(configPath + fileName) + configFile := path.Dir(fpath) + "/../../../configs/" + fileName + _, err := os.Stat(configFile) + if os.IsNotExist(err) { + runPath, err := os.Getwd() + if err != nil { + panic(err) + } + configFile = runPath + "/configs/" + fileName + } + + config.SetConfigFile(configFile) if err := config.ReadInConfig(); err != nil { panic(err) } diff --git a/scripts/before-install.sh b/scripts/before-install.sh index e767de26c9..64f04f9458 100755 --- a/scripts/before-install.sh +++ b/scripts/before-install.sh @@ -2,14 +2,12 @@ set -ex -export CCACHE_COMPRESS=1 -export CCACHE_COMPRESSLEVEL=5 -export CCACHE_COMPILERCHECK=content -export PATH=/usr/lib/ccache/:$PATH -export CCACHE_BASEDIR=${WORKSPACE:=""} +export CCACHE_COMPRESS=${CCACHE_COMPRESS:="1"} +export CCACHE_COMPRESSLEVEL=${CCACHE_COMPRESSLEVEL:="5"} +export CCACHE_COMPILERCHECK=${CCACHE_COMPILERCHECK:="content"} +export CCACHE_MAXSIZE=${CCACHE_MAXSIZE:="2G"} export CCACHE_DIR=${CCACHE_DIR:="${HOME}/.ccache"} -export CCACHE_COMPRESS_PACKAGE_FILE=${CCACHE_COMPRESS_PACKAGE_FILE:="ccache-${OS_NAME}-${BUILD_ENV_IMAGE_ID}.tar.gz"} -export CUSTOM_THIRDPARTY_DOWNLOAD_PATH=${CUSTOM_THIRDPARTY_DOWNLOAD_PATH:="${HOME}/3rdparty_download"} -export THIRDPARTY_COMPRESS_PACKAGE_FILE=${THIRDPARTY_COMPRESS_PACKAGE_FILE:="thirdparty-download.tar.gz"} +export http_proxy="http://proxy.zilliz.tech:1088" +export https_proxy="http://proxy.zilliz.tech:1088" set +ex diff --git a/scripts/check_cache.sh b/scripts/check_cache.sh new file mode 100755 index 0000000000..d675181745 --- /dev/null +++ b/scripts/check_cache.sh @@ -0,0 +1,114 @@ +#!/bin/bash + +HELP=" +Usage: + $0 [flags] [Arguments] + + -l [ARTIFACTORY_URL] Artifactory URL + --cache_dir=[CACHE_DIR] Cache directory + -f [FILE] or --file=[FILE] Cache compress package file + -h or --help Print help information + + +Use \"$0 --help\" for more information about a given command. +" + +ARGS=$(getopt -o "l:f:h" -l "cache_dir::,file::,help" -n "$0" -- "$@") + +eval set -- "${ARGS}" + +while true ; do + case "$1" in + -l) + # o has an optional argument. As we are in quoted mode, + # an empty parameter will be generated if its optional + # argument is not found. + case "$2" in + "") echo "Option Artifactory URL, no argument"; exit 1 ;; + *) ARTIFACTORY_URL=$2 ; shift 2 ;; + esac ;; + --cache_dir) + case "$2" in + "") echo "Option cache_dir, no argument"; exit 1 ;; + *) CACHE_DIR=$2 ; shift 2 ;; + esac ;; + -f|--file) + case "$2" in + "") echo "Option file, no argument"; exit 1 ;; + *) PACKAGE_FILE=$2 ; shift 2 ;; + esac ;; + -h|--help) echo -e "${HELP}" ; exit 0 ;; + --) shift ; break ;; + *) echo "Internal error!" ; exit 1 ;; + esac +done + +# Set defaults for vars modified by flags to this script +BRANCH_NAMES=$(git log --decorate | head -n 1 | sed 's/.*(\(.*\))/\1/' | sed 's=[a-zA-Z]*\/==g' | awk -F", " '{$1=""; print $0}') + +if [[ -z "${ARTIFACTORY_URL}" || "${ARTIFACTORY_URL}" == "" ]];then + echo "You have not input ARTIFACTORY_URL !" + exit 1 +fi + +if [[ -z "${CACHE_DIR}" ]]; then + echo "You have not input CACHE_DIR !" + exit 1 +fi + +if [[ -z "${PACKAGE_FILE}" ]]; then + echo "You have not input PACKAGE_FILE !" + exit 1 +fi + +function check_cache() { + BRANCH=$1 + echo "fetching ${BRANCH}/${PACKAGE_FILE}" + wget -q --spider "${ARTIFACTORY_URL}/${BRANCH}/${PACKAGE_FILE}" + return $? +} + +function download_file() { + BRANCH=$1 + wget -q "${ARTIFACTORY_URL}/${BRANCH}/${PACKAGE_FILE}" && \ + mkdir -p "${CACHE_DIR}" && \ + tar zxf "${PACKAGE_FILE}" -C "${CACHE_DIR}" && \ + rm ${PACKAGE_FILE} + return $? +} + +if [[ -n "${CHANGE_TARGET}" && "${BRANCH_NAME}" =~ "PR-" ]];then + check_cache ${CHANGE_TARGET} + if [[ $? == 0 ]];then + download_file ${CHANGE_TARGET} + if [[ $? == 0 ]];then + echo "found cache" + exit 0 + fi + fi + + check_cache ${BRANCH_NAME} + if [[ $? == 0 ]];then + download_file ${BRANCH_NAME} + if [[ $? == 0 ]];then + echo "found cache" + exit 0 + fi + fi +fi + +for CURRENT_BRANCH in ${BRANCH_NAMES} +do + if [[ "${CURRENT_BRANCH}" != "HEAD" ]];then + check_cache ${CURRENT_BRANCH} + if [[ $? == 0 ]];then + download_file ${CURRENT_BRANCH} + if [[ $? == 0 ]];then + echo "found cache" + exit 0 + fi + fi + fi +done + +echo "could not download cache" && exit 1 diff --git a/scripts/update_cache.sh b/scripts/update_cache.sh new file mode 100755 index 0000000000..85985b741d --- /dev/null +++ b/scripts/update_cache.sh @@ -0,0 +1,104 @@ +#!/bin/bash + +HELP=" +Usage: + $0 [flags] [Arguments] + + -l [ARTIFACTORY_URL] Artifactory URL + --cache_dir=[CACHE_DIR] Cache directory + -f [FILE] or --file=[FILE] Cache compress package file + -u [USERNAME] Artifactory Username + -p [PASSWORD] Artifactory Password + -h or --help Print help information + + +Use \"$0 --help\" for more information about a given command. +" + +ARGS=$(getopt -o "l:f:u:p:h" -l "cache_dir::,file::,help" -n "$0" -- "$@") + +eval set -- "${ARGS}" + +while true ; do + case "$1" in + -l) + # o has an optional argument. As we are in quoted mode, + # an empty parameter will be generated if its optional + # argument is not found. + case "$2" in + "") echo "Option Artifactory URL, no argument"; exit 1 ;; + *) ARTIFACTORY_URL=$2 ; shift 2 ;; + esac ;; + --cache_dir) + case "$2" in + "") echo "Option cache_dir, no argument"; exit 1 ;; + *) CACHE_DIR=$2 ; shift 2 ;; + esac ;; + -u) + case "$2" in + "") echo "Option Username, no argument"; exit 1 ;; + *) USERNAME=$2 ; shift 2 ;; + esac ;; + -p) + case "$2" in + "") echo "Option Password, no argument"; exit 1 ;; + *) PASSWORD=$2 ; shift 2 ;; + esac ;; + -f|--file) + case "$2" in + "") echo "Option file, no argument"; exit 1 ;; + *) PACKAGE_FILE=$2 ; shift 2 ;; + esac ;; + -h|--help) echo -e "${HELP}" ; exit 0 ;; + --) shift ; break ;; + *) echo "Internal error!" ; exit 1 ;; + esac +done + +# Set defaults for vars modified by flags to this script +BRANCH_NAME=$(git log --decorate | head -n 1 | sed 's/.*(\(.*\))/\1/' | sed 's/.*, //' | sed 's=[a-zA-Z]*\/==g') + +if [[ -z "${ARTIFACTORY_URL}" || "${ARTIFACTORY_URL}" == "" ]];then + echo "You have not input ARTIFACTORY_URL !" + exit 1 +fi + +if [[ ! -d "${CACHE_DIR}" ]]; then + echo "\"${CACHE_DIR}\" directory does not exist !" + exit 1 +fi + +if [[ -z "${PACKAGE_FILE}" ]]; then + echo "You have not input PACKAGE_FILE !" + exit 1 +fi + +function check_cache() { + BRANCH=$1 + wget -q --spider "${ARTIFACTORY_URL}/${BRANCH}/${PACKAGE_FILE}" + return $? +} + +if [[ -n "${CHANGE_TARGET}" && "${BRANCH_NAME}" =~ "PR-" ]]; then + check_cache ${CHANGE_TARGET} + if [[ $? == 0 ]];then + echo "Skip Update cache package ..." && exit 0 + fi +fi + +if [[ "${BRANCH_NAME}" != "HEAD" ]];then + REMOTE_PACKAGE_PATH="${ARTIFACTORY_URL}/${BRANCH_NAME}" + echo "Updating cache package file: ${PACKAGE_FILE}" + tar zcf ./"${PACKAGE_FILE}" -C "${CACHE_DIR}" . + echo "Uploading cache package file ${PACKAGE_FILE} to ${REMOTE_PACKAGE_PATH}" + curl -u"${USERNAME}":"${PASSWORD}" -T "${PACKAGE_FILE}" "${REMOTE_PACKAGE_PATH}"/"${PACKAGE_FILE}" + if [[ $? == 0 ]];then + echo "Uploading cache package file success !" + exit 0 + else + echo "Uploading cache package file fault !" + exit 1 + fi +fi + +echo "Skip Update cache package ..."