Support per-method middlewares. (#2015)

This commit is contained in:
Nitromelon 2024-05-10 10:03:20 +08:00 committed by an-tao
parent abbcf6023d
commit 5b7cefd32c
27 changed files with 760 additions and 304 deletions

View File

@ -253,7 +253,7 @@ set(DROGON_SOURCES
lib/src/Cookie.cc
lib/src/DrClassMap.cc
lib/src/DrTemplateBase.cc
lib/src/FiltersFunction.cc
lib/src/MiddlewaresFunction.cc
lib/src/FixedWindowRateLimiter.cc
lib/src/GlobalFilters.cc
lib/src/Histogram.cc
@ -302,7 +302,7 @@ set(private_headers
lib/src/CacheFile.h
lib/src/ConfigLoader.h
lib/src/ControllerBinderBase.h
lib/src/FiltersFunction.h
lib/src/MiddlewaresFunction.h
lib/src/HttpAppFrameworkImpl.h
lib/src/HttpClientImpl.h
lib/src/HttpConnectionLimit.h
@ -557,6 +557,7 @@ set(DROGON_HEADERS
lib/inc/drogon/HttpClient.h
lib/inc/drogon/HttpController.h
lib/inc/drogon/HttpFilter.h
lib/inc/drogon/HttpMiddleware.h
lib/inc/drogon/HttpRequest.h
lib/inc/drogon/HttpResponse.h
lib/inc/drogon/HttpSimpleController.h

View File

@ -322,7 +322,7 @@ class DROGON_EXPORT HttpAppFramework : public trantor::NonCopyable
[Check Method]---------------->[405]----------->+
| |
v |
[Filters]------->[Filter callback]----------->+
[Filters/Middlewares]------>[Filter callback]------>+
| |
v Y |
[Is OPTIONS method?]------------->[200]----------->+
@ -335,6 +335,9 @@ class DROGON_EXPORT HttpAppFramework : public trantor::NonCopyable
| |
v |
Post-handling join point o---------------------------------------->+
| |
v |
[Middlewares post logic]--->[Middleware callback]--->+
@endcode
*
@ -368,7 +371,7 @@ class DROGON_EXPORT HttpAppFramework : public trantor::NonCopyable
/// Register an advice called after routing
/**
* @param advice is called immediately after the request matches a handler
* path and before any 'doFilter' method of filters applies. The parameters
* path and before any filters/middlewares applies. The parameters
* of the advice are same as those of the doFilter method of the Filter
* class.
*/
@ -390,8 +393,8 @@ class DROGON_EXPORT HttpAppFramework : public trantor::NonCopyable
/// Register an advice called before the request is handled
/**
* @param advice is called immediately after the request is approved by all
* filters and before it is handled. The parameters of the advice are
* same as those of the doFilter method of the Filter class.
* filters/middlewares and before it is handled. The parameters of the
* advice are same as those of the doFilter method of the Filter class.
*/
virtual HttpAppFramework &registerPreHandlingAdvice(
const std::function<void(const HttpRequestPtr &,
@ -472,8 +475,8 @@ class DROGON_EXPORT HttpAppFramework : public trantor::NonCopyable
* called.
* @param ctrlName is the name of the controller. It includes the namespace
* to which the controller belongs.
* @param filtersAndMethods is a vector containing Http methods or filter
* name constraints.
* @param constraints is a vector containing Http methods or middleware
names
*
* Example:
* @code
@ -487,8 +490,7 @@ class DROGON_EXPORT HttpAppFramework : public trantor::NonCopyable
virtual HttpAppFramework &registerHttpSimpleController(
const std::string &pathName,
const std::string &ctrlName,
const std::vector<internal::HttpConstraint> &filtersAndMethods =
std::vector<internal::HttpConstraint>{}) = 0;
const std::vector<internal::HttpConstraint> &constraints = {}) = 0;
/// Register a handler into the framework.
/**
@ -496,7 +498,7 @@ class DROGON_EXPORT HttpAppFramework : public trantor::NonCopyable
* pathPattern, the handler indicated by the function parameter is called.
* @param function indicates any type of callable object with a valid
* processing interface.
* @param filtersAndMethods is the same as the third parameter in the above
* @param constraints is the same as the third parameter in the above
* method.
*
* Example:
@ -522,8 +524,7 @@ class DROGON_EXPORT HttpAppFramework : public trantor::NonCopyable
HttpAppFramework &registerHandler(
const std::string &pathPattern,
FUNCTION &&function,
const std::vector<internal::HttpConstraint> &filtersAndMethods =
std::vector<internal::HttpConstraint>{},
const std::vector<internal::HttpConstraint> &constraints = {},
const std::string &handlerName = "")
{
LOG_TRACE << "pathPattern:" << pathPattern;
@ -533,17 +534,16 @@ class DROGON_EXPORT HttpAppFramework : public trantor::NonCopyable
getLoop()->queueInLoop([binder]() { binder->createHandlerInstance(); });
std::vector<HttpMethod> validMethods;
std::vector<std::string> filters;
for (auto const &filterOrMethod : filtersAndMethods)
std::vector<std::string> middlewares;
for (auto const &constraint : constraints)
{
if (filterOrMethod.type() == internal::ConstraintType::HttpFilter)
if (constraint.type() == internal::ConstraintType::HttpMiddleware)
{
filters.push_back(filterOrMethod.getFilterName());
middlewares.push_back(constraint.getMiddlewareName());
}
else if (filterOrMethod.type() ==
internal::ConstraintType::HttpMethod)
else if (constraint.type() == internal::ConstraintType::HttpMethod)
{
validMethods.push_back(filterOrMethod.getHttpMethod());
validMethods.push_back(constraint.getHttpMethod());
}
else
{
@ -552,7 +552,7 @@ class DROGON_EXPORT HttpAppFramework : public trantor::NonCopyable
}
}
registerHttpController(
pathPattern, binder, validMethods, filters, handlerName);
pathPattern, binder, validMethods, middlewares, handlerName);
return *this;
}
@ -566,8 +566,8 @@ class DROGON_EXPORT HttpAppFramework : public trantor::NonCopyable
* subexpression is sequentially mapped to a handler parameter.
* @param function indicates any type of callable object with a valid
* processing interface.
* @param filtersAndMethods is the same as the third parameter in the above
* method.
* @param constraints is the same as the third parameter in the
* above method.
* @param handlerName a name for the handler.
* @return HttpAppFramework&
*/
@ -575,8 +575,7 @@ class DROGON_EXPORT HttpAppFramework : public trantor::NonCopyable
HttpAppFramework &registerHandlerViaRegex(
const std::string &regExp,
FUNCTION &&function,
const std::vector<internal::HttpConstraint> &filtersAndMethods =
std::vector<internal::HttpConstraint>{},
const std::vector<internal::HttpConstraint> &constraints = {},
const std::string &handlerName = "")
{
LOG_TRACE << "regex:" << regExp;
@ -586,17 +585,16 @@ class DROGON_EXPORT HttpAppFramework : public trantor::NonCopyable
std::forward<FUNCTION>(function));
std::vector<HttpMethod> validMethods;
std::vector<std::string> filters;
for (auto const &filterOrMethod : filtersAndMethods)
std::vector<std::string> middlewares;
for (auto const &constraint : constraints)
{
if (filterOrMethod.type() == internal::ConstraintType::HttpFilter)
if (constraint.type() == internal::ConstraintType::HttpMiddleware)
{
filters.push_back(filterOrMethod.getFilterName());
middlewares.push_back(constraint.getMiddlewareName());
}
else if (filterOrMethod.type() ==
internal::ConstraintType::HttpMethod)
else if (constraint.type() == internal::ConstraintType::HttpMethod)
{
validMethods.push_back(filterOrMethod.getHttpMethod());
validMethods.push_back(constraint.getHttpMethod());
}
else
{
@ -605,7 +603,7 @@ class DROGON_EXPORT HttpAppFramework : public trantor::NonCopyable
}
}
registerHttpControllerViaRegex(
regExp, binder, validMethods, filters, handlerName);
regExp, binder, validMethods, middlewares, handlerName);
return *this;
}
@ -617,8 +615,7 @@ class DROGON_EXPORT HttpAppFramework : public trantor::NonCopyable
virtual HttpAppFramework &registerWebSocketController(
const std::string &pathName,
const std::string &ctrlName,
const std::vector<internal::HttpConstraint> &filtersAndMethods =
std::vector<internal::HttpConstraint>{}) = 0;
const std::vector<internal::HttpConstraint> &constraints = {}) = 0;
/// Register controller objects created and initialized by the user
/**
@ -919,7 +916,8 @@ class DROGON_EXPORT HttpAppFramework : public trantor::NonCopyable
* extension can be accessed.
* @param isRecursive If it is set to false, files in sub directories can't
* be accessed.
* @param filters The list of filters which acting on the location.
* @param middlewareNames The list of middlewares which acting on the
* location.
* @return HttpAppFramework&
*/
virtual HttpAppFramework &addALocation(
@ -929,7 +927,7 @@ class DROGON_EXPORT HttpAppFramework : public trantor::NonCopyable
bool isCaseSensitive = false,
bool allowAll = true,
bool isRecursive = true,
const std::vector<std::string> &filters = {}) = 0;
const std::vector<std::string> &middlewareNames = {}) = 0;
/// Set the path to store uploaded files.
/**
@ -1563,14 +1561,14 @@ class DROGON_EXPORT HttpAppFramework : public trantor::NonCopyable
virtual void registerHttpController(
const std::string &pathPattern,
const internal::HttpBinderBasePtr &binder,
const std::vector<HttpMethod> &validMethods = std::vector<HttpMethod>(),
const std::vector<std::string> &filters = std::vector<std::string>(),
const std::vector<HttpMethod> &validMethods = {},
const std::vector<std::string> &middlewareNames = {},
const std::string &handlerName = "") = 0;
virtual void registerHttpControllerViaRegex(
const std::string &regExp,
const internal::HttpBinderBasePtr &binder,
const std::vector<HttpMethod> &validMethods,
const std::vector<std::string> &filters,
const std::vector<std::string> &middlewareNames,
const std::string &handlerName) = 0;
};

View File

@ -66,8 +66,7 @@ class HttpController : public DrObject<T>, public HttpControllerBase
static void registerMethod(
FUNCTION &&function,
const std::string &pattern,
const std::vector<internal::HttpConstraint> &filtersAndMethods =
std::vector<internal::HttpConstraint>{},
const std::vector<internal::HttpConstraint> &constraints = {},
bool classNameInPath = true,
const std::string &handlerName = "")
{
@ -88,12 +87,12 @@ class HttpController : public DrObject<T>, public HttpControllerBase
if (pattern.empty() || pattern[0] == '/')
app().registerHandler(path + pattern,
std::forward<FUNCTION>(function),
filtersAndMethods,
constraints,
handlerName);
else
app().registerHandler(path + "/" + pattern,
std::forward<FUNCTION>(function),
filtersAndMethods,
constraints,
handlerName);
}
else
@ -105,7 +104,7 @@ class HttpController : public DrObject<T>, public HttpControllerBase
}
app().registerHandler(path,
std::forward<FUNCTION>(function),
filtersAndMethods,
constraints,
handlerName);
}
}
@ -114,13 +113,13 @@ class HttpController : public DrObject<T>, public HttpControllerBase
static void registerMethodViaRegex(
FUNCTION &&function,
const std::string &regExp,
const std::vector<internal::HttpConstraint> &filtersAndMethods =
const std::vector<internal::HttpConstraint> &constraints =
std::vector<internal::HttpConstraint>{},
const std::string &handlerName = "")
{
app().registerHandlerViaRegex(regExp,
std::forward<FUNCTION>(function),
filtersAndMethods,
constraints,
handlerName);
}

View File

@ -18,6 +18,7 @@
#include <drogon/drogon_callbacks.h>
#include <drogon/HttpRequest.h>
#include <drogon/HttpResponse.h>
#include <drogon/HttpMiddleware.h>
#include <memory>
#ifdef __cpp_impl_coroutine
@ -30,7 +31,8 @@ namespace drogon
* @brief The abstract base class for filters
* For more details on the class, see the wiki site (the 'Filter' section)
*/
class DROGON_EXPORT HttpFilterBase : public virtual DrObjectBase
class DROGON_EXPORT HttpFilterBase : public virtual DrObjectBase,
public HttpMiddlewareBase
{
public:
/// This virtual function should be overridden in subclasses.
@ -48,6 +50,24 @@ class DROGON_EXPORT HttpFilterBase : public virtual DrObjectBase
FilterCallback &&fcb,
FilterChainCallback &&fccb) = 0;
~HttpFilterBase() override = default;
private:
void invoke(const HttpRequestPtr &req,
MiddlewareNextCallback &&nextCb,
MiddlewareCallback &&mcb) final
{
auto mcbPtr = std::make_shared<MiddlewareCallback>(std::move(mcb));
doFilter(
req,
[mcbPtr](const HttpResponsePtr &resp) {
(*mcbPtr)(resp);
}, // fcb, intercept the response
[nextCb = std::move(nextCb), mcbPtr]() mutable {
nextCb([mcbPtr = std::move(mcbPtr)](
const HttpResponsePtr &resp) { (*mcbPtr)(resp); });
} // fccb, call the next middleware
);
}
};
/**
@ -55,7 +75,7 @@ class DROGON_EXPORT HttpFilterBase : public virtual DrObjectBase
*
* @tparam T The type of the implementation class
* @tparam AutoCreation The flag for automatically creating, user can set this
* flag to false for classes that have nondefault constructors.
* flag to false for classes that have non-default constructors.
*/
template <typename T, bool AutoCreation = true>
class HttpFilter : public DrObject<T>, public HttpFilterBase
@ -65,14 +85,6 @@ class HttpFilter : public DrObject<T>, public HttpFilterBase
~HttpFilter() override = default;
};
namespace internal
{
DROGON_EXPORT void handleException(
const std::exception &,
const HttpRequestPtr &,
std::function<void(const HttpResponsePtr &)> &&);
}
#ifdef __cpp_impl_coroutine
template <typename T, bool AutoCreation = true>
class HttpCoroFilter : public DrObject<T>, public HttpFilterBase

