diff --git a/lib/src/HttpControllersRouter.cc b/lib/src/HttpControllersRouter.cc index c39bbfaf..07beb909 100644 --- a/lib/src/HttpControllersRouter.cc +++ b/lib/src/HttpControllersRouter.cc @@ -44,11 +44,8 @@ void HttpControllersRouter::doWhenNoHandlerFound( void HttpControllersRouter::init( const std::vector & /*ioLoops*/) { - for (auto &router : ctrlVector_) - { - router.regex_ = std::regex(router.pathParameterPattern_, - std::regex_constants::icase); - for (auto &binder : router.binders_) + auto initFilters = [](auto &binders) { + for (auto &binder : binders) { if (binder) { @@ -56,6 +53,21 @@ void HttpControllersRouter::init( filters_function::createFilters(binder->filterNames_); } } + }; + + for (auto &router : ctrlVector_) + { + router.regex_ = std::regex(router.pathParameterPattern_, + std::regex_constants::icase); + initFilters(router.binders_); + } + + for (auto &p : ctrlMap_) + { + auto &router = p.second; + router.regex_ = std::regex(router.pathParameterPattern_, + std::regex_constants::icase); + initFilters(router.binders_); } } @@ -63,8 +75,7 @@ std::vector> HttpControllersRouter::getHandlersInfo() const { std::vector> ret; - for (auto &item : ctrlVector_) - { + auto gatherInfo = [&](const auto &item) { for (size_t i = 0; i < Invalid; ++i) { if (item.binders_[i]) @@ -80,6 +91,14 @@ HttpControllersRouter::getHandlersInfo() const ret.emplace_back(std::move(info)); } } + }; + for (auto &item : ctrlVector_) + { + gatherInfo(item); + } + for (auto &data : ctrlMap_) + { + gatherInfo(data.second); } return ret; } @@ -104,7 +123,7 @@ void HttpControllersRouter::addHttpRegex( { if (router.pathParameterPattern_ == regExp) { - if (validMethods.size() > 0) + if (!validMethods.empty()) { for (auto const &method : validMethods) { @@ -126,7 +145,7 @@ void HttpControllersRouter::addHttpRegex( struct HttpControllerRouterItem router; router.pathParameterPattern_ = regExp; router.pathPattern_ = regExp; - if (validMethods.size() > 0) + if (!validMethods.empty()) { for (auto const &method : validMethods) { @@ -153,8 +172,8 @@ void HttpControllersRouter::addHttpPath( // Path is like /api/v1/service/method/{1}/{2}/xxx... std::vector places; std::string tmpPath = path; - std::string paras = ""; - std::regex regex = std::regex("\\{([^/]*)\\}"); + std::string paras; + static const std::regex regex("\\{([^/]*)\\}"); std::smatch results; auto pos = tmpPath.find('?'); if (pos != std::string::npos) @@ -174,7 +193,7 @@ void HttpControllersRouter::addHttpPath( return std::isdigit(c); })) { - size_t place = (size_t)std::stoi(result); + auto place = (size_t)std::stoi(result); if (place > binder->paramCount() || place == 0) { LOG_ERROR << "Parameter placeholder(value=" << place @@ -196,13 +215,13 @@ void HttpControllersRouter::addHttpPath( } else { - std::regex regNumberAndName("([0-9]+):.*"); + static const std::regex regNumberAndName("([0-9]+):.*"); std::smatch regexResult; if (std::regex_match(result, regexResult, regNumberAndName)) { assert(regexResult.size() == 2 && regexResult[1].matched); auto num = regexResult[1].str(); - size_t place = (size_t)std::stoi(num); + auto place = (size_t)std::stoi(num); if (place > binder->paramCount() || place == 0) { LOG_ERROR << "Parameter placeholder(value=" << place @@ -247,7 +266,7 @@ void HttpControllersRouter::addHttpPath( std::vector> parametersPlaces; if (!paras.empty()) { - std::regex pregex("([^&]*)=\\{([^&]*)\\}&*"); + static const std::regex pregex("([^&]*)=\\{([^&]*)\\}&*"); while (std::regex_search(paras, results, pregex)) { if (results.size() > 2) @@ -258,7 +277,7 @@ void HttpControllersRouter::addHttpPath( return std::isdigit(c); })) { - size_t place = (size_t)std::stoi(result); + auto place = (size_t)std::stoi(result); if (place > binder->paramCount() || place == 0) { LOG_ERROR << "Parameter placeholder(value=" << place @@ -296,7 +315,7 @@ void HttpControllersRouter::addHttpPath( assert(regexResult.size() == 2 && regexResult[1].matched); auto num = regexResult[1].str(); - size_t place = (size_t)std::stoi(num); + auto place = (size_t)std::stoi(num); if (place > binder->paramCount() || place == 0) { LOG_ERROR << "Parameter placeholder(value=" << place @@ -367,34 +386,53 @@ void HttpControllersRouter::addHttpPath( // Recreate this with the correct number of threads. binderInfo->responseCache_ = IOThreadStorage(); }); + bool routingRequiresRegex = (path != pathParameterPattern); + HttpControllerRouterItem *existingRouterItemPtr = nullptr; + + // If exists another controllers on the same route. Updathe them then exit + if (routingRequiresRegex) { for (auto &router : ctrlVector_) { if (router.pathParameterPattern_ == pathParameterPattern) - { - if (validMethods.size() > 0) - { - for (auto const &method : validMethods) - { - router.binders_[method] = binderInfo; - if (method == Options) - binderInfo->isCORS_ = true; - } - } - else - { - binderInfo->isCORS_ = true; - for (int i = 0; i < Invalid; ++i) - router.binders_[i] = binderInfo; - } - return; - } + existingRouterItemPtr = &router; } } + else + { + std::string loweredPath; + loweredPath.resize(path.size()); + std::transform(path.begin(), path.end(), loweredPath.begin(), tolower); + auto it = ctrlMap_.find(loweredPath); + if (it != ctrlMap_.end()) + existingRouterItemPtr = &it->second; + } + + if (existingRouterItemPtr != nullptr) + { + auto &router = *existingRouterItemPtr; + if (!validMethods.empty()) + { + for (auto const &method : validMethods) + { + router.binders_[method] = binderInfo; + if (method == Options) + binderInfo->isCORS_ = true; + } + } + else + { + binderInfo->isCORS_ = true; + for (int i = 0; i < Invalid; ++i) + router.binders_[i] = binderInfo; + } + return; + } + struct HttpControllerRouterItem router; router.pathParameterPattern_ = pathParameterPattern; router.pathPattern_ = path; - if (validMethods.size() > 0) + if (!validMethods.empty()) { for (auto const &method : validMethods) { @@ -409,7 +447,16 @@ void HttpControllersRouter::addHttpPath( for (int i = 0; i < Invalid; ++i) router.binders_[i] = binderInfo; } - ctrlVector_.push_back(std::move(router)); + + if (routingRequiresRegex) + ctrlVector_.push_back(std::move(router)); + else + { + std::string loweredPath; + loweredPath.resize(path.size()); + std::transform(path.begin(), path.end(), loweredPath.begin(), tolower); + ctrlMap_[loweredPath] = std::move(router); + } } void HttpControllersRouter::route( @@ -417,122 +464,145 @@ void HttpControllersRouter::route( std::function &&callback) { // Find http controller - for (auto &routerItem : ctrlVector_) + HttpControllerRouterItem *routerItemPtr = nullptr; + std::smatch result; + std::string loweredPath = req->path(); + std::transform(loweredPath.begin(), + loweredPath.end(), + loweredPath.begin(), + tolower); + + auto it = ctrlMap_.find(loweredPath); + // Try to find a controller in the hash map. If can't linear search + // with regex. + if (it != ctrlMap_.end()) { - std::smatch result; - auto const &ctrlRegex = routerItem.regex_; - if (std::regex_match(req->path(), result, ctrlRegex)) + routerItemPtr = &it->second; + } + else + { + for (auto &item : ctrlVector_) { - assert(Invalid > req->method()); - req->setMatchedPathPattern(routerItem.pathPattern_); - auto &binder = routerItem.binders_[req->method()]; - if (!binder) + auto const &ctrlRegex = item.regex_; + if (std::regex_match(req->path(), result, ctrlRegex)) { - // Invalid Http Method - if (req->method() != Options) - { - callback( - app().getCustomErrorHandler()(k405MethodNotAllowed)); - } - else - { - callback(app().getCustomErrorHandler()(k403Forbidden)); - } - return; + routerItemPtr = &item; + break; } - if (!postRoutingObservers_.empty()) - { - for (auto &observer : postRoutingObservers_) - { - observer(req); - } - } - if (postRoutingAdvices_.empty()) - { - if (!binder->filters_.empty()) - { - auto &filters = binder->filters_; - auto callbackPtr = std::make_shared< - std::function>( - std::move(callback)); - filters_function::doFilters( - filters, - req, - callbackPtr, - [req, - callbackPtr, - this, - &binder, - &routerItem, - result = std::move(result)]() mutable { - doPreHandlingAdvices(binder, - routerItem, - req, - std::move(result), - std::move(*callbackPtr)); - }); - } - else - { - doPreHandlingAdvices(binder, - routerItem, - req, - std::move(result), - std::move(callback)); - } - } - else - { - auto callbackPtr = std::make_shared< - std::function>( - std::move(callback)); - doAdvicesChain( - postRoutingAdvices_, - 0, - req, - callbackPtr, - [&binder, - callbackPtr, - req, - this, - &routerItem, - result = std::move(result)]() mutable { - if (!binder->filters_.empty()) - { - auto &filters = binder->filters_; - filters_function::doFilters( - filters, - req, - callbackPtr, - [this, - req, - callbackPtr, - &binder, - &routerItem, - result = std::move(result)]() mutable { - doPreHandlingAdvices(binder, - routerItem, - req, - std::move(result), - std::move( - *callbackPtr)); - }); - } - else - { - doPreHandlingAdvices(binder, - routerItem, - req, - std::move(result), - std::move(*callbackPtr)); - } - }); - } - return; } } + // No handler found - doWhenNoHandlerFound(req, std::move(callback)); + if (routerItemPtr == nullptr) + { + doWhenNoHandlerFound(req, std::move(callback)); + return; + } + + HttpControllerRouterItem &routerItem = *routerItemPtr; + assert(Invalid > req->method()); + req->setMatchedPathPattern(routerItem.pathPattern_); + auto &binder = routerItem.binders_[req->method()]; + if (!binder) + { + // Invalid Http Method + if (req->method() != Options) + { + callback(app().getCustomErrorHandler()(k405MethodNotAllowed)); + } + else + { + callback(app().getCustomErrorHandler()(k403Forbidden)); + } + return; + } + if (!postRoutingObservers_.empty()) + { + for (auto &observer : postRoutingObservers_) + { + observer(req); + } + } + if (postRoutingAdvices_.empty()) + { + if (!binder->filters_.empty()) + { + auto &filters = binder->filters_; + auto callbackPtr = + std::make_shared>( + std::move(callback)); + filters_function::doFilters(filters, + req, + callbackPtr, + [req, + callbackPtr, + this, + &binder, + &routerItem, + result = std::move(result)]() mutable { + doPreHandlingAdvices( + binder, + routerItem, + req, + std::move(result), + std::move(*callbackPtr)); + }); + } + else + { + doPreHandlingAdvices(binder, + routerItem, + req, + std::move(result), + std::move(callback)); + } + } + else + { + auto callbackPtr = + std::make_shared>( + std::move(callback)); + doAdvicesChain(postRoutingAdvices_, + 0, + req, + callbackPtr, + [&binder, + callbackPtr, + req, + this, + &routerItem, + result = std::move(result)]() mutable { + if (!binder->filters_.empty()) + { + auto &filters = binder->filters_; + filters_function::doFilters( + filters, + req, + callbackPtr, + [this, + req, + callbackPtr, + &binder, + &routerItem, + result = std::move(result)]() mutable { + doPreHandlingAdvices(binder, + routerItem, + req, + std::move(result), + std::move( + *callbackPtr)); + }); + } + else + { + doPreHandlingAdvices(binder, + routerItem, + req, + std::move(result), + std::move(*callbackPtr)); + } + }); + } } void HttpControllersRouter::doControllerHandler( @@ -620,7 +690,6 @@ void HttpControllersRouter::doControllerHandler( } invokeCallback(callback, req, resp); }); - return; } void HttpControllersRouter::doPreHandlingAdvices( diff --git a/lib/src/HttpControllersRouter.h b/lib/src/HttpControllersRouter.h index 3b19610e..5ad26b00 100644 --- a/lib/src/HttpControllersRouter.h +++ b/lib/src/HttpControllersRouter.h @@ -96,6 +96,7 @@ class HttpControllersRouter : public trantor::NonCopyable CtrlBinderPtr binders_[Invalid]{ nullptr}; // The enum value of Invalid is the http methods number }; + std::unordered_map ctrlMap_; std::vector ctrlVector_; const std::vector TEST_CTX) REQUIRE(result == ReqResult::Ok); std::shared_ptr ret = *resp; - REQUIRE(resp != nullptr); + REQUIRE(ret != nullptr); CHECK((*ret)["result"].asString() == "ok"); }); // Post json again @@ -142,7 +142,7 @@ void doTest(const HttpClientPtr &client, std::shared_ptr TEST_CTX) REQUIRE(result == ReqResult::Ok); std::shared_ptr ret = *resp; - REQUIRE(resp != nullptr); + REQUIRE(ret != nullptr); CHECK((*ret)["result"].asString() == "ok"); }); @@ -156,7 +156,7 @@ void doTest(const HttpClientPtr &client, std::shared_ptr TEST_CTX) REQUIRE(result == ReqResult::Ok); std::shared_ptr ret = *resp; - REQUIRE(resp != nullptr); + REQUIRE(ret != nullptr); CHECK((*ret)["result"].asString() == "ok"); }); @@ -527,7 +527,7 @@ void doTest(const HttpClientPtr &client, std::shared_ptr TEST_CTX) const HttpResponsePtr &resp) { REQUIRE(result == ReqResult::Ok); auto ret = resp->getJsonObject(); - CHECK(ret != nullptr); + REQUIRE(ret != nullptr); CHECK((*ret)["result"].asString() == "ok"); }); @@ -540,7 +540,7 @@ void doTest(const HttpClientPtr &client, std::shared_ptr TEST_CTX) const HttpResponsePtr &resp) { REQUIRE(result == ReqResult::Ok); auto ret = resp->getJsonObject(); - CHECK(ret != nullptr); + REQUIRE(ret != nullptr); CHECK((*ret)["result"].asString() == "ok"); }); @@ -579,7 +579,6 @@ void doTest(const HttpClientPtr &client, std::shared_ptr TEST_CTX) body->end(), resp->getBody().begin())); }); - // return; // Test file upload UploadFile file1("./drogon.jpg"); UploadFile file2("./drogon.jpg", "drogon1.jpg"); @@ -593,13 +592,12 @@ void doTest(const HttpClientPtr &client, std::shared_ptr TEST_CTX) const HttpResponsePtr &resp) { REQUIRE(result == ReqResult::Ok); auto json = resp->getJsonObject(); - CHECK(json != nullptr); + REQUIRE(json != nullptr); CHECK((*json)["result"].asString() == "ok"); CHECK((*json)["P1"] == "upload"); CHECK((*json)["P2"] == "test"); }); - // return; // Test file upload, file type and extension interface. UploadFile image("./drogon.jpg"); req = HttpRequest::newFileUploadRequest({image}); @@ -611,11 +609,12 @@ void doTest(const HttpClientPtr &client, std::shared_ptr TEST_CTX) const HttpResponsePtr &resp) { REQUIRE(result == ReqResult::Ok); auto json = resp->getJsonObject(); - CHECK(json != nullptr); + REQUIRE(json != nullptr); CHECK((*json)["P1"] == "upload"); CHECK((*json)["P2"] == "test"); }); + // Test exception handling req = HttpRequest::newHttpRequest(); req->setMethod(drogon::Get); req->setPath("/api/v1/this_will_fail"); @@ -625,6 +624,129 @@ void doTest(const HttpClientPtr &client, std::shared_ptr TEST_CTX) CHECK(resp->getStatusCode() == k500InternalServerError); }); + // The result of this API is cached for (almost) forever. And the endpoint + // increments a internal counter on each invoke. This tests if the respond + // is taken from the cache after the first invoke. + // Try poking the cache test endpoint 3 times. They should all respond 0 + // since the first respond is cached by the server. + req = HttpRequest::newHttpRequest(); + req->setMethod(drogon::Get); + req->setPath("/api/v1/ApiTest/cacheTest"); + client->sendRequest(req, + [req, TEST_CTX](ReqResult result, + const HttpResponsePtr &resp) { + REQUIRE(result == ReqResult::Ok); + CHECK(resp->getStatusCode() == k200OK); + CHECK(resp->body() == "0"); + }); + + req = HttpRequest::newHttpRequest(); + req->setMethod(drogon::Get); + req->setPath("/api/v1/ApiTest/cacheTest"); + client->sendRequest(req, + [req, TEST_CTX](ReqResult result, + const HttpResponsePtr &resp) { + REQUIRE(result == ReqResult::Ok); + CHECK(resp->getStatusCode() == k200OK); + CHECK(resp->body() == "0"); + }); + + req = HttpRequest::newHttpRequest(); + req->setMethod(drogon::Get); + req->setPath("/api/v1/ApiTest/cacheTest"); + client->sendRequest(req, + [req, TEST_CTX](ReqResult result, + const HttpResponsePtr &resp) { + REQUIRE(result == ReqResult::Ok); + CHECK(resp->getStatusCode() == k200OK); + CHECK(resp->body() == "0"); + }); + + // This API caches it's result on the third (counting from 1) calls. Thus + // we expect to always see 2 upon the third call. And all previous calls + // should be less than or equal to 2, as another test is also poking the API + req = HttpRequest::newHttpRequest(); + req->setMethod(drogon::Get); + req->setPath("/api/v1/ApiTest/cacheTest2"); + client->sendRequest( + req, [req, TEST_CTX](ReqResult result, const HttpResponsePtr &resp) { + REQUIRE(result == ReqResult::Ok); + CHECK(resp->getStatusCode() == k200OK); + int n; + CHECK_NOTHROW(n = std::stoi(std::string(resp->body()))); + CHECK(n <= 2); + }); + + req = HttpRequest::newHttpRequest(); + req->setMethod(drogon::Get); + req->setPath("/api/v1/ApiTest/cacheTest2"); + client->sendRequest( + req, [req, TEST_CTX](ReqResult result, const HttpResponsePtr &resp) { + REQUIRE(result == ReqResult::Ok); + CHECK(resp->getStatusCode() == k200OK); + int n; + CHECK_NOTHROW(n = std::stoi(std::string(resp->body()))); + CHECK(n <= 2); + }); + + req = HttpRequest::newHttpRequest(); + req->setMethod(drogon::Get); + req->setPath("/api/v1/ApiTest/cacheTest2"); + client->sendRequest(req, + [req, TEST_CTX](ReqResult result, + const HttpResponsePtr &resp) { + REQUIRE(result == ReqResult::Ok); + CHECK(resp->getStatusCode() == k200OK); + CHECK(resp->body() == "2"); + }); + + // Same as cacheTest2. But the server has to handle this API through regex. + // it is intentionally made that the final part of the path can't conatin + // a "z" character + req = HttpRequest::newHttpRequest(); + req->setMethod(drogon::Get); + req->setPath("/cacheTestRegex/foobar"); + client->sendRequest( + req, [req, TEST_CTX](ReqResult result, const HttpResponsePtr &resp) { + REQUIRE(result == ReqResult::Ok); + CHECK(resp->getStatusCode() == k200OK); + int n; + CHECK_NOTHROW(n = std::stoi(std::string(resp->body()))); + CHECK(n <= 2); + }); + + req = HttpRequest::newHttpRequest(); + req->setMethod(drogon::Get); + req->setPath("/cacheTestRegex/deadbeef"); + client->sendRequest( + req, [req, TEST_CTX](ReqResult result, const HttpResponsePtr &resp) { + REQUIRE(result == ReqResult::Ok); + CHECK(resp->getStatusCode() == k200OK); + int n; + CHECK_NOTHROW(n = std::stoi(std::string(resp->body()))); + CHECK(n <= 2); + }); + + req = HttpRequest::newHttpRequest(); + req->setMethod(drogon::Get); + req->setPath("/cacheTestRegex/leet"); + client->sendRequest(req, + [req, TEST_CTX](ReqResult result, + const HttpResponsePtr &resp) { + REQUIRE(result == ReqResult::Ok); + CHECK(resp->getStatusCode() == k200OK); + CHECK(resp->body() == "2"); + }); + req = HttpRequest::newHttpRequest(); + req->setMethod(drogon::Get); + req->setPath("/cacheTestRegex/zebra"); + client->sendRequest(req, + [req, TEST_CTX](ReqResult result, + const HttpResponsePtr &resp) { + REQUIRE(result == ReqResult::Ok); + CHECK(resp->getStatusCode() == k404NotFound); + }); + #if defined(__cpp_impl_coroutine) sync_wait([client, TEST_CTX]() -> Task<> { // Test coroutine requests diff --git a/lib/tests/integration_test/server/api_v1_ApiTest.cc b/lib/tests/integration_test/server/api_v1_ApiTest.cc index 2449d571..12ded4f0 100644 --- a/lib/tests/integration_test/server/api_v1_ApiTest.cc +++ b/lib/tests/integration_test/server/api_v1_ApiTest.cc @@ -453,4 +453,58 @@ void ApiTest::regexTest(const HttpRequestPtr &req, ret["p2"] = std::move(p2); auto resp = HttpResponse::newHttpJsonResponse(std::move(ret)); callback(resp); -} \ No newline at end of file +} + +static std::mutex cacheTestMtx; +void ApiTest::cacheTest(const HttpRequestPtr &req, + std::function &&callback) +{ + std::unique_lock lk(cacheTestMtx); + static size_t callCount = 0; + + auto resp = HttpResponse::newHttpResponse(); + resp->setBody(std::to_string(callCount)); + resp->setContentTypeCode(CT_TEXT_PLAIN); + // Expire after a millennia + resp->setExpiredTime(31536000000); + callback(resp); + callCount++; +} + +static std::mutex cacheTest2Mtx; +void ApiTest::cacheTest2( + const HttpRequestPtr &req, + std::function &&callback) +{ + std::unique_lock lk(cacheTest2Mtx); + static size_t callCount = 0; + + auto resp = HttpResponse::newHttpResponse(); + LOG_ERROR << callCount; + resp->setBody(std::to_string(callCount)); + resp->setContentTypeCode(CT_TEXT_PLAIN); + // Expire after a millennia + if (callCount >= 2) + resp->setExpiredTime(31536000000); + callback(resp); + callCount++; +} + +static std::mutex regexCacheApiMtx; +void ApiTest::cacheTestRegex( + const HttpRequestPtr &req, + std::function &&callback) +{ + std::unique_lock lk(regexCacheApiMtx); + static size_t callCount = 0; + + auto resp = HttpResponse::newHttpResponse(); + LOG_ERROR << callCount; + resp->setBody(std::to_string(callCount)); + resp->setContentTypeCode(CT_TEXT_PLAIN); + // Expire after a millennia + if (callCount >= 2) + resp->setExpiredTime(31536000000); + callback(resp); + callCount++; +} diff --git a/lib/tests/integration_test/server/api_v1_ApiTest.h b/lib/tests/integration_test/server/api_v1_ApiTest.h index 867d386f..7da205a6 100644 --- a/lib/tests/integration_test/server/api_v1_ApiTest.h +++ b/lib/tests/integration_test/server/api_v1_ApiTest.h @@ -37,6 +37,11 @@ class ApiTest : public drogon::HttpController METHOD_ADD(ApiTest::formTest, "/form", Post); METHOD_ADD(ApiTest::attributesTest, "/attrs", Get); ADD_METHOD_VIA_REGEX(ApiTest::regexTest, "/reg/([0-9]*)/(.*)", Get); + METHOD_ADD(ApiTest::cacheTest, "/cacheTest", Get); + METHOD_ADD(ApiTest::cacheTest2, "/cacheTest2", Get); + ADD_METHOD_VIA_REGEX(ApiTest::cacheTestRegex, + "/cacheTestRegex/[a-y]+", + Get); METHOD_LIST_END void get(const HttpRequestPtr &req, @@ -73,6 +78,13 @@ class ApiTest : public drogon::HttpController { app().quit(); } + void cacheTest(const HttpRequestPtr &req, + std::function &&callback); + void cacheTest2(const HttpRequestPtr &req, + std::function &&callback); + void cacheTestRegex( + const HttpRequestPtr &req, + std::function &&callback); public: ApiTest()