diff --git a/CHANGELOG.md b/CHANGELOG.md index bc13527646..71e4efbb93 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,8 @@ Please mark all change in change log and use the issue from GitHub ## Bug - \#1276 SQLite throw exception after create 50000+ partitions in a table - \#1762 Server is not forbidden to create new partition which tag is `_default` +- \#1789 Fix multi-client search cause server crash +- \#1832 Fix crash in tracing module - \#1873 Fix index file serialize to incorrect path - \#1881 Fix Annoy index search failure diff --git a/core/cmake/ThirdPartyPackages.cmake b/core/cmake/ThirdPartyPackages.cmake index 25d7afa66f..29b4c6e4b4 100644 --- a/core/cmake/ThirdPartyPackages.cmake +++ b/core/cmake/ThirdPartyPackages.cmake @@ -301,6 +301,7 @@ if (DEFINED ENV{MILVUS_GRPC_URL}) set(GRPC_SOURCE_URL "$ENV{MILVUS_GRPC_URL}") else () set(GRPC_SOURCE_URL + "https://github.com/milvus-io/grpc-milvus/archive/${GRPC_VERSION}.zip" "https://github.com/youny626/grpc-milvus/archive/${GRPC_VERSION}.zip" "https://gitee.com/quicksilver/grpc-milvus/repository/archive/${GRPC_VERSION}.zip") endif () diff --git a/core/src/server/grpc_impl/GrpcRequestHandler.cpp b/core/src/server/grpc_impl/GrpcRequestHandler.cpp index 8d583b5240..78afd03415 100644 --- a/core/src/server/grpc_impl/GrpcRequestHandler.cpp +++ b/core/src/server/grpc_impl/GrpcRequestHandler.cpp @@ -13,6 +13,7 @@ #include #include +#include #include #include @@ -158,6 +159,50 @@ ConstructCollectionInfo(const CollectionInfo& collection_info, ::milvus::grpc::C } // namespace +namespace { + +#define REQ_ID ("request_id") + +std::atomic _sequential_id; + +int64_t +get_sequential_id() { + return _sequential_id++; +} + +void +set_request_id(::grpc::ServerContext* context, const std::string& request_id) { + if (not context) { + // error + SERVER_LOG_ERROR << "set_request_id: grpc::ServerContext is nullptr" << std::endl; + return; + } + + context->AddInitialMetadata(REQ_ID, request_id); +} + +std::string +get_request_id(::grpc::ServerContext* context) { + if (not context) { + // error + SERVER_LOG_ERROR << "get_request_id: grpc::ServerContext is nullptr" << std::endl; + return "INVALID_ID"; + } + + auto server_metadata = context->server_metadata(); + + auto request_id_kv = server_metadata.find(REQ_ID); + if (request_id_kv == server_metadata.end()) { + // error + SERVER_LOG_ERROR << std::string(REQ_ID) << " not found in grpc.server_metadata" << std::endl; + return "INVALID_ID"; + } + + return request_id_kv->second.data(); +} + +} // namespace + GrpcRequestHandler::GrpcRequestHandler(const std::shared_ptr& tracer) : tracer_(tracer), random_num_generator_() { std::random_device random_device; @@ -187,16 +232,42 @@ GrpcRequestHandler::OnPostRecvInitialMetaData( return; } auto span = tracer_->StartSpan(server_rpc_info->method(), {opentracing::ChildOf(span_context_maybe->get())}); + auto server_context = server_rpc_info->server_context(); auto client_metadata = server_context->client_metadata(); - // TODO: request id + + // if client provide request_id in metadata, milvus just use it, + // else milvus generate a sequential id. std::string request_id; auto request_id_kv = client_metadata.find("request_id"); if (request_id_kv != client_metadata.end()) { request_id = request_id_kv->second.data(); + SERVER_LOG_DEBUG << "client provide request_id: " << request_id; + + // if request_id is being used by another request, + // convert it to request_id_n. + std::lock_guard lock(context_map_mutex_); + if (context_map_.find(request_id) == context_map_.end()) { + // if not found exist, mark + context_map_[request_id] = nullptr; + } else { + // Finding a unused suffix + int64_t suffix = 1; + std::string try_request_id; + bool exist = true; + do { + try_request_id = request_id + "_" + std::to_string(suffix); + exist = context_map_.find(try_request_id) != context_map_.end(); + suffix++; + } while (exist); + context_map_[try_request_id] = nullptr; + } } else { - request_id = std::to_string(random_id()) + std::to_string(random_id()); + request_id = std::to_string(get_sequential_id()); + set_request_id(server_context, request_id); + SERVER_LOG_DEBUG << "milvus generate request_id: " << request_id; } + auto trace_context = std::make_shared(span); auto context = std::make_shared(request_id); context->SetTraceContext(trace_context); @@ -207,23 +278,33 @@ void GrpcRequestHandler::OnPreSendMessage(::grpc::experimental::ServerRpcInfo* server_rpc_info, ::grpc::experimental::InterceptorBatchMethods* interceptor_batch_methods) { std::lock_guard lock(context_map_mutex_); - context_map_[server_rpc_info->server_context()]->GetTraceContext()->GetSpan()->Finish(); - auto search = context_map_.find(server_rpc_info->server_context()); - if (search != context_map_.end()) { - context_map_.erase(search); + auto request_id = get_request_id(server_rpc_info->server_context()); + + if (context_map_.find(request_id) == context_map_.end()) { + // error + SERVER_LOG_ERROR << "request_id " << request_id << " not found in context_map_"; + return; } + context_map_[request_id]->GetTraceContext()->GetSpan()->Finish(); + context_map_.erase(request_id); } const std::shared_ptr& GrpcRequestHandler::GetContext(::grpc::ServerContext* server_context) { std::lock_guard lock(context_map_mutex_); - return context_map_[server_context]; + auto request_id = get_request_id(server_context); + if (context_map_.find(request_id) == context_map_.end()) { + SERVER_LOG_ERROR << "GetContext: request_id " << request_id << " not found in context_map_"; + return nullptr; + } + return context_map_[request_id]; } void GrpcRequestHandler::SetContext(::grpc::ServerContext* server_context, const std::shared_ptr& context) { std::lock_guard lock(context_map_mutex_); - context_map_[server_context] = context; + auto request_id = get_request_id(server_context); + context_map_[request_id] = context; } uint64_t @@ -244,7 +325,7 @@ GrpcRequestHandler::CreateCollection(::grpc::ServerContext* context, const ::mil CHECK_NULLPTR_RETURN(request); Status status = - request_handler_.CreateCollection(context_map_[context], request->collection_name(), request->dimension(), + request_handler_.CreateCollection(GetContext(context), request->collection_name(), request->dimension(), request->index_file_size(), request->metric_type()); SET_RESPONSE(response, status, context); @@ -258,7 +339,7 @@ GrpcRequestHandler::HasCollection(::grpc::ServerContext* context, const ::milvus bool has_collection = false; - Status status = request_handler_.HasCollection(context_map_[context], request->collection_name(), has_collection); + Status status = request_handler_.HasCollection(GetContext(context), request->collection_name(), has_collection); response->set_bool_reply(has_collection); SET_RESPONSE(response->mutable_status(), status, context); @@ -270,7 +351,7 @@ GrpcRequestHandler::DropCollection(::grpc::ServerContext* context, const ::milvu ::milvus::grpc::Status* response) { CHECK_NULLPTR_RETURN(request); - Status status = request_handler_.DropCollection(context_map_[context], request->collection_name()); + Status status = request_handler_.DropCollection(GetContext(context), request->collection_name()); SET_RESPONSE(response, status, context); return ::grpc::Status::OK; @@ -289,8 +370,8 @@ GrpcRequestHandler::CreateIndex(::grpc::ServerContext* context, const ::milvus:: } } - Status status = request_handler_.CreateIndex(context_map_[context], request->collection_name(), - request->index_type(), json_params); + Status status = request_handler_.CreateIndex(GetContext(context), request->collection_name(), request->index_type(), + json_params); SET_RESPONSE(response, status, context); return ::grpc::Status::OK; @@ -309,7 +390,7 @@ GrpcRequestHandler::Insert(::grpc::ServerContext* context, const ::milvus::grpc: // step 2: insert vectors Status status = - request_handler_.Insert(context_map_[context], request->collection_name(), vectors, request->partition_tag()); + request_handler_.Insert(GetContext(context), request->collection_name(), vectors, request->partition_tag()); // step 3: return id array response->mutable_vector_id_array()->Resize(static_cast(vectors.id_array_.size()), 0); @@ -329,7 +410,7 @@ GrpcRequestHandler::GetVectorByID(::grpc::ServerContext* context, const ::milvus std::vector vector_ids = {request->id()}; engine::VectorsData vectors; Status status = - request_handler_.GetVectorByID(context_map_[context], request->collection_name(), vector_ids, vectors); + request_handler_.GetVectorByID(GetContext(context), request->collection_name(), vector_ids, vectors); if (!vectors.float_data_.empty()) { response->mutable_vector_data()->mutable_float_data()->Resize(vectors.float_data_.size(), 0); @@ -351,7 +432,7 @@ GrpcRequestHandler::GetVectorIDs(::grpc::ServerContext* context, const ::milvus: CHECK_NULLPTR_RETURN(request); std::vector vector_ids; - Status status = request_handler_.GetVectorIDs(context_map_[context], request->collection_name(), + Status status = request_handler_.GetVectorIDs(GetContext(context), request->collection_name(), request->segment_name(), vector_ids); if (!vector_ids.empty()) { @@ -393,7 +474,8 @@ GrpcRequestHandler::Search(::grpc::ServerContext* context, const ::milvus::grpc: std::vector file_ids; TopKQueryResult result; fiu_do_on("GrpcRequestHandler.Search.not_empty_file_ids", file_ids.emplace_back("test_file_id")); - Status status = request_handler_.Search(context_map_[context], request->collection_name(), vectors, request->topk(), + + Status status = request_handler_.Search(GetContext(context), request->collection_name(), vectors, request->topk(), json_params, partitions, file_ids, result); // step 5: construct and return result @@ -428,7 +510,7 @@ GrpcRequestHandler::SearchByID(::grpc::ServerContext* context, const ::milvus::g // step 3: search vectors TopKQueryResult result; - Status status = request_handler_.SearchByID(context_map_[context], request->collection_name(), request->id(), + Status status = request_handler_.SearchByID(GetContext(context), request->collection_name(), request->id(), request->topk(), json_params, partitions, result); // step 4: construct and return result @@ -474,7 +556,7 @@ GrpcRequestHandler::SearchInFiles(::grpc::ServerContext* context, const ::milvus // step 5: search vectors TopKQueryResult result; - Status status = request_handler_.Search(context_map_[context], search_request->collection_name(), vectors, + Status status = request_handler_.Search(GetContext(context), search_request->collection_name(), vectors, search_request->topk(), json_params, partitions, file_ids, result); // step 6: construct and return result @@ -492,7 +574,7 @@ GrpcRequestHandler::DescribeCollection(::grpc::ServerContext* context, const ::m CollectionSchema collection_schema; Status status = - request_handler_.DescribeCollection(context_map_[context], request->collection_name(), collection_schema); + request_handler_.DescribeCollection(GetContext(context), request->collection_name(), collection_schema); response->set_collection_name(collection_schema.collection_name_); response->set_dimension(collection_schema.dimension_); response->set_index_file_size(collection_schema.index_file_size_); @@ -508,7 +590,7 @@ GrpcRequestHandler::CountCollection(::grpc::ServerContext* context, const ::milv CHECK_NULLPTR_RETURN(request); int64_t row_count = 0; - Status status = request_handler_.CountCollection(context_map_[context], request->collection_name(), row_count); + Status status = request_handler_.CountCollection(GetContext(context), request->collection_name(), row_count); response->set_collection_row_count(row_count); SET_RESPONSE(response->mutable_status(), status, context); return ::grpc::Status::OK; @@ -520,7 +602,7 @@ GrpcRequestHandler::ShowCollections(::grpc::ServerContext* context, const ::milv CHECK_NULLPTR_RETURN(request); std::vector collections; - Status status = request_handler_.ShowCollections(context_map_[context], collections); + Status status = request_handler_.ShowCollections(GetContext(context), collections); for (auto& collection : collections) { response->add_collection_names(collection); } @@ -536,7 +618,7 @@ GrpcRequestHandler::ShowCollectionInfo(::grpc::ServerContext* context, const ::m CollectionInfo collection_info; Status status = - request_handler_.ShowCollectionInfo(context_map_[context], request->collection_name(), collection_info); + request_handler_.ShowCollectionInfo(GetContext(context), request->collection_name(), collection_info); ConstructCollectionInfo(collection_info, response); SET_RESPONSE(response->mutable_status(), status, context); @@ -549,7 +631,7 @@ GrpcRequestHandler::Cmd(::grpc::ServerContext* context, const ::milvus::grpc::Co CHECK_NULLPTR_RETURN(request); std::string reply; - Status status = request_handler_.Cmd(context_map_[context], request->cmd(), reply); + Status status = request_handler_.Cmd(GetContext(context), request->cmd(), reply); response->set_string_reply(reply); SET_RESPONSE(response->mutable_status(), status, context); @@ -568,7 +650,7 @@ GrpcRequestHandler::DeleteByID(::grpc::ServerContext* context, const ::milvus::g } // step 2: delete vector - Status status = request_handler_.DeleteByID(context_map_[context], request->collection_name(), vector_ids); + Status status = request_handler_.DeleteByID(GetContext(context), request->collection_name(), vector_ids); SET_RESPONSE(response, status, context); return ::grpc::Status::OK; @@ -579,7 +661,7 @@ GrpcRequestHandler::PreloadCollection(::grpc::ServerContext* context, const ::mi ::milvus::grpc::Status* response) { CHECK_NULLPTR_RETURN(request); - Status status = request_handler_.PreloadCollection(context_map_[context], request->collection_name()); + Status status = request_handler_.PreloadCollection(GetContext(context), request->collection_name()); SET_RESPONSE(response, status, context); return ::grpc::Status::OK; @@ -591,7 +673,7 @@ GrpcRequestHandler::DescribeIndex(::grpc::ServerContext* context, const ::milvus CHECK_NULLPTR_RETURN(request); IndexParam param; - Status status = request_handler_.DescribeIndex(context_map_[context], request->collection_name(), param); + Status status = request_handler_.DescribeIndex(GetContext(context), request->collection_name(), param); response->set_collection_name(param.collection_name_); response->set_index_type(param.index_type_); ::milvus::grpc::KeyValuePair* kv = response->add_extra_params(); @@ -607,7 +689,7 @@ GrpcRequestHandler::DropIndex(::grpc::ServerContext* context, const ::milvus::gr ::milvus::grpc::Status* response) { CHECK_NULLPTR_RETURN(request); - Status status = request_handler_.DropIndex(context_map_[context], request->collection_name()); + Status status = request_handler_.DropIndex(GetContext(context), request->collection_name()); SET_RESPONSE(response, status, context); return ::grpc::Status::OK; @@ -618,7 +700,7 @@ GrpcRequestHandler::CreatePartition(::grpc::ServerContext* context, const ::milv ::milvus::grpc::Status* response) { CHECK_NULLPTR_RETURN(request); - Status status = request_handler_.CreatePartition(context_map_[context], request->collection_name(), request->tag()); + Status status = request_handler_.CreatePartition(GetContext(context), request->collection_name(), request->tag()); SET_RESPONSE(response, status, context); return ::grpc::Status::OK; @@ -630,7 +712,7 @@ GrpcRequestHandler::ShowPartitions(::grpc::ServerContext* context, const ::milvu CHECK_NULLPTR_RETURN(request); std::vector partitions; - Status status = request_handler_.ShowPartitions(context_map_[context], request->collection_name(), partitions); + Status status = request_handler_.ShowPartitions(GetContext(context), request->collection_name(), partitions); for (auto& partition : partitions) { response->add_partition_tag_array(partition.tag_); } @@ -645,7 +727,7 @@ GrpcRequestHandler::DropPartition(::grpc::ServerContext* context, const ::milvus ::milvus::grpc::Status* response) { CHECK_NULLPTR_RETURN(request); - Status status = request_handler_.DropPartition(context_map_[context], request->collection_name(), request->tag()); + Status status = request_handler_.DropPartition(GetContext(context), request->collection_name(), request->tag()); SET_RESPONSE(response, status, context); return ::grpc::Status::OK; @@ -660,7 +742,7 @@ GrpcRequestHandler::Flush(::grpc::ServerContext* context, const ::milvus::grpc:: for (int32_t i = 0; i < request->collection_name_array().size(); i++) { collection_names.push_back(request->collection_name_array(i)); } - Status status = request_handler_.Flush(context_map_[context], collection_names); + Status status = request_handler_.Flush(GetContext(context), collection_names); SET_RESPONSE(response, status, context); return ::grpc::Status::OK; @@ -671,7 +753,7 @@ GrpcRequestHandler::Compact(::grpc::ServerContext* context, const ::milvus::grpc ::milvus::grpc::Status* response) { CHECK_NULLPTR_RETURN(request); - Status status = request_handler_.Compact(context_map_[context], request->collection_name()); + Status status = request_handler_.Compact(GetContext(context), request->collection_name()); SET_RESPONSE(response, status, context); return ::grpc::Status::OK; diff --git a/core/src/server/grpc_impl/GrpcRequestHandler.h b/core/src/server/grpc_impl/GrpcRequestHandler.h index 5f2781f350..385338272f 100644 --- a/core/src/server/grpc_impl/GrpcRequestHandler.h +++ b/core/src/server/grpc_impl/GrpcRequestHandler.h @@ -11,6 +11,7 @@ #pragma once +#include #include #include @@ -311,7 +312,8 @@ class GrpcRequestHandler final : public ::milvus::grpc::MilvusService::Service, private: RequestHandler request_handler_; - std::unordered_map<::grpc::ServerContext*, std::shared_ptr> context_map_; + // std::unordered_map<::grpc::ServerContext*, std::shared_ptr> context_map_; + std::unordered_map> context_map_; std::shared_ptr tracer_; // std::unordered_map<::grpc::ServerContext*, std::unique_ptr> span_map_;