View File

@ -0,0 +1,151 @@
/**
*
* @file HttpMiddleware.h
* @author Nitromelon
*
* Copyright 2024, Nitromelon. All rights reserved.
* https://github.com/drogonframework/drogon
* Use of this source code is governed by a MIT license
* that can be found in the License file.
*
* Drogon
*
*/
#pragma once
#include <drogon/DrObject.h>
#include <drogon/drogon_callbacks.h>
#include <drogon/HttpRequest.h>
#include <drogon/HttpResponse.h>
#include <memory>
#ifdef __cpp_impl_coroutine
#include <drogon/utils/coroutine.h>
#endif
namespace drogon
{
/**
* @brief The abstract base class for middleware
*/
class DROGON_EXPORT HttpMiddlewareBase : public virtual DrObjectBase
{
public:
/**
* This virtual function should be overridden in subclasses.
*
* Example:
* @code
* void invoke(const HttpRequestPtr &req,
* MiddlewareNextCallback &&nextCb,
* MiddlewareCallback &&mcb) override
* {
* if (req->path() == "/some/path") {
* // intercept directly
* mcb(HttpResponse::newNotFoundResponse(req));
* return;
* }
* // Do something before calling the next middleware
* nextCb([mcb = std::move(mcb)](const HttpResponsePtr &resp) {
* // Do something after the next middleware returns
* mcb(resp);
* });
* }
* @endcode
*
*/
virtual void invoke(const HttpRequestPtr &req,
MiddlewareNextCallback &&nextCb,
MiddlewareCallback &&mcb) = 0;
~HttpMiddlewareBase() override = default;
};
/**
* @brief The reflection base class template for middlewares
*
* @tparam T The type of the implementation class
* @tparam AutoCreation The flag for automatically creating, user can set this
* flag to false for classes that have non-default constructors.
*/
template <typename T, bool AutoCreation = true>
class HttpMiddleware : public DrObject<T>, public HttpMiddlewareBase
{
public:
static constexpr bool isAutoCreation{AutoCreation};
~HttpMiddleware() override = default;
};
namespace internal
{
DROGON_EXPORT void handleException(
const std::exception &,
const HttpRequestPtr &,
std::function<void(const HttpResponsePtr &)> &&);
}
#ifdef __cpp_impl_coroutine
struct [[nodiscard]] MiddlewareNextAwaiter
: public CallbackAwaiter<HttpResponsePtr>
{
public:
MiddlewareNextAwaiter(MiddlewareNextCallback &&nextCb)
: nextCb_(std::move(nextCb))
{
}
void await_suspend(std::coroutine_handle<> handle) noexcept
{
nextCb_([this, handle](const HttpResponsePtr &resp) {
setValue(resp);
handle.resume();
});
}
private:
MiddlewareNextCallback nextCb_;
};
template <typename T, bool AutoCreation = true>
class HttpCoroMiddleware : public DrObject<T>, public HttpMiddlewareBase
{
public:
static constexpr bool isAutoCreation{AutoCreation};
~HttpCoroMiddleware() override = default;
void invoke(const HttpRequestPtr &req,
MiddlewareNextCallback &&nextCb,
MiddlewareCallback &&mcb) final
{
drogon::async_run([this,
req,
nextCb = std::move(nextCb),
mcb = std::move(mcb)]() mutable -> drogon::Task<> {
HttpResponsePtr resp;
try
{
resp = co_await invoke(req, {std::move(nextCb)});
}
catch (const std::exception &ex)
{
internal::handleException(ex, req, std::move(mcb));
co_return;
}
catch (...)
{
LOG_ERROR << "Exception not derived from std::exception";
co_return;
}
mcb(resp);
});
}
virtual Task<HttpResponsePtr> invoke(const HttpRequestPtr &req,
MiddlewareNextAwaiter &&next) = 0;
};
#endif
} // namespace drogon

View File

@ -76,7 +76,7 @@ class HttpSimpleController : public DrObject<T>, public HttpSimpleControllerBase
static void registerSelf__(
const std::string &path,
const std::vector<internal::HttpConstraint> &filtersAndMethods)
const std::vector<internal::HttpConstraint> &constraints)
{
LOG_TRACE << "register simple controller("
<< HttpSimpleController<T, AutoCreation>::classTypeName()
@ -84,7 +84,7 @@ class HttpSimpleController : public DrObject<T>, public HttpSimpleControllerBase
app().registerHttpSimpleController(
path,
HttpSimpleController<T, AutoCreation>::classTypeName(),
filtersAndMethods);
constraints);
}
private:

View File

@ -83,7 +83,7 @@ class WebSocketController : public DrObject<T>, public WebSocketControllerBase
static void registerSelf__(
const std::string &path,
const std::vector<internal::HttpConstraint> &filtersAndMethods)
const std::vector<internal::HttpConstraint> &constraints)
{
LOG_TRACE << "register websocket controller("
<< WebSocketController<T, AutoCreation>::classTypeName()
@ -91,7 +91,7 @@ class WebSocketController : public DrObject<T>, public WebSocketControllerBase
app().registerWebSocketController(
path,
WebSocketController<T, AutoCreation>::classTypeName(),
filtersAndMethods);
constraints);
}
private:

View File

@ -31,4 +31,9 @@ using AdviceDestroySessionCallback = std::function<void(const std::string &)>;
using FilterCallback = std::function<void(const HttpResponsePtr &)>;
using FilterChainCallback = std::function<void()>;
using HttpReqCallback = std::function<void(ReqResult, const HttpResponsePtr &)>;
using MiddlewareCallback = std::function<void(const HttpResponsePtr &)>;
using MiddlewareNextCallback =
std::function<void(std::function<void(const HttpResponsePtr &)> &&)>;
} // namespace drogon

View File

@ -25,7 +25,7 @@ enum class ConstraintType
{
None,
HttpMethod,
HttpFilter
HttpMiddleware
};
class HttpConstraint
@ -36,13 +36,14 @@ class HttpConstraint
{
}
HttpConstraint(const std::string &filterName)
: type_(ConstraintType::HttpFilter), filterName_(filterName)
HttpConstraint(std::string middlewareName)
: type_(ConstraintType::HttpMiddleware),
middlewareName_(std::move(middlewareName))
{
}
HttpConstraint(const char *filterName)
: type_(ConstraintType::HttpFilter), filterName_(filterName)
HttpConstraint(const char *middlewareName)
: type_(ConstraintType::HttpMiddleware), middlewareName_(middlewareName)
{
}
@ -56,15 +57,15 @@ class HttpConstraint
return method_;
}
const std::string &getFilterName() const
const std::string &getMiddlewareName() const
{
return filterName_;
return middlewareName_;
}
private:
ConstraintType type_{ConstraintType::None};
HttpMethod method_{HttpMethod::Invalid};
std::string filterName_;
std::string middlewareName_;
};
} // namespace internal
} // namespace drogon

View File

@ -23,7 +23,7 @@
namespace drogon
{
class HttpFilterBase;
class HttpMiddlewareBase;
/**
* @brief A component to associate router class and controller class
@ -31,8 +31,8 @@ class HttpFilterBase;
struct ControllerBinderBase
{
std::string handlerName_;
std::vector<std::string> filterNames_;
std::vector<std::shared_ptr<HttpFilterBase>> filters_;
std::vector<std::string> middlewareNames_;
std::vector<std::shared_ptr<HttpMiddlewareBase>> middlewares_;
IOThreadStorage<HttpResponsePtr> responseCache_;
std::shared_ptr<std::string> corsMethods_;
bool isCORS_{false};

View File

@ -1,100 +0,0 @@
/**
*
* @file FiltersFunction.cc
* @author An Tao
*
* Copyright 2018, An Tao. All rights reserved.
* https://github.com/an-tao/drogon
* Use of this source code is governed by a MIT license
* that can be found in the License file.
*
* Drogon
*
*/
#include "FiltersFunction.h"
#include "HttpRequestImpl.h"
#include "HttpResponseImpl.h"
#include "HttpAppFrameworkImpl.h"
#include <drogon/HttpFilter.h>
#include <queue>
namespace drogon
{
namespace filters_function
{
static void doFilterChains(
const std::vector<std::shared_ptr<HttpFilterBase>> &filters,
size_t index,
const HttpRequestImplPtr &req,
std::shared_ptr<const std::function<void(const HttpResponsePtr &)>>
&&callbackPtr)
{
if (index < filters.size())
{
auto &filter = filters[index];
filter->doFilter(
req,
[/*copy*/ callbackPtr](const HttpResponsePtr &resp) {
(*callbackPtr)(resp);
},
[index, req, callbackPtr, &filters]() mutable {
auto ioLoop = req->getLoop();
if (ioLoop && !ioLoop->isInLoopThread())
{
ioLoop->queueInLoop(
[&filters,
index,
req,
callbackPtr = std::move(callbackPtr)]() mutable {
doFilterChains(filters,
index + 1,
req,
std::move(callbackPtr));
});
}
else
{
doFilterChains(filters,
index + 1,
req,
std::move(callbackPtr));
}
});
}
else
{
(*callbackPtr)(nullptr);
}
}
std::vector<std::shared_ptr<HttpFilterBase>> createFilters(
const std::vector<std::string> &filterNames)
{
std::vector<std::shared_ptr<HttpFilterBase>> filters;
for (auto const &filter : filterNames)
{
auto object_ = DrClassMap::getSingleInstance(filter);
auto filter_ = std::dynamic_pointer_cast<HttpFilterBase>(object_);
if (filter_)
filters.push_back(filter_);
else
{
LOG_ERROR << "filter " << filter << " not found";
}
}
return filters;
}
void doFilters(const std::vector<std::shared_ptr<HttpFilterBase>> &filters,
const HttpRequestImplPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback)
{
auto callbackPtr =
std::make_shared<std::decay_t<decltype(callback)>>(std::move(callback));
doFilterChains(filters, 0, req, std::move(callbackPtr));
}
} // namespace filters_function
} // namespace drogon

View File

@ -1,33 +0,0 @@
/**
*
* FiltersFunction.h
* An Tao
*
* Copyright 2018, An Tao. All rights reserved.
* https://github.com/an-tao/drogon
* Use of this source code is governed by a MIT license
* that can be found in the License file.
*
* Drogon
*
*/
#pragma once
#include "impl_forwards.h"
#include <memory>
#include <string>
#include <vector>
namespace drogon
{
namespace filters_function
{
std::vector<std::shared_ptr<HttpFilterBase>> createFilters(
const std::vector<std::string> &filterNames);
void doFilters(const std::vector<std::shared_ptr<HttpFilterBase>> &filters,
const HttpRequestImplPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
} // namespace filters_function
} // namespace drogon

View File

@ -1,7 +1,7 @@
#include <drogon/plugins/GlobalFilters.h>
#include <drogon/DrClassMap.h>
#include <drogon/HttpAppFramework.h>
#include "FiltersFunction.h"
#include "MiddlewaresFunction.h"
#include "HttpRequestImpl.h"
#include "HttpAppFrameworkImpl.h"
@ -85,7 +85,7 @@ void GlobalFilters::initAndStart(const Json::Value &config)
}
}
drogon::filters_function::doFilters(
drogon::middlewares_function::doFilters(
thisPtr->filters_,
std::static_pointer_cast<HttpRequestImpl>(req),
[acb = std::move(acb),

View File

@ -285,22 +285,24 @@ HttpAppFramework &HttpAppFrameworkImpl::setFileTypes(
HttpAppFramework &HttpAppFrameworkImpl::registerWebSocketController(
const std::string &pathName,
const std::string &ctrlName,
const std::vector<internal::HttpConstraint> &filtersAndMethods)
const std::vector<internal::HttpConstraint> &constraints)
{
assert(!routersInit_);
HttpControllersRouter::instance().registerWebSocketController(
pathName, ctrlName, filtersAndMethods);
HttpControllersRouter::instance().registerWebSocketController(pathName,
ctrlName,
constraints);
return *this;
}
HttpAppFramework &HttpAppFrameworkImpl::registerHttpSimpleController(
const std::string &pathName,
const std::string &ctrlName,
const std::vector<internal::HttpConstraint> &filtersAndMethods)
const std::vector<internal::HttpConstraint> &constraints)
{
assert(!routersInit_);
HttpControllersRouter::instance().registerHttpSimpleController(
pathName, ctrlName, filtersAndMethods);
HttpControllersRouter::instance().registerHttpSimpleController(pathName,
ctrlName,
constraints);
return *this;
}
@ -308,28 +310,28 @@ void HttpAppFrameworkImpl::registerHttpController(
const std::string &pathPattern,
const internal::HttpBinderBasePtr &binder,
const std::vector<HttpMethod> &validMethods,
const std::vector<std::string> &filters,
const std::vector<std::string> &middlewareNames,
const std::string &handlerName)
{
assert(!pathPattern.empty());
assert(binder);
assert(!routersInit_);
HttpControllersRouter::instance().addHttpPath(
pathPattern, binder, validMethods, filters, handlerName);
pathPattern, binder, validMethods, middlewareNames, handlerName);
}
void HttpAppFrameworkImpl::registerHttpControllerViaRegex(
const std::string &regExp,
const internal::HttpBinderBasePtr &binder,
const std::vector<HttpMethod> &validMethods,
const std::vector<std::string> &filters,
const std::vector<std::string> &middlewareNames,
const std::string &handlerName)
{
assert(!regExp.empty());
assert(binder);
assert(!routersInit_);
HttpControllersRouter::instance().addHttpRegex(
regExp, binder, validMethods, filters, handlerName);
regExp, binder, validMethods, middlewareNames, handlerName);
}
HttpAppFramework &HttpAppFrameworkImpl::setThreadNum(size_t threadNum)
@ -1013,7 +1015,7 @@ HttpAppFramework &HttpAppFrameworkImpl::addALocation(
bool isCaseSensitive,
bool allowAll,
bool isRecursive,
const std::vector<std::string> &filters)
const std::vector<std::string> &middlewareNames)
{
StaticFileRouter::instance().addALocation(uriPrefix,
defaultContentType,
@ -1021,7 +1023,7 @@ HttpAppFramework &HttpAppFrameworkImpl::addALocation(
isCaseSensitive,
allowAll,
isRecursive,
filters);
middlewareNames);
return *this;
}

View File

@ -88,13 +88,11 @@ class HttpAppFrameworkImpl final : public HttpAppFramework
HttpAppFramework &registerWebSocketController(
const std::string &pathName,
const std::string &ctrlName,
const std::vector<internal::HttpConstraint> &filtersAndMethods)
override;
const std::vector<internal::HttpConstraint> &constraints) override;
HttpAppFramework &registerHttpSimpleController(
const std::string &pathName,
const std::string &ctrlName,
const std::vector<internal::HttpConstraint> &filtersAndMethods)
override;
const std::vector<internal::HttpConstraint> &constraints) override;
HttpAppFramework &setCustom404Page(const HttpResponsePtr &resp,
bool set404) override
@ -246,7 +244,7 @@ class HttpAppFrameworkImpl final : public HttpAppFramework
bool isCaseSensitive,
bool allowAll,
bool isRecursive,
const std::vector<std::string> &filters) override;
const std::vector<std::string> &middlewareNames) override;
const std::string &getUploadPath() const override
{
@ -642,13 +640,13 @@ class HttpAppFrameworkImpl final : public HttpAppFramework
void registerHttpController(const std::string &pathPattern,
const internal::HttpBinderBasePtr &binder,
const std::vector<HttpMethod> &validMethods,
const std::vector<std::string> &filters,
const std::vector<std::string> &middlewareNames,
const std::string &handlerName) override;
void registerHttpControllerViaRegex(
const std::string &regExp,
const internal::HttpBinderBasePtr &binder,
const std::vector<HttpMethod> &validMethods,
const std::vector<std::string> &filters,
const std::vector<std::string> &middlewareNames,
const std::string &handlerName) override;
// We use an uuid string as session id;

View File

@ -15,9 +15,8 @@
#include "HttpControllersRouter.h"
#include "HttpControllerBinder.h"
#include "HttpRequestImpl.h"
#include "HttpResponseImpl.h"
#include "HttpAppFrameworkImpl.h"
#include "FiltersFunction.h"
#include "MiddlewaresFunction.h"
#include <drogon/HttpSimpleController.h>
#include <drogon/WebSocketController.h>
#include <algorithm>
@ -27,15 +26,15 @@ using namespace drogon;
void HttpControllersRouter::init(
const std::vector<trantor::EventLoop *> & /*ioLoops*/)
{
auto initFiltersAndCorsMethods = [](const auto &item) {
auto initMiddlewaresAndCorsMethods = [](const auto &item) {
auto corsMethods = std::make_shared<std::string>("OPTIONS,");
for (size_t i = 0; i < Invalid; ++i)
{
auto &binder = item.binders_[i];
if (binder)
{
binder->filters_ =
filters_function::createFilters(binder->filterNames_);
binder->middlewares_ = middlewares_function::createMiddlewares(
binder->middlewareNames_);
binder->corsMethods_ = corsMethods;
if (binder->isCORS_)
{
@ -56,19 +55,19 @@ void HttpControllersRouter::init(
for (auto &iter : simpleCtrlMap_)
{
initFiltersAndCorsMethods(iter.second);
initMiddlewaresAndCorsMethods(iter.second);
}
for (auto &iter : wsCtrlMap_)
{
initFiltersAndCorsMethods(iter.second);
initMiddlewaresAndCorsMethods(iter.second);
}
for (auto &router : ctrlVector_)
{
router.regex_ = std::regex(router.pathParameterPattern_,
std::regex_constants::icase);
initFiltersAndCorsMethods(router);
initMiddlewaresAndCorsMethods(router);
}
for (auto &p : ctrlMap_)
@ -76,7 +75,7 @@ void HttpControllersRouter::init(
auto &router = p.second;
router.regex_ = std::regex(router.pathParameterPattern_,
std::regex_constants::icase);
initFiltersAndCorsMethods(router);
initMiddlewaresAndCorsMethods(router);
}
}
@ -175,12 +174,12 @@ struct SimpleControllerProcessResult
{
std::string lowerPath;
std::vector<HttpMethod> validMethods;
std::vector<std::string> filters;
std::vector<std::string> middlewares;
};
static SimpleControllerProcessResult processSimpleControllerParams(
const std::string &pathName,
const std::vector<internal::HttpConstraint> &filtersAndMethods)
const std::vector<internal::HttpConstraint> &constraints)
{
std::string path(pathName);
std::transform(pathName.begin(),
@ -188,16 +187,16 @@ static SimpleControllerProcessResult processSimpleControllerParams(
path.begin(),
[](unsigned char c) { return tolower(c); });
std::vector<HttpMethod> validMethods;
std::vector<std::string> filters;
for (const auto &filterOrMethod : filtersAndMethods)
std::vector<std::string> middlewareNames;
for (const auto &constraint : constraints)
{
if (filterOrMethod.type() == internal::ConstraintType::HttpFilter)
if (constraint.type() == internal::ConstraintType::HttpMiddleware)
{
filters.push_back(filterOrMethod.getFilterName());
middlewareNames.push_back(constraint.getMiddlewareName());
}
else if (filterOrMethod.type() == internal::ConstraintType::HttpMethod)
else if (constraint.type() == internal::ConstraintType::HttpMethod)
{
validMethods.push_back(filterOrMethod.getHttpMethod());
validMethods.push_back(constraint.getHttpMethod());
}
else
{
@ -208,26 +207,26 @@ static SimpleControllerProcessResult processSimpleControllerParams(
return {
std::move(path),
std::move(validMethods),
std::move(filters),
std::move(middlewareNames),
};
}
void HttpControllersRouter::registerHttpSimpleController(
const std::string &pathName,
const std::string &ctrlName,
const std::vector<internal::HttpConstraint> &filtersAndMethods)
const std::vector<internal::HttpConstraint> &constraints)
{
assert(!pathName.empty());
assert(!ctrlName.empty());
// Note: some compiler version failed to handle structural bindings with
// lambda capture
auto result = processSimpleControllerParams(pathName, filtersAndMethods);
auto result = processSimpleControllerParams(pathName, constraints);
std::string path = std::move(result.lowerPath);
auto &item = simpleCtrlMap_[path];
auto binder = std::make_shared<HttpSimpleControllerBinder>();
binder->handlerName_ = ctrlName;
binder->filterNames_ = result.filters;
binder->middlewareNames_ = result.middlewares;
drogon::app().getLoop()->queueInLoop([this, binder, ctrlName, path]() {
auto &object_ = DrClassMap::getSingleInstance(ctrlName);
auto controller =
@ -249,17 +248,17 @@ void HttpControllersRouter::registerHttpSimpleController(
void HttpControllersRouter::registerWebSocketController(
const std::string &pathName,
const std::string &ctrlName,
const std::vector<internal::HttpConstraint> &filtersAndMethods)
const std::vector<internal::HttpConstraint> &constraints)
{
assert(!pathName.empty());
assert(!ctrlName.empty());
auto result = processSimpleControllerParams(pathName, filtersAndMethods);
auto result = processSimpleControllerParams(pathName, constraints);
std::string path = std::move(result.lowerPath);
auto &item = wsCtrlMap_[path];
auto binder = std::make_shared<WebsocketControllerBinder>();
binder->handlerName_ = ctrlName;
binder->filterNames_ = result.filters;
binder->middlewareNames_ = result.middlewares;
drogon::app().getLoop()->queueInLoop([this, binder, ctrlName, path]() {
auto &object_ = DrClassMap::getSingleInstance(ctrlName);
auto controller =
@ -281,11 +280,11 @@ void HttpControllersRouter::addHttpRegex(
const std::string &regExp,
const internal::HttpBinderBasePtr &binder,
const std::vector<HttpMethod> &validMethods,
const std::vector<std::string> &filters,
const std::vector<std::string> &middlewareNames,
const std::string &handlerName)
{
auto binderInfo = std::make_shared<HttpControllerBinder>();
binderInfo->filterNames_ = filters;
binderInfo->middlewareNames_ = middlewareNames;
binderInfo->handlerName_ = handlerName;
binderInfo->binderPtr_ = binder;
drogon::app().getLoop()->queueInLoop([binderInfo]() {
@ -300,7 +299,7 @@ void HttpControllersRouter::addHttpPath(
const std::string &path,
const internal::HttpBinderBasePtr &binder,
const std::vector<HttpMethod> &validMethods,
const std::vector<std::string> &filters,
const std::vector<std::string> &middlewareNames,
const std::string &handlerName)
{
// Path is like /api/v1/service/method/{1}/{2}/xxx...
@ -513,7 +512,7 @@ void HttpControllersRouter::addHttpPath(
// Create new ControllerBinder
auto binderInfo = std::make_shared<HttpControllerBinder>();
binderInfo->filterNames_ = filters;
binderInfo->middlewareNames_ = middlewareNames;
binderInfo->handlerName_ = handlerName;
binderInfo->binderPtr_ = binder;
binderInfo->parameterPlaces_ = std::move(places);

View File

@ -45,20 +45,20 @@ class HttpControllersRouter : public trantor::NonCopyable
void registerHttpSimpleController(
const std::string &pathName,
const std::string &ctrlName,
const std::vector<internal::HttpConstraint> &filtersAndMethods);
const std::vector<internal::HttpConstraint> &constraints);
void registerWebSocketController(
const std::string &pathName,
const std::string &ctrlName,
const std::vector<internal::HttpConstraint> &filtersAndMethods);
const std::vector<internal::HttpConstraint> &constraints);
void addHttpPath(const std::string &path,
const internal::HttpBinderBasePtr &binder,
const std::vector<HttpMethod> &validMethods,
const std::vector<std::string> &filters,
const std::vector<std::string> &middlewareNames,
const std::string &handlerName = "");
void addHttpRegex(const std::string &regExp,
const internal::HttpBinderBasePtr &binder,
const std::vector<HttpMethod> &validMethods,
const std::vector<std::string> &filters,
const std::vector<std::string> &middlewareNames,
const std::string &handlerName = "");
RouteResult route(const HttpRequestImplPtr &req);
RouteResult routeWs(const HttpRequestImplPtr &req);

View File

@ -20,7 +20,7 @@
#include <memory>
#include <utility>
#include "AOPAdvice.h"
#include "FiltersFunction.h"
#include "MiddlewaresFunction.h"
#include "HttpAppFrameworkImpl.h"
#include "HttpConnectionLimit.h"
#include "HttpControllerBinder.h"
@ -418,7 +418,7 @@ void HttpServer::requestPostRouting(const HttpRequestImplPtr &req, Pack &&pack)
aop.passPostRoutingObservers(req);
if (!aop.hasPostRoutingAdvices())
{
requestPassFilters(req, std::forward<Pack>(pack));
requestPassMiddlewares(req, std::forward<Pack>(pack));
return;
}
aop.passPostRoutingAdvices(req,
@ -430,35 +430,36 @@ void HttpServer::requestPostRouting(const HttpRequestImplPtr &req, Pack &&pack)
}
else
{
requestPassFilters(req, std::move(pack));
requestPassMiddlewares(req,
std::move(pack));
}
});
}
template <typename Pack>
void HttpServer::requestPassFilters(const HttpRequestImplPtr &req, Pack &&pack)
void HttpServer::requestPassMiddlewares(const HttpRequestImplPtr &req,
Pack &&pack)
{
// pass filters
auto &filters = pack.binderPtr->filters_;
if (filters.empty())
// pass middlewares
auto &middlewares = pack.binderPtr->middlewares_;
if (middlewares.empty())
{
requestPreHandling(req, std::forward<Pack>(pack));
return;
}
filters_function::doFilters(filters,
req,
[req, pack = std::forward<Pack>(pack)](
const HttpResponsePtr &resp) mutable {
if (resp)
{
pack.callback(resp);
}
else
{
requestPreHandling(req,
std::move(pack));
}
});
auto callback = std::move(pack.callback);
pack.callback = nullptr;
middlewares_function::passMiddlewares(
middlewares,
req,
std::move(callback),
[req, pack = std::forward<Pack>(pack)](
std::function<void(const HttpResponsePtr &)>
&&middlewarePostCb) mutable {
pack.callback = std::move(middlewarePostCb);
requestPreHandling(req, std::forward<Pack>(pack));
});
}
template <typename Pack>

View File

@ -25,7 +25,7 @@ struct CallbackParamPack;
namespace drogon
{
class ControllerBinderBase;
struct ControllerBinderBase;
class HttpServer : trantor::NonCopyable
{
@ -107,7 +107,8 @@ class HttpServer : trantor::NonCopyable
template <typename Pack>
static void requestPostRouting(const HttpRequestImplPtr &req, Pack &&pack);
template <typename Pack>
static void requestPassFilters(const HttpRequestImplPtr &req, Pack &&pack);
static void requestPassMiddlewares(const HttpRequestImplPtr &req,
Pack &&pack);
template <typename Pack>
static void requestPreHandling(const HttpRequestImplPtr &req, Pack &&pack);

View File

@ -0,0 +1,185 @@
/**
*
* @file MiddlewaresFunction.cc
* @author An Tao
*
* Copyright 2018, An Tao. All rights reserved.
* https://github.com/an-tao/drogon
* Use of this source code is governed by a MIT license
* that can be found in the License file.
*
* Drogon
*
*/
#include "MiddlewaresFunction.h"
#include "HttpRequestImpl.h"
#include "HttpAppFrameworkImpl.h"
#include <drogon/HttpMiddleware.h>
#include <queue>
namespace drogon
{
namespace middlewares_function
{
static void doFilterChains(
const std::vector<std::shared_ptr<HttpFilterBase>> &filters,
size_t index,
const HttpRequestImplPtr &req,
std::shared_ptr<const std::function<void(const HttpResponsePtr &)>>
&&callbackPtr)
{
if (index < filters.size())
{
auto &filter = filters[index];
filter->doFilter(
req,
[/*copy*/ callbackPtr](const HttpResponsePtr &resp) {
(*callbackPtr)(resp);
},
[index, req, callbackPtr, &filters]() mutable {
auto ioLoop = req->getLoop();
if (ioLoop && !ioLoop->isInLoopThread())
{
ioLoop->queueInLoop(
[&filters,
index,
req,
callbackPtr = std::move(callbackPtr)]() mutable {
doFilterChains(filters,
index + 1,
req,
std::move(callbackPtr));
});
}
else
{
doFilterChains(filters,
index + 1,
req,
std::move(callbackPtr));
}
});
}
else
{
(*callbackPtr)(nullptr);
}
}
void doFilters(const std::vector<std::shared_ptr<HttpFilterBase>> &filters,
const HttpRequestImplPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback)
{
auto callbackPtr =
std::make_shared<std::decay_t<decltype(callback)>>(std::move(callback));
doFilterChains(filters, 0, req, std::move(callbackPtr));
}
/**
* @brief
* The middlewares are invoked according to the onion ring model.
*
* @param outerCallback The road back to the outer layer of the onion ring.
* @param innermostHandler The innermost handler at the core of the onion ring.
*
* When going through each middleware, the `innermostHandler` is passed down as
* is, while the `outerCallback` is passed to the user code. User code wraps the
* outerCallback along with other post processing codes into `userPostCb`, and
* passes it to the next middleware.
*
* When reaching the onion core, the `innermostHandler` is finally called. It's
* parameter is a function that wraps the original `outerCallback` and all
* `userPostCb`s.
*/
static void passMiddlewareChains(
const std::vector<std::shared_ptr<HttpMiddlewareBase>> &middlewares,
size_t index,
const HttpRequestImplPtr &req,
std::function<void(const HttpResponsePtr &)> &&outerCallback,
std::function<void(std::function<void(const HttpResponsePtr &)> &&)>
&&innermostHandler)
{
if (index < middlewares.size())
{
auto &middleware = middlewares[index];
middleware->invoke(
req,
[index,
req,
innermostHandler = std::move(innermostHandler),
&middlewares](std::function<void(const HttpResponsePtr &)>
&&userPostCb) mutable {
// call next middleware
auto ioLoop = req->getLoop();
if (ioLoop && !ioLoop->isInLoopThread())
{
ioLoop->queueInLoop(
[&middlewares,
index,
req,
innermostHandler = std::move(innermostHandler),
userPostCb = std::move(userPostCb)]() mutable {
passMiddlewareChains(middlewares,
index + 1,
req,
std::move(userPostCb),
std::move(innermostHandler)
);
});
}
else
{
passMiddlewareChains(middlewares,
index + 1,
req,
std::move(userPostCb),
std::move(innermostHandler));
}
},
std::move(outerCallback));
}
else
{
innermostHandler(std::move(outerCallback));
}
}
std::vector<std::shared_ptr<HttpMiddlewareBase>> createMiddlewares(
const std::vector<std::string> &middlewareNames)
{
std::vector<std::shared_ptr<HttpMiddlewareBase>> middlewares;
for (const auto &name : middlewareNames)
{
auto object_ = DrClassMap::getSingleInstance(name);
if (auto middleware =
std::dynamic_pointer_cast<HttpMiddlewareBase>(object_))
{
middlewares.push_back(middleware);
}
else
{
LOG_ERROR << "middleware " << name << " not found";
}
}
return middlewares;
}
void passMiddlewares(
const std::vector<std::shared_ptr<HttpMiddlewareBase>> &middlewares,
const HttpRequestImplPtr &req,
std::function<void(const HttpResponsePtr &)> &&outermostCallback,
std::function<void(std::function<void(const HttpResponsePtr &)> &&)>
&&innermostHandler)
{
passMiddlewareChains(middlewares,
0,
req,
std::move(outermostCallback),
std::move(innermostHandler));
}
} // namespace middlewares_function
} // namespace drogon

View File

@ -0,0 +1,44 @@
/**
*
* MiddlewaresFunction.h
* An Tao
*
* Copyright 2018, An Tao. All rights reserved.
* https://github.com/an-tao/drogon
* Use of this source code is governed by a MIT license
* that can be found in the License file.
*
* Drogon
*
*/
#pragma once
#include "impl_forwards.h"
#include <memory>
#include <string>
#include <vector>
namespace drogon
{
namespace middlewares_function
{
// We can not remove old filters api. GlobalFilter still needs it.
// GlobalFilter run filters in advice chains, which does not expose the outer
// response handler, so HttpMiddleware is not suitable for it.
void doFilters(const std::vector<std::shared_ptr<HttpFilterBase>> &filters,
const HttpRequestImplPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
std::vector<std::shared_ptr<HttpMiddlewareBase>> createMiddlewares(
const std::vector<std::string> &middlewareNames);
void passMiddlewares(
const std::vector<std::shared_ptr<HttpMiddlewareBase>> &middlewares,
const HttpRequestImplPtr &req,
std::function<void(const HttpResponsePtr &)> &&outermostCallback,
std::function<void(std::function<void(const HttpResponsePtr &)> &&)>
&&innermostHandler);
} // namespace middlewares_function
} // namespace drogon

View File

@ -205,7 +205,7 @@ void StaticFileRouter::route(
}
}
if (location.filters_.empty())
if (location.middlewares_.empty())
{
sendStaticFileResponse(filePath,
req,
@ -215,27 +215,21 @@ void StaticFileRouter::route(
}
else
{
filters_function::doFilters(
location.filters_,
middlewares_function::passMiddlewares(
location.middlewares_,
req,
std::move(callback),
[this,
req,
filePath = std::move(filePath),
contentType =
std::string_view{location.defaultContentType_},
callback = std::move(callback)](
const HttpResponsePtr &resp) mutable {
if (resp)
{
callback(resp);
}
else
{
sendStaticFileResponse(filePath,
req,
std::move(callback),
contentType);
}
std::string_view{location.defaultContentType_}](
std::function<void(const HttpResponsePtr &)>
&&middlewarePostCb) mutable {
sendStaticFileResponse(filePath,
req,
std::move(middlewarePostCb),
contentType);
});
}
return;

View File

@ -15,7 +15,7 @@
#pragma once
#include "impl_forwards.h"
#include "FiltersFunction.h"
#include "MiddlewaresFunction.h"
#include <drogon/CacheMap.h>
#include <drogon/IOThreadStorage.h>
#include <functional>
@ -73,7 +73,7 @@ class StaticFileRouter
bool isCaseSensitive,
bool allowAll,
bool isRecursive,
const std::vector<std::string> &filters)
const std::vector<std::string> &middlewareNames)
{
locations_.emplace_back(uriPrefix,
defaultContentType,
@ -81,7 +81,7 @@ class StaticFileRouter
isCaseSensitive,
allowAll,
isRecursive,
filters);
middlewareNames);
}
void setStaticFileHeaders(
@ -165,7 +165,7 @@ class StaticFileRouter
bool isCaseSensitive_;
bool allowAll_;
bool isRecursive_;
std::vector<std::shared_ptr<drogon::HttpFilterBase>> filters_;
std::vector<std::shared_ptr<drogon::HttpMiddlewareBase>> middlewares_;
Location(const std::string &uriPrefix,
const std::string &defaultContentType,
@ -173,13 +173,13 @@ class StaticFileRouter
bool isCaseSensitive,
bool allowAll,
bool isRecursive,
const std::vector<std::string> &filters)
const std::vector<std::string> &middlewares)
: uriPrefix_(uriPrefix),
alias_(alias),
isCaseSensitive_(isCaseSensitive),
allowAll_(allowAll),
isRecursive_(isRecursive),
filters_(filters_function::createFilters(filters))
middlewares_(middlewares_function::createMiddlewares(middlewares))
{
if (!defaultContentType.empty())
{

View File

@ -17,6 +17,8 @@ class WebSocketControllerBase;
using WebSocketControllerBasePtr = std::shared_ptr<WebSocketControllerBase>;
class HttpFilterBase;
using HttpFilterBasePtr = std::shared_ptr<HttpFilterBase>;
class HttpMiddlewareBase;
using HttpMiddlewareBasePtr = std::shared_ptr<HttpMiddlewareBase>;
class HttpSimpleControllerBase;
using HttpSimpleControllerBasePtr = std::shared_ptr<HttpSimpleControllerBase>;
class HttpRequestImpl;

View File

@ -74,6 +74,7 @@ if (BUILD_CTL)
integration_test/server/MethodTest.cc
integration_test/server/RangeTestController.cc
integration_test/server/BeginAdviceTest.cc
integration_test/server/MiddlewareTest.cc
integration_test/server/main.cc)
if(DROGON_CXX_STANDARD GREATER_EQUAL 20 AND HAS_COROUTINE)

View File

@ -1046,6 +1046,25 @@ void doTest(const HttpClientPtr &client, std::shared_ptr<test::Case> TEST_CTX)
});
#endif
// Test middleware
req = HttpRequest::newHttpRequest();
req->setPath("/test-middleware");
client->sendRequest(req,
[TEST_CTX, req](ReqResult r,
const HttpResponsePtr &resp) {
REQUIRE(r == ReqResult::Ok);
CHECK(resp->body() == "123test321");
});
req = HttpRequest::newHttpRequest();
req->setPath("/test-middleware-block");
client->sendRequest(req,
[TEST_CTX, req](ReqResult r,
const HttpResponsePtr &resp) {
REQUIRE(r == ReqResult::Ok);
CHECK(resp->body() == "12block21");
});
#if defined(__cpp_impl_coroutine)
async_run([client, TEST_CTX]() -> Task<> {
// Test coroutine requests
@ -1142,6 +1161,19 @@ void doTest(const HttpClientPtr &client, std::shared_ptr<test::Case> TEST_CTX)
{
FAIL("Unexpected exception, what(): " + std::string(e.what()));
}
// Test coroutine middleware
try
{
auto req = HttpRequest::newHttpRequest();
req->setPath("/test-middleware-coro");
auto resp = co_await client->sendRequestCoro(req);
CHECK(resp->body() == "12corotestcoro21");
}
catch (const std::exception &e)
{
FAIL("Unexpected exception, what(): " + std::string(e.what()));
}
});
#endif
}

View File

@ -0,0 +1,163 @@
#include <drogon/HttpController.h>
#include <drogon/HttpMiddleware.h>
using namespace drogon;
class Middleware1 : public drogon::HttpMiddleware<Middleware1>
{
public:
Middleware1()
{
// do not omit constructor
void(0);
};
void invoke(const HttpRequestPtr &req,
MiddlewareNextCallback &&nextCb,
MiddlewareCallback &&mcb) override
{
auto ptr = std::make_shared<std::string>("1");
req->attributes()->insert("test-middleware", ptr);
nextCb([req, ptr, mcb = std::move(mcb)](const HttpResponsePtr &resp) {
ptr->append("1");
resp->setBody(*ptr);
mcb(resp);
});
}
};
class Middleware2 : public drogon::HttpMiddleware<Middleware2>
{
public:
Middleware2()
{
// do not omit constructor
void(0);
};
void invoke(const HttpRequestPtr &req,
MiddlewareNextCallback &&nextCb,
MiddlewareCallback &&mcb) override
{
auto ptr = req->attributes()->get<std::shared_ptr<std::string>>(
"test-middleware");
ptr->append("2");
nextCb([req, ptr, mcb = std::move(mcb)](const HttpResponsePtr &resp) {
ptr->append("2");
resp->setBody(*ptr);
mcb(resp);
});
}
};
class Middleware3 : public drogon::HttpMiddleware<Middleware3>
{
public:
Middleware3()
{
// do not omit constructor
void(0);
};
void invoke(const HttpRequestPtr &req,
MiddlewareNextCallback &&nextCb,
MiddlewareCallback &&mcb) override
{
auto ptr = req->attributes()->get<std::shared_ptr<std::string>>(
"test-middleware");
ptr->append("3");
nextCb([req, ptr, mcb = std::move(mcb)](const HttpResponsePtr &resp) {
ptr->append("3");
resp->setBody(*ptr);
mcb(resp);
});
}
};
class MiddlewareBlock : public drogon::HttpMiddleware<MiddlewareBlock>
{
public:
MiddlewareBlock()
{
// do not omit constructor
void(0);
};
void invoke(const HttpRequestPtr &req,
MiddlewareNextCallback &&nextCb,
MiddlewareCallback &&mcb) override
{
auto ptr = req->attributes()->get<std::shared_ptr<std::string>>(
"test-middleware");
ptr->append("block");
mcb(HttpResponse::newHttpResponse());
}
};
#if defined(__cpp_impl_coroutine)
class MiddlewareCoro : public drogon::HttpCoroMiddleware<MiddlewareCoro>
{
public:
MiddlewareCoro()
{
// do not omit constructor
void(0);
};
Task<HttpResponsePtr> invoke(const HttpRequestPtr &req,
MiddlewareNextAwaiter &&nextAwaiter) override
{
auto ptr = req->attributes()->get<std::shared_ptr<std::string>>(
"test-middleware");
ptr->append("coro");
auto resp = co_await nextAwaiter;
ptr->append("coro");
resp->setBody(*ptr);
co_return resp;
}
};
#endif
class MiddlewareTest : public drogon::HttpController<MiddlewareTest>
{
public:
METHOD_LIST_BEGIN
ADD_METHOD_TO(MiddlewareTest::handleRequest,
"/test-middleware",
Get,
"Middleware1",
"Middleware2",
"Middleware3");
ADD_METHOD_TO(MiddlewareTest::handleRequest,
"/test-middleware-block",
Get,
"Middleware1",
"Middleware2",
"MiddlewareBlock",
"Middleware3");
#if defined(__cpp_impl_coroutine)
ADD_METHOD_TO(MiddlewareTest::handleRequest,
"/test-middleware-coro",
Get,
"Middleware1",
"Middleware2",
"MiddlewareCoro");
#endif
METHOD_LIST_END
void handleRequest(
const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback) const
{
req->attributes()
->get<std::shared_ptr<std::string>>("test-middleware")
->append("test");
callback(HttpResponse::newHttpResponse());
}
};