Support request stream (#2055)

This commit is contained in:
Nitromelon 2024-07-03 11:31:39 +08:00 committed by GitHub
parent dfacd1b454
commit 5d4523a3a6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
31 changed files with 1988 additions and 204 deletions

View File

@ -268,6 +268,7 @@ set(DROGON_SOURCES
lib/src/HttpFileUploadRequest.cc lib/src/HttpFileUploadRequest.cc
lib/src/HttpRequestImpl.cc lib/src/HttpRequestImpl.cc
lib/src/HttpRequestParser.cc lib/src/HttpRequestParser.cc
lib/src/RequestStream.cc
lib/src/HttpResponseImpl.cc lib/src/HttpResponseImpl.cc
lib/src/HttpResponseParser.cc lib/src/HttpResponseParser.cc
lib/src/HttpServer.cc lib/src/HttpServer.cc
@ -278,6 +279,7 @@ set(DROGON_SOURCES
lib/src/ListenerManager.cc lib/src/ListenerManager.cc
lib/src/LocalHostFilter.cc lib/src/LocalHostFilter.cc
lib/src/MultiPart.cc lib/src/MultiPart.cc
lib/src/MultipartStreamParser.cc
lib/src/NotFound.cc lib/src/NotFound.cc
lib/src/PluginsManager.cc lib/src/PluginsManager.cc
lib/src/PromExporter.cc lib/src/PromExporter.cc
@ -332,7 +334,8 @@ set(private_headers
lib/src/ConfigAdapterManager.h lib/src/ConfigAdapterManager.h
lib/src/JsonConfigAdapter.h lib/src/JsonConfigAdapter.h
lib/src/YamlConfigAdapter.h lib/src/YamlConfigAdapter.h
lib/src/ConfigAdapter.h) lib/src/ConfigAdapter.h
lib/src/MultipartStreamParser.h)
if (NOT WIN32) if (NOT WIN32)
set(DROGON_SOURCES set(DROGON_SOURCES
@ -559,6 +562,7 @@ set(DROGON_HEADERS
lib/inc/drogon/HttpFilter.h lib/inc/drogon/HttpFilter.h
lib/inc/drogon/HttpMiddleware.h lib/inc/drogon/HttpMiddleware.h
lib/inc/drogon/HttpRequest.h lib/inc/drogon/HttpRequest.h
lib/inc/drogon/RequestStream.h
lib/inc/drogon/HttpResponse.h lib/inc/drogon/HttpResponse.h
lib/inc/drogon/HttpSimpleController.h lib/inc/drogon/HttpSimpleController.h
lib/inc/drogon/HttpTypes.h lib/inc/drogon/HttpTypes.h

View File

@ -108,7 +108,7 @@
"session_timeout": 0, "session_timeout": 0,
//string value of SameSite attribute of the Set-Cookie HTTP response header //string value of SameSite attribute of the Set-Cookie HTTP response header
//valid value is either 'Null' (default), 'Lax', 'Strict' or 'None' //valid value is either 'Null' (default), 'Lax', 'Strict' or 'None'
"session_same_site" : "Null", "session_same_site": "Null",
//session_cookie_key: The cookie key of the session, "JSESSIONID" by default //session_cookie_key: The cookie key of the session, "JSESSIONID" by default
"session_cookie_key": "JSESSIONID", "session_cookie_key": "JSESSIONID",
//session_max_age: The max age of the session cookie, -1 by default //session_max_age: The max age of the session cookie, -1 by default
@ -310,7 +310,10 @@
// Currently only gzip and br are supported. Note: max_memory_body_size and max_body_size applies twice for compressed requests. // Currently only gzip and br are supported. Note: max_memory_body_size and max_body_size applies twice for compressed requests.
// Once when receiving and once when decompressing. i.e. if the decompressed body is larger than max_body_size, the request // Once when receiving and once when decompressing. i.e. if the decompressed body is larger than max_body_size, the request
// will be rejected. // will be rejected.
"enabled_compressed_request": false "enabled_compressed_request": false,
// enable_request_stream: Defaults to false. If true the server will enable stream mode for http requests.
// See the wiki for more details.
"enable_request_stream": false,
}, },
//plugins: Define all plugins running in the application //plugins: Define all plugins running in the application
"plugins": [ "plugins": [

View File

@ -283,6 +283,9 @@ app:
# Once when receiving and once when decompressing. i.e. if the decompressed body is larger than max_body_size, the request # Once when receiving and once when decompressing. i.e. if the decompressed body is larger than max_body_size, the request
# will be rejected. # will be rejected.
enabled_compressed_request: false enabled_compressed_request: false
# enable_request_stream: Defaults to false. If true the server will enable stream mode for http requests.
# See the wiki for more details.
enable_request_stream: false
# plugins: Define all plugins running in the application # plugins: Define all plugins running in the application
plugins: plugins:
# name: The class name of the plugin # name: The class name of the plugin

View File

@ -310,7 +310,10 @@
// Currently only gzip and br are supported. Note: max_memory_body_size and max_body_size applies twice for compressed requests. // Currently only gzip and br are supported. Note: max_memory_body_size and max_body_size applies twice for compressed requests.
// Once when receiving and once when decompressing. i.e. if the decompressed body is larger than max_body_size, the request // Once when receiving and once when decompressing. i.e. if the decompressed body is larger than max_body_size, the request
// will be rejected. // will be rejected.
"enabled_compressed_request": false "enabled_compressed_request": false,
// enable_request_stream: Defaults to false. If true the server will enable stream mode for http requests.
// See the wiki for more details.
"enable_request_stream": false,
}, },
//plugins: Define all plugins running in the application //plugins: Define all plugins running in the application
"plugins": [ "plugins": [

View File

@ -283,6 +283,9 @@ app:
# Once when receiving and once when decompressing. i.e. if the decompressed body is larger than max_body_size, the request # Once when receiving and once when decompressing. i.e. if the decompressed body is larger than max_body_size, the request
# will be rejected. # will be rejected.
enabled_compressed_request: false enabled_compressed_request: false
# enable_request_stream: Defaults to false. If true the server will enable stream mode for http requests.
# See the wiki for more details.
enable_request_stream: false
# plugins: Define all plugins running in the application # plugins: Define all plugins running in the application
plugins: plugins:
# name: The class name of the plugin # name: The class name of the plugin

View File

@ -31,7 +31,8 @@ add_executable(redis_simple redis/main.cc
add_executable(redis_chat redis_chat/main.cc add_executable(redis_chat redis_chat/main.cc
redis_chat/controllers/Chat.cc) redis_chat/controllers/Chat.cc)
add_executable(async_stream async_stream/main.cc) add_executable(async_stream async_stream/main.cc
async_stream/RequestStreamExampleCtrl.cc)
set(example_targets set(example_targets
benchmark benchmark

View File

@ -0,0 +1,167 @@
#include <drogon/drogon.h>
#include <drogon/HttpController.h>
#include <drogon/HttpRequest.h>
#include <fstream>
using namespace drogon;
class StreamEchoReader : public RequestStreamReader
{
public:
StreamEchoReader(ResponseStreamPtr respStream)
: respStream_(std::move(respStream))
{
}
void onStreamData(const char *data, size_t length) override
{
LOG_INFO << "onStreamData[" << length << "]";
respStream_->send({data, length});
}
void onStreamFinish(std::exception_ptr ptr) override
{
if (ptr)
{
try
{
std::rethrow_exception(ptr);
}
catch (const std::exception &e)
{
LOG_ERROR << "onStreamError: " << e.what();
}
}
else
{
LOG_INFO << "onStreamFinish";
}
respStream_->close();
}
private:
ResponseStreamPtr respStream_;
};
class RequestStreamExampleCtrl : public HttpController<RequestStreamExampleCtrl>
{
public:
METHOD_LIST_BEGIN
ADD_METHOD_TO(RequestStreamExampleCtrl::stream_echo, "/stream_echo", Post);
ADD_METHOD_TO(RequestStreamExampleCtrl::stream_upload,
"/stream_upload",
Post);
METHOD_LIST_END
void stream_echo(
const HttpRequestPtr &,
RequestStreamPtr &&stream,
std::function<void(const HttpResponsePtr &)> &&callback) const
{
auto resp = drogon::HttpResponse::newAsyncStreamResponse(
[stream](ResponseStreamPtr respStream) {
stream->setStreamReader(
std::make_shared<StreamEchoReader>(std::move(respStream)));
});
callback(resp);
}
void stream_upload(
const HttpRequestPtr &req,
RequestStreamPtr &&stream,
std::function<void(const HttpResponsePtr &)> &&callback) const
{
struct Entry
{
MultipartHeader header;
std::string tmpName;
std::ofstream file;
};
auto files = std::make_shared<std::vector<Entry>>();
auto reader = RequestStreamReader::newMultipartReader(
req,
[files](MultipartHeader &&header) {
LOG_INFO << "Multipart name: " << header.name
<< ", filename:" << header.filename
<< ", contentType:" << header.contentType;
files->push_back({std::move(header)});
auto tmpName = drogon::utils::genRandomString(40);
if (!files->back().header.filename.empty())
{
files->back().tmpName = tmpName;
files->back().file.open("uploads/" + tmpName,
std::ios::trunc);
}
},
[files](const char *data, size_t length) {
if (files->back().tmpName.empty())
{
return;
}
auto &currentFile = files->back().file;
if (length == 0)
{
LOG_INFO << "file finish";
if (currentFile.is_open())
{
currentFile.flush();
currentFile.close();
}
return;
}
LOG_INFO << "data[" << length << "]: ";
if (currentFile.is_open())
{
LOG_INFO << "write file";
currentFile.write(data, length);
}
else
{
LOG_ERROR << "file not open";
}
},
[files, callback = std::move(callback)](std::exception_ptr ex) {
if (ex)
{
try
{
std::rethrow_exception(std::move(ex));
}
catch (const StreamError &e)
{
LOG_ERROR << "stream error: " << e.what();
}
catch (const std::exception &e)
{
LOG_ERROR << "multipart error: " << e.what();
}
auto resp = HttpResponse::newHttpResponse();
resp->setStatusCode(k400BadRequest);
resp->setBody("error\n");
callback(resp);
}
else
{
LOG_INFO << "stream finish, received " << files->size()
<< " files";
Json::Value respJson;
for (const auto &item : *files)
{
if (item.tmpName.empty())
continue;
Json::Value entry;
entry["name"] = item.header.name;
entry["filename"] = item.header.filename;
entry["tmpName"] = item.tmpName;
respJson.append(entry);
}
auto resp = HttpResponse::newHttpJsonResponse(respJson);
callback(resp);
}
});
stream->setStreamReader(std::move(reader));
}
};

View File

@ -1,6 +1,6 @@
#include <drogon/drogon.h> #include <drogon/drogon.h>
#include <chrono> #include <chrono>
#include <memory>
using namespace drogon; using namespace drogon;
using namespace std::chrono_literals; using namespace std::chrono_literals;
@ -28,6 +28,56 @@ int main()
callback(resp); callback(resp);
}); });
// Example: register a stream-mode function handler
app().registerHandler(
"/stream_req",
[](const HttpRequestPtr &req,
RequestStreamPtr &&stream,
std::function<void(const HttpResponsePtr &)> &&callback) {
if (!stream)
{
LOG_INFO << "stream mode is not enabled";
auto resp = HttpResponse::newHttpResponse();
resp->setStatusCode(k400BadRequest);
resp->setBody("no stream");
callback(resp);
return;
}
auto reader = RequestStreamReader::newReader(
[](const char *data, size_t length) {
LOG_INFO << "piece[" << length
<< "]: " << std::string_view{data, length};
},
[callback = std::move(callback)](std::exception_ptr ex) {
auto resp = HttpResponse::newHttpResponse();
if (ex)
{
try
{
std::rethrow_exception(std::move(ex));
}
catch (const std::exception &e)
{
LOG_ERROR << "stream error: " << e.what();
}
resp->setStatusCode(k400BadRequest);
resp->setBody("error\n");
callback(resp);
}
else
{
LOG_INFO << "stream finish";
resp->setBody("success\n");
callback(resp);
}
});
stream->setStreamReader(std::move(reader));
},
{Post});
LOG_INFO << "Server running on 127.0.0.1:8848"; LOG_INFO << "Server running on 127.0.0.1:8848";
app().enableRequestStream(); // This is for request stream.
app().addListener("127.0.0.1", 8848).run(); app().addListener("127.0.0.1", 8848).run();
} }

View File

@ -1606,6 +1606,9 @@ class DROGON_EXPORT HttpAppFramework : public trantor::NonCopyable
virtual HttpAppFramework &setAfterAcceptSockOptCallback( virtual HttpAppFramework &setAfterAcceptSockOptCallback(
std::function<void(int)> cb) = 0; std::function<void(int)> cb) = 0;
virtual HttpAppFramework &enableRequestStream(bool enable = true) = 0;
virtual bool isRequestStreamEnabled() const = 0;
private: private:
virtual void registerHttpController( virtual void registerHttpController(
const std::string &pathPattern, const std::string &pathPattern,

View File

@ -164,6 +164,7 @@ class HttpBinderBase
std::function<void(const HttpResponsePtr &)> &&callback) = 0; std::function<void(const HttpResponsePtr &)> &&callback) = 0;
virtual size_t paramCount() = 0; virtual size_t paramCount() = 0;
virtual const std::string &handlerName() const = 0; virtual const std::string &handlerName() const = 0;
virtual bool isStreamHandler() = 0;
virtual ~HttpBinderBase() virtual ~HttpBinderBase()
{ {
@ -218,6 +219,11 @@ class HttpBinder : public HttpBinderBase
return traits::arity; return traits::arity;
} }
bool isStreamHandler() override
{
return traits::isStreamHandler;
}
HttpBinder(FUNCTION &&func) : func_(std::forward<FUNCTION>(func)) HttpBinder(FUNCTION &&func) : func_(std::forward<FUNCTION>(func))
{ {
static_assert(traits::isHTTPFunction, static_assert(traits::isHTTPFunction,
@ -266,6 +272,7 @@ class HttpBinder : public HttpBinderBase
template <typename... Values, template <typename... Values,
std::size_t Boundary = argument_count, std::size_t Boundary = argument_count,
bool isStreamHandler = traits::isStreamHandler,
bool isCoroutine = traits::isCoroutine> bool isCoroutine = traits::isCoroutine>
void run(std::deque<std::string> &pathArguments, void run(std::deque<std::string> &pathArguments,
const HttpRequestPtr &req, const HttpRequestPtr &req,
@ -344,7 +351,17 @@ class HttpBinder : public HttpBinderBase
{ {
// Explicit copy because `callFunction` moves it // Explicit copy because `callFunction` moves it
auto cb = callback; auto cb = callback;
callFunction(req, cb, std::move(values)...); if constexpr (isStreamHandler)
{
callFunction(req,
createRequestStream(req),
cb,
std::move(values)...);
}
else
{
callFunction(req, cb, std::move(values)...);
}
} }
catch (const std::exception &except) catch (const std::exception &except)
{ {
@ -359,6 +376,7 @@ class HttpBinder : public HttpBinderBase
#ifdef __cpp_impl_coroutine #ifdef __cpp_impl_coroutine
else else
{ {
static_assert(!isStreamHandler);
[this](HttpRequestPtr req, [this](HttpRequestPtr req,
std::function<void(const HttpResponsePtr &)> callback, std::function<void(const HttpResponsePtr &)> callback,
Values &&...values) -> AsyncTask { Values &&...values) -> AsyncTask {

View File

@ -179,6 +179,17 @@ class DROGON_EXPORT HttpRequest
return cookies(); return cookies();
} }
/**
* @brief Return content length parsed from the Content-Length header
* If no Content-Length header, return null.
*/
virtual size_t realContentLength() const = 0;
size_t getRealContentLength() const
{
return realContentLength();
}
/// Get the query string of the request. /// Get the query string of the request.
/** /**
* The query string is the substring after the '?' in the URL string. * The query string is the substring after the '?' in the URL string.

View File

@ -0,0 +1,116 @@
/**
*
* @file RequestStream.h
* @author Nitromelon
*
* Copyright 2024, Nitromelon. All rights reserved.
* https://github.com/drogonframework/drogon
* Use of this source code is governed by a MIT license
* that can be found in the License file.
*
* Drogon
*
*/
#pragma once
#include <drogon/exports.h>
#include <string>
#include <functional>
#include <memory>
namespace drogon
{
class HttpRequest;
using HttpRequestPtr = std::shared_ptr<HttpRequest>;
class RequestStreamReader;
using RequestStreamReaderPtr = std::shared_ptr<RequestStreamReader>;
struct MultipartHeader
{
std::string name;
std::string filename;
std::string contentType;
};
class DROGON_EXPORT RequestStream
{
public:
virtual ~RequestStream() = default;
virtual void setStreamReader(RequestStreamReaderPtr reader) = 0;
};
using RequestStreamPtr = std::shared_ptr<RequestStream>;
namespace internal
{
DROGON_EXPORT RequestStreamPtr createRequestStream(const HttpRequestPtr &req);
}
enum class StreamErrorCode
{
kNone = 0,
kBadRequest,
kConnectionBroken
};
class StreamError final : public std::exception
{
public:
const char *what() const noexcept override
{
return message_.data();
}
StreamErrorCode code() const
{
return code_;
}
StreamError(StreamErrorCode code, const std::string &message)
: message_(message), code_(code)
{
}
StreamError(StreamErrorCode code, std::string &&message)
: message_(std::move(message)), code_(code)
{
}
StreamError() = delete;
private:
std::string message_;
StreamErrorCode code_;
};
/**
* An interface for stream request reading.
* User should create an implementation class, or use built-in handlers
*/
class RequestStreamReader
{
public:
virtual ~RequestStreamReader() = default;
virtual void onStreamData(const char *, size_t) = 0;
virtual void onStreamFinish(std::exception_ptr) = 0;
using StreamDataCallback = std::function<void(const char *, size_t)>;
using StreamFinishCallback = std::function<void(std::exception_ptr)>;
// Create a handler with default implementation
static RequestStreamReaderPtr newReader(StreamDataCallback dataCb,
StreamFinishCallback finishCb);
// A handler that drops all data
static RequestStreamReaderPtr newNullReader();
using MultipartHeaderCallback = std::function<void(MultipartHeader header)>;
static RequestStreamReaderPtr newMultipartReader(
const HttpRequestPtr &req,
MultipartHeaderCallback headerCb,
StreamDataCallback dataCb,
StreamFinishCallback finishCb);
};
} // namespace drogon

View File

@ -15,6 +15,7 @@
#pragma once #pragma once
#include <drogon/DrObject.h> #include <drogon/DrObject.h>
#include <drogon/RequestStream.h>
#include <functional> #include <functional>
#include <memory> #include <memory>
#include <tuple> #include <tuple>
@ -46,53 +47,36 @@ struct resumable_type : std::false_type
template <typename> template <typename>
struct FunctionTraits; struct FunctionTraits;
// functor,lambda,std::function... //
template <typename Function> // Basic match, inherited by all other matches
struct FunctionTraits //
: public FunctionTraits< template <typename ReturnType, typename... Arguments>
decltype(&std::remove_reference_t<Function>::operator())> struct FunctionTraits<ReturnType (*)(Arguments...)>
{ {
static const bool isClassFunction = false; using result_type = ReturnType;
static const bool isDrObjectClass = false;
template <std::size_t Index>
using argument =
typename std::tuple_element_t<Index, std::tuple<Arguments...>>;
static const std::size_t arity = sizeof...(Arguments);
using class_type = void; using class_type = void;
using return_type = ReturnType;
static const bool isHTTPFunction = false;
static const bool isClassFunction = false;
static const bool isStreamHandler = false;
static const bool isDrObjectClass = false;
static const bool isCoroutine = false;
static const std::string name() static const std::string name()
{ {
return std::string("Functor"); return std::string("Normal or Static Function");
} }
}; };
// class instance method of const object //
template <typename ClassType, typename ReturnType, typename... Arguments> // Match normal functions
struct FunctionTraits<ReturnType (ClassType::*)(Arguments...) const> //
: FunctionTraits<ReturnType (*)(Arguments...)>
{
static const bool isClassFunction = true;
static const bool isDrObjectClass =
std::is_base_of<DrObject<ClassType>, ClassType>::value;
using class_type = ClassType;
static const std::string name()
{
return std::string("Class Function");
}
};
// class instance method of non-const object
template <typename ClassType, typename ReturnType, typename... Arguments>
struct FunctionTraits<ReturnType (ClassType::*)(Arguments...)>
: FunctionTraits<ReturnType (*)(Arguments...)>
{
static const bool isClassFunction = true;
static const bool isDrObjectClass =
std::is_base_of<DrObject<ClassType>, ClassType>::value;
using class_type = ClassType;
static const std::string name()
{
return std::string("Class Function");
}
};
// normal function for HTTP handling // normal function for HTTP handling
template <typename ReturnType, typename... Arguments> template <typename ReturnType, typename... Arguments>
@ -108,16 +92,93 @@ struct FunctionTraits<
using return_type = ReturnType; using return_type = ReturnType;
}; };
template <typename ReturnType, typename... Arguments> // normal function with custom request object
template <typename T, typename ReturnType, typename... Arguments>
struct FunctionTraits< struct FunctionTraits<
ReturnType (*)(HttpRequestPtr &req, ReturnType (*)(T &&customReq,
std::function<void(const HttpResponsePtr &)> &&callback, std::function<void(const HttpResponsePtr &)> &&callback,
Arguments...)> : FunctionTraits<ReturnType (*)(Arguments...)> Arguments...)> : FunctionTraits<ReturnType (*)(Arguments...)>
{ {
static const bool isHTTPFunction = false; static const bool isHTTPFunction = !resumable_type<ReturnType>::value;
static const bool isCoroutine = false;
using class_type = void; using class_type = void;
using first_param_type = T;
using return_type = ReturnType;
}; };
// normal function with stream handler
template <typename ReturnType, typename... Arguments>
struct FunctionTraits<
ReturnType (*)(const HttpRequestPtr &req,
RequestStreamPtr &&streamCtx,
std::function<void(const HttpResponsePtr &)> &&callback,
Arguments...)> : FunctionTraits<ReturnType (*)(Arguments...)>
{
static const bool isHTTPFunction = !resumable_type<ReturnType>::value;
static const bool isCoroutine = false;
static const bool isStreamHandler = true;
using class_type = void;
using first_param_type = HttpRequestPtr;
using return_type = ReturnType;
};
//
// Match functor,lambda,std::function... inherits normal function matches
//
template <typename Function>
struct FunctionTraits
: public FunctionTraits<
decltype(&std::remove_reference_t<Function>::operator())>
{
static const bool isClassFunction = false;
static const bool isDrObjectClass = false;
using class_type = void;
static const std::string name()
{
return std::string("Functor");
}
};
//
// Match class functions, inherits normal function matches
//
// class const method
template <typename ClassType, typename ReturnType, typename... Arguments>
struct FunctionTraits<ReturnType (ClassType::*)(Arguments...) const>
: FunctionTraits<ReturnType (*)(Arguments...)>
{
static const bool isClassFunction = true;
static const bool isDrObjectClass =
std::is_base_of<DrObject<ClassType>, ClassType>::value;
using class_type = ClassType;
static const std::string name()
{
return std::string("Class Function");
}
};
// class non-const method
template <typename ClassType, typename ReturnType, typename... Arguments>
struct FunctionTraits<ReturnType (ClassType::*)(Arguments...)>
: FunctionTraits<ReturnType (*)(Arguments...)>
{
static const bool isClassFunction = true;
static const bool isDrObjectClass =
std::is_base_of<DrObject<ClassType>, ClassType>::value;
using class_type = ClassType;
static const std::string name()
{
return std::string("Class Function");
}
};
//
// Match coroutine functions
//
#ifdef __cpp_impl_coroutine #ifdef __cpp_impl_coroutine
template <typename... Arguments> template <typename... Arguments>
struct FunctionTraits< struct FunctionTraits<
@ -158,6 +219,20 @@ struct FunctionTraits<Task<HttpResponsePtr> (*)(HttpRequestPtr req,
}; };
#endif #endif
//
// Bad matches
//
template <typename ReturnType, typename... Arguments>
struct FunctionTraits<
ReturnType (*)(HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
Arguments...)> : FunctionTraits<ReturnType (*)(Arguments...)>
{
static const bool isHTTPFunction = false;
using class_type = void;
};
template <typename ReturnType, typename... Arguments> template <typename ReturnType, typename... Arguments>
struct FunctionTraits< struct FunctionTraits<
ReturnType (*)(HttpRequestPtr &&req, ReturnType (*)(HttpRequestPtr &&req,
@ -168,43 +243,5 @@ struct FunctionTraits<
using class_type = void; using class_type = void;
}; };
// normal function for HTTP handling
template <typename T, typename ReturnType, typename... Arguments>
struct FunctionTraits<
ReturnType (*)(T &&customReq,
std::function<void(const HttpResponsePtr &)> &&callback,
Arguments...)> : FunctionTraits<ReturnType (*)(Arguments...)>
{
static const bool isHTTPFunction = !resumable_type<ReturnType>::value;
static const bool isCoroutine = false;
using class_type = void;
using first_param_type = T;
using return_type = ReturnType;
};
// normal function
template <typename ReturnType, typename... Arguments>
struct FunctionTraits<ReturnType (*)(Arguments...)>
{
using result_type = ReturnType;
template <std::size_t Index>
using argument =
typename std::tuple_element_t<Index, std::tuple<Arguments...>>;
static const std::size_t arity = sizeof...(Arguments);
using class_type = void;
using return_type = ReturnType;
static const bool isHTTPFunction = false;
static const bool isClassFunction = false;
static const bool isDrObjectClass = false;
static const bool isCoroutine = false;
static const std::string name()
{
return std::string("Normal or Static Function");
}
};
} // namespace internal } // namespace internal
} // namespace drogon } // namespace drogon

View File

@ -524,6 +524,9 @@ static void loadApp(const Json::Value &app)
bool enableCompressedRequests = bool enableCompressedRequests =
app.get("enabled_compressed_request", false).asBool(); app.get("enabled_compressed_request", false).asBool();
drogon::app().enableCompressedRequest(enableCompressedRequests); drogon::app().enableCompressedRequest(enableCompressedRequests);
drogon::app().enableRequestStream(
app.get("enable_request_stream", false).asBool());
} }
static void loadDbClients(const Json::Value &dbClients) static void loadDbClients(const Json::Value &dbClients)

View File

@ -41,6 +41,11 @@ struct ControllerBinderBase
virtual void handleRequest( virtual void handleRequest(
const HttpRequestImplPtr &req, const HttpRequestImplPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback) const = 0; std::function<void(const HttpResponsePtr &)> &&callback) const = 0;
virtual bool isStreamHandler() const
{
return false;
}
}; };
struct RouteResult struct RouteResult

View File

@ -1246,6 +1246,17 @@ int64_t HttpAppFrameworkImpl::getConnectionCount() const
return HttpConnectionLimit::instance().getConnectionNum(); return HttpConnectionLimit::instance().getConnectionNum();
} }
HttpAppFramework &HttpAppFrameworkImpl::enableRequestStream(bool enable)
{
enableRequestStream_ = enable;
return *this;
}
bool HttpAppFrameworkImpl::isRequestStreamEnabled() const
{
return enableRequestStream_;
}
// AOP registration methods // AOP registration methods
HttpAppFramework &HttpAppFrameworkImpl::registerNewConnectionAdvice( HttpAppFramework &HttpAppFrameworkImpl::registerNewConnectionAdvice(

View File

@ -663,6 +663,9 @@ class HttpAppFrameworkImpl final : public HttpAppFramework
HttpAppFramework &setAfterAcceptSockOptCallback( HttpAppFramework &setAfterAcceptSockOptCallback(
std::function<void(int)> cb) override; std::function<void(int)> cb) override;
HttpAppFramework &enableRequestStream(bool enable) override;
bool isRequestStreamEnabled() const override;
private: private:
void registerHttpController(const std::string &pathPattern, void registerHttpController(const std::string &pathPattern,
const internal::HttpBinderBasePtr &binder, const internal::HttpBinderBasePtr &binder,
@ -753,6 +756,8 @@ class HttpAppFrameworkImpl final : public HttpAppFramework
ExceptionHandler exceptionHandler_{defaultExceptionHandler}; ExceptionHandler exceptionHandler_{defaultExceptionHandler};
bool enableCompressedRequest_{false}; bool enableCompressedRequest_{false};
bool enableRequestStream_{false};
}; };
} // namespace drogon } // namespace drogon

View File

@ -39,6 +39,12 @@ class HttpControllerBinder : public ControllerBinderBase
const HttpRequestImplPtr &req, const HttpRequestImplPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback) const override; std::function<void(const HttpResponsePtr &)> &&callback) const override;
bool isStreamHandler() const override
{
assert(binderPtr_);
return binderPtr_->isStreamHandler();
}
internal::HttpBinderBasePtr binderPtr_; internal::HttpBinderBasePtr binderPtr_;
std::vector<size_t> parameterPlaces_; std::vector<size_t> parameterPlaces_;
std::vector<std::pair<std::string, size_t>> queryParametersPlaces_; std::vector<std::pair<std::string, size_t>> queryParametersPlaces_;

View File

@ -567,6 +567,8 @@ void HttpRequestImpl::swap(HttpRequestImpl &that) noexcept
swap(query_, that.query_); swap(query_, that.query_);
swap(headers_, that.headers_); swap(headers_, that.headers_);
swap(cookies_, that.cookies_); swap(cookies_, that.cookies_);
swap(contentLengthHeaderValue_, that.contentLengthHeaderValue_);
swap(realContentLength_, that.realContentLength_);
swap(parameters_, that.parameters_); swap(parameters_, that.parameters_);
swap(jsonPtr_, that.jsonPtr_); swap(jsonPtr_, that.jsonPtr_);
swap(sessionPtr_, that.sessionPtr_); swap(sessionPtr_, that.sessionPtr_);
@ -584,6 +586,12 @@ void HttpRequestImpl::swap(HttpRequestImpl &that) noexcept
swap(flagForParsingContentType_, that.flagForParsingContentType_); swap(flagForParsingContentType_, that.flagForParsingContentType_);
swap(jsonParsingErrorPtr_, that.jsonParsingErrorPtr_); swap(jsonParsingErrorPtr_, that.jsonParsingErrorPtr_);
swap(routingParams_, that.routingParams_); swap(routingParams_, that.routingParams_);
// stream
swap(streamStatus_, that.streamStatus_);
swap(streamReaderPtr_, that.streamReaderPtr_);
swap(streamFinishCb_, that.streamFinishCb_);
swap(streamExceptionPtr_, that.streamExceptionPtr_);
swap(startProcessing_, that.startProcessing_);
} }
const char *HttpRequestImpl::versionString() const const char *HttpRequestImpl::versionString() const
@ -723,6 +731,11 @@ HttpRequestImpl::~HttpRequestImpl()
void HttpRequestImpl::reserveBodySize(size_t length) void HttpRequestImpl::reserveBodySize(size_t length)
{ {
assert(loop_->isInLoopThread());
if (cacheFilePtr_)
{
return;
}
if (length <= HttpAppFrameworkImpl::instance().getClientMaxMemoryBodySize()) if (length <= HttpAppFrameworkImpl::instance().getClientMaxMemoryBodySize())
{ {
content_.reserve(length); content_.reserve(length);
@ -736,7 +749,14 @@ void HttpRequestImpl::reserveBodySize(size_t length)
void HttpRequestImpl::appendToBody(const char *data, size_t length) void HttpRequestImpl::appendToBody(const char *data, size_t length)
{ {
if (cacheFilePtr_) assert(loop_->isInLoopThread());
realContentLength_ += length;
if (streamReaderPtr_)
{
assert(streamStatus_ == ReqStreamStatus::Open);
streamReaderPtr_->onStreamData(data, length);
}
else if (cacheFilePtr_)
{ {
cacheFilePtr_->append(data, length); cacheFilePtr_->append(data, length);
} }
@ -974,3 +994,114 @@ StreamDecompressStatus HttpRequestImpl::decompressBodyGzip() noexcept
} }
return status; return status;
} }
void HttpRequestImpl::setStreamReader(RequestStreamReaderPtr reader)
{
assert(loop_->isInLoopThread());
assert(!streamReaderPtr_);
assert(streamStatus_ > ReqStreamStatus::None);
if (streamExceptionPtr_)
{
assert(streamStatus_ == ReqStreamStatus::Error);
reader->onStreamFinish(std::move(streamExceptionPtr_));
streamExceptionPtr_ = nullptr;
return;
}
// Consume already received body
if (cacheFilePtr_)
{
auto bodyPieceView = cacheFilePtr_->getStringView();
if (!bodyPieceView.empty())
reader->onStreamData(bodyPieceView.data(), bodyPieceView.length());
cacheFilePtr_.reset();
}
else if (!content_.empty())
{
reader->onStreamData(content_.data(), content_.length());
content_.clear();
}
if (streamStatus_ == ReqStreamStatus::Finish)
{
reader->onStreamFinish({});
}
else
{
streamReaderPtr_ = std::move(reader);
}
}
void HttpRequestImpl::streamStart()
{
assert(streamStatus_ == ReqStreamStatus::None);
streamStatus_ = ReqStreamStatus::Open;
}
void HttpRequestImpl::streamFinish()
{
assert(loop_->isInLoopThread());
assert(streamStatus_ == ReqStreamStatus::Open);
streamStatus_ = ReqStreamStatus::Finish;
if (streamFinishCb_)
{
auto cb = std::move(streamFinishCb_);
streamFinishCb_ = nullptr;
cb();
}
if (streamReaderPtr_)
{
streamReaderPtr_->onStreamFinish({});
streamReaderPtr_ = nullptr;
}
}
void HttpRequestImpl::streamError(std::exception_ptr ex)
{
// TODO: can we be sure that streamError() only be called once?
// If not, we could allow it to be called multiple times, and
// only handle the first one.
assert(loop_->isInLoopThread());
assert(streamStatus_ == ReqStreamStatus::Open);
streamStatus_ = ReqStreamStatus::Error;
if (streamReaderPtr_)
{
streamReaderPtr_->onStreamFinish(std::move(ex));
streamReaderPtr_ = nullptr;
}
else
{
streamExceptionPtr_ = std::move(ex);
}
if (streamFinishCb_)
{
auto cb = std::move(streamFinishCb_);
streamFinishCb_ = nullptr;
cb();
}
}
void HttpRequestImpl::waitForStreamFinish(std::function<void()> &&cb)
{
assert(loop_->isInLoopThread());
assert(streamStatus_ > ReqStreamStatus::None);
if (streamStatus_ <= ReqStreamStatus::Open)
{
assert(!streamFinishCb_); // should only be called once
streamFinishCb_ = std::move(cb);
}
else
{
cb();
}
}
void HttpRequestImpl::quitStreamMode()
{
assert(loop_->isInLoopThread());
assert(streamStatus_ >= ReqStreamStatus::Finish);
assert(!streamReaderPtr_);
streamStatus_ = ReqStreamStatus::None;
}

View File

@ -18,6 +18,7 @@
#include "CacheFile.h" #include "CacheFile.h"
#include <drogon/utils/Utilities.h> #include <drogon/utils/Utilities.h>
#include <drogon/HttpRequest.h> #include <drogon/HttpRequest.h>
#include <drogon/RequestStream.h>
#include <drogon/utils/Utilities.h> #include <drogon/utils/Utilities.h>
#include <trantor/net/EventLoop.h> #include <trantor/net/EventLoop.h>
#include <trantor/net/InetAddress.h> #include <trantor/net/InetAddress.h>
@ -42,6 +43,14 @@ enum class StreamDecompressStatus
Ok Ok
}; };
enum class ReqStreamStatus
{
None = 0,
Open = 1,
Finish = 2,
Error = 3
};
class HttpRequestImpl : public HttpRequest class HttpRequestImpl : public HttpRequest
{ {
public: public:
@ -60,6 +69,8 @@ class HttpRequestImpl : public HttpRequest
flagForParsingJson_ = false; flagForParsingJson_ = false;
headers_.clear(); headers_.clear();
cookies_.clear(); cookies_.clear();
contentLengthHeaderValue_.reset();
realContentLength_ = 0;
flagForParsingParameters_ = false; flagForParsingParameters_ = false;
path_.clear(); path_.clear();
originalPath_.clear(); originalPath_.clear();
@ -80,6 +91,12 @@ class HttpRequestImpl : public HttpRequest
jsonParsingErrorPtr_.reset(); jsonParsingErrorPtr_.reset();
peerCertificate_.reset(); peerCertificate_.reset();
routingParams_.clear(); routingParams_.clear();
// stream
streamStatus_ = ReqStreamStatus::None;
streamReaderPtr_.reset();
streamFinishCb_ = nullptr;
streamExceptionPtr_ = nullptr;
startProcessing_ = false;
} }
trantor::EventLoop *getLoop() trantor::EventLoop *getLoop()
@ -207,6 +224,10 @@ class HttpRequestImpl : public HttpRequest
std::string_view bodyView() const std::string_view bodyView() const
{ {
if (isStreamMode())
{
return emptySv_;
}
if (cacheFilePtr_) if (cacheFilePtr_)
{ {
return cacheFilePtr_->getStringView(); return cacheFilePtr_->getStringView();
@ -216,6 +237,10 @@ class HttpRequestImpl : public HttpRequest
const char *bodyData() const override const char *bodyData() const override
{ {
if (isStreamMode())
{
return emptySv_.data();
}
if (cacheFilePtr_) if (cacheFilePtr_)
{ {
return cacheFilePtr_->getStringView().data(); return cacheFilePtr_->getStringView().data();
@ -225,6 +250,10 @@ class HttpRequestImpl : public HttpRequest
size_t bodyLength() const override size_t bodyLength() const override
{ {
if (isStreamMode())
{
return emptySv_.length();
}
if (cacheFilePtr_) if (cacheFilePtr_)
{ {
return cacheFilePtr_->getStringView().length(); return cacheFilePtr_->getStringView().length();
@ -243,6 +272,10 @@ class HttpRequestImpl : public HttpRequest
std::string_view contentView() const std::string_view contentView() const
{ {
if (isStreamMode())
{
return emptySv_;
}
if (cacheFilePtr_) if (cacheFilePtr_)
return cacheFilePtr_->getStringView(); return cacheFilePtr_->getStringView();
return content_; return content_;
@ -349,6 +382,16 @@ class HttpRequestImpl : public HttpRequest
return cookies_; return cookies_;
} }
std::optional<size_t> getContentLengthHeaderValue() const
{
return contentLengthHeaderValue_;
}
size_t realContentLength() const override
{
return realContentLength_;
}
void setParameter(const std::string &key, const std::string &value) override void setParameter(const std::string &key, const std::string &value) override
{ {
flagForParsingParameters_ = true; flagForParsingParameters_ = true;
@ -526,7 +569,36 @@ class HttpRequestImpl : public HttpRequest
StreamDecompressStatus decompressBody(); StreamDecompressStatus decompressBody();
~HttpRequestImpl(); // Stream mode api
ReqStreamStatus streamStatus() const
{
return streamStatus_;
}
bool isStreamMode() const
{
return streamStatus_ > ReqStreamStatus::None;
}
void streamStart();
void streamFinish();
void streamError(std::exception_ptr ex);
void setStreamReader(RequestStreamReaderPtr reader);
void waitForStreamFinish(std::function<void()> &&cb);
void quitStreamMode();
void startProcessing()
{
startProcessing_ = true;
}
bool isProcessingStarted() const
{
return startProcessing_;
}
~HttpRequestImpl() override;
protected: protected:
friend class HttpRequest; friend class HttpRequest;
@ -592,6 +664,9 @@ class HttpRequestImpl : public HttpRequest
StreamDecompressStatus decompressBodyBrotli() noexcept; StreamDecompressStatus decompressBodyBrotli() noexcept;
#endif #endif
StreamDecompressStatus decompressBodyGzip() noexcept; StreamDecompressStatus decompressBodyGzip() noexcept;
static constexpr const std::string_view emptySv_{""};
mutable bool flagForParsingParameters_{false}; mutable bool flagForParsingParameters_{false};
mutable bool flagForParsingJson_{false}; mutable bool flagForParsingJson_{false};
HttpMethod method_{Invalid}; HttpMethod method_{Invalid};
@ -604,6 +679,8 @@ class HttpRequestImpl : public HttpRequest
std::string query_; std::string query_;
SafeStringMap<std::string> headers_; SafeStringMap<std::string> headers_;
SafeStringMap<std::string> cookies_; SafeStringMap<std::string> cookies_;
std::optional<size_t> contentLengthHeaderValue_;
size_t realContentLength_{0};
mutable SafeStringMap<std::string> parameters_; mutable SafeStringMap<std::string> parameters_;
mutable std::shared_ptr<Json::Value> jsonPtr_; mutable std::shared_ptr<Json::Value> jsonPtr_;
SessionPtr sessionPtr_; SessionPtr sessionPtr_;
@ -620,6 +697,12 @@ class HttpRequestImpl : public HttpRequest
bool passThrough_{false}; bool passThrough_{false};
std::vector<std::string> routingParams_; std::vector<std::string> routingParams_;
ReqStreamStatus streamStatus_{ReqStreamStatus::None};
std::function<void()> streamFinishCb_;
RequestStreamReaderPtr streamReaderPtr_;
std::exception_ptr streamExceptionPtr_;
bool startProcessing_{false};
protected: protected:
std::string content_; std::string content_;
trantor::EventLoop *loop_; trantor::EventLoop *loop_;

View File

@ -36,19 +36,6 @@ HttpRequestParser::HttpRequestParser(const trantor::TcpConnectionPtr &connPtr)
{ {
} }
void HttpRequestParser::shutdownConnection(HttpStatusCode code)
{
auto connPtr = conn_.lock();
if (connPtr)
{
connPtr->send(utils::formattedString(
"HTTP/1.1 %d %s\r\nConnection: close\r\n\r\n",
code,
statusCodeToString(code).data()));
connPtr->shutdown();
}
}
bool HttpRequestParser::processRequestLine(const char *begin, const char *end) bool HttpRequestParser::processRequestLine(const char *begin, const char *end)
{ {
bool succeed = false; bool succeed = false;
@ -130,7 +117,7 @@ HttpRequestImplPtr HttpRequestParser::makeRequestForPool(HttpRequestImpl *ptr)
void HttpRequestParser::reset() void HttpRequestParser::reset()
{ {
assert(loop_->isInLoopThread()); assert(loop_->isInLoopThread());
currentContentLength_ = 0; remainContentLength_ = 0;
status_ = HttpRequestParseStatus::kExpectMethod; status_ = HttpRequestParseStatus::kExpectMethod;
if (requestsPool_.empty()) if (requestsPool_.empty())
{ {
@ -146,9 +133,12 @@ void HttpRequestParser::reset()
} }
/** /**
* @return return -1 if encounters any error in request * @return return -HttpStatusCode if encounters any http errors in request
* @return return -1 if encounters any other errors in request
* @return return 0 if request is not ready * @return return 0 if request is not ready
* @return return 1 if request is ready * @return return 1 if request is ready
* @return return 2 if request is ready and entering stream mode
* @return return 3 if request header is ready and entering stream mode
*/ */
int HttpRequestParser::parseRequest(MsgBuffer *buf) int HttpRequestParser::parseRequest(MsgBuffer *buf)
{ {
@ -166,18 +156,14 @@ int HttpRequestParser::parseRequest(MsgBuffer *buf)
{ {
if (buf->readableBytes() > METHOD_MAX_LEN) if (buf->readableBytes() > METHOD_MAX_LEN)
{ {
buf->retrieveAll(); return -k400BadRequest;
shutdownConnection(k400BadRequest);
return -1;
} }
return 0; return 0;
} }
// try read method // try read method
if (!request_->setMethod(buf->peek(), space)) if (!request_->setMethod(buf->peek(), space))
{ {
buf->retrieveAll(); return -k405MethodNotAllowed;
shutdownConnection(k405MethodNotAllowed);
return -1;
} }
status_ = HttpRequestParseStatus::kExpectRequestLine; status_ = HttpRequestParseStatus::kExpectRequestLine;
buf->retrieveUntil(space + 1); buf->retrieveUntil(space + 1);
@ -193,18 +179,14 @@ int HttpRequestParser::parseRequest(MsgBuffer *buf)
/// The limit for request line is 64K bytes. response /// The limit for request line is 64K bytes. response
/// k414RequestURITooLarge /// k414RequestURITooLarge
/// TODO: Make this configurable? /// TODO: Make this configurable?
buf->retrieveAll(); return -k414RequestURITooLarge;
shutdownConnection(k414RequestURITooLarge);
return -1;
} }
return 0; return 0;
} }
if (!processRequestLine(buf->peek(), crlf)) if (!processRequestLine(buf->peek(), crlf))
{ {
// error // error
buf->retrieveAll(); return -k400BadRequest;
shutdownConnection(k400BadRequest);
return -1;
} }
buf->retrieveUntil(crlf + CRLF_LEN); buf->retrieveUntil(crlf + CRLF_LEN);
status_ = HttpRequestParseStatus::kExpectHeaders; status_ = HttpRequestParseStatus::kExpectHeaders;
@ -219,9 +201,7 @@ int HttpRequestParser::parseRequest(MsgBuffer *buf)
{ {
/// The limit for every request header is 64K bytes; /// The limit for every request header is 64K bytes;
/// TODO: Make this configurable? /// TODO: Make this configurable?
buf->retrieveAll(); return -k400BadRequest;
shutdownConnection(k400BadRequest);
return -1;
} }
return 0; return 0;
} }
@ -246,21 +226,18 @@ int HttpRequestParser::parseRequest(MsgBuffer *buf)
{ {
try try
{ {
currentContentLength_ = remainContentLength_ =
static_cast<size_t>(std::stoull(len)); static_cast<size_t>(std::stoull(len));
} }
catch (...) catch (...)
{ {
buf->retrieveAll(); return -k400BadRequest;
shutdownConnection(k400BadRequest);
return -1;
} }
if (currentContentLength_ == 0) request_->contentLengthHeaderValue_ = remainContentLength_;
if (remainContentLength_ == 0)
{ {
// content-length = 0, request is over. // content-length = 0, request is over.
status_ = HttpRequestParseStatus::kGotAll; status_ = HttpRequestParseStatus::kGotAll;
++requestsCounter_;
return 1;
} }
else else
{ {
@ -276,8 +253,6 @@ int HttpRequestParser::parseRequest(MsgBuffer *buf)
// no content-length and no transfer-encoding, // no content-length and no transfer-encoding,
// request is over. // request is over.
status_ = HttpRequestParseStatus::kGotAll; status_ = HttpRequestParseStatus::kGotAll;
++requestsCounter_;
return 1;
} }
else if (encode == "chunked") else if (encode == "chunked")
{ {
@ -285,43 +260,37 @@ int HttpRequestParser::parseRequest(MsgBuffer *buf)
} }
else else
{ {
buf->retrieveAll(); return -k501NotImplemented;
shutdownConnection(k501NotImplemented);
return -1;
} }
} }
// Check max body size
if (remainContentLength_ >
HttpAppFrameworkImpl::instance().getClientMaxBodySize())
{
return -k413RequestEntityTooLarge;
}
// Check expect:100-continue
auto &expect = request_->expect(); auto &expect = request_->expect();
if (expect == "100-continue" && if (expect == "100-continue" &&
request_->getVersion() >= Version::kHttp11) request_->getVersion() >= Version::kHttp11)
{ {
if (currentContentLength_ == 0) if (remainContentLength_ == 0)
{ {
// error // error
buf->retrieveAll(); return -k400BadRequest;
shutdownConnection(k400BadRequest);
return -1;
}
// rfc2616-8.2.3
auto connPtr = conn_.lock();
if (!connPtr)
{
return -1;
}
auto resp = HttpResponse::newHttpResponse();
if (currentContentLength_ >
HttpAppFrameworkImpl::instance().getClientMaxBodySize())
{
resp->setStatusCode(k413RequestEntityTooLarge);
auto httpString =
static_cast<HttpResponseImpl *>(resp.get())
->renderToBuffer();
reset();
connPtr->send(std::move(*httpString));
// TODO: missing logic here
} }
else else
{ {
// rfc2616-8.2.3
// TODO: consider adding an AOP for expect header
auto connPtr = conn_.lock(); // ugly
if (!connPtr)
{
return -1;
}
auto resp = HttpResponse::newHttpResponse();
resp->setStatusCode(k100Continue); resp->setStatusCode(k100Continue);
auto httpString = auto httpString =
static_cast<HttpResponseImpl *>(resp.get()) static_cast<HttpResponseImpl *>(resp.get())
@ -332,35 +301,50 @@ int HttpRequestParser::parseRequest(MsgBuffer *buf)
else if (!expect.empty()) else if (!expect.empty())
{ {
LOG_WARN << "417ExpectationFailed for \"" << expect << "\""; LOG_WARN << "417ExpectationFailed for \"" << expect << "\"";
buf->retrieveAll(); return -k417ExpectationFailed;
shutdownConnection(k417ExpectationFailed);
return -1;
} }
else if (currentContentLength_ >
HttpAppFrameworkImpl::instance() assert(status_ == HttpRequestParseStatus::kGotAll ||
.getClientMaxBodySize()) status_ == HttpRequestParseStatus::kExpectBody ||
status_ == HttpRequestParseStatus::kExpectChunkLen);
if (app().isRequestStreamEnabled())
{ {
buf->retrieveAll(); request_->streamStart();
shutdownConnection(k413RequestEntityTooLarge); if (status_ == HttpRequestParseStatus::kGotAll)
return -1; {
++requestsCounter_;
return 2;
}
else
{
return 3;
}
}
// Reserve space for full body in non-stream mode.
// For stream mode requests that match a non-stream handler,
// we will reserve full body before waitForStreamFinish().
if (remainContentLength_)
{
request_->reserveBodySize(remainContentLength_);
} }
request_->reserveBodySize(currentContentLength_);
continue; continue;
} }
case HttpRequestParseStatus::kExpectBody: case HttpRequestParseStatus::kExpectBody:
{ {
size_t bytesToConsume = size_t bytesToConsume =
currentContentLength_ <= buf->readableBytes() remainContentLength_ <= buf->readableBytes()
? currentContentLength_ ? remainContentLength_
: buf->readableBytes(); : buf->readableBytes();
if (bytesToConsume) if (bytesToConsume)
{ {
request_->appendToBody(buf->peek(), bytesToConsume); request_->appendToBody(buf->peek(), bytesToConsume);
buf->retrieve(bytesToConsume); buf->retrieve(bytesToConsume);
currentContentLength_ -= bytesToConsume; remainContentLength_ -= bytesToConsume;
} }
if (currentContentLength_ == 0) if (remainContentLength_ == 0)
{ {
status_ = HttpRequestParseStatus::kGotAll; status_ = HttpRequestParseStatus::kGotAll;
++requestsCounter_; ++requestsCounter_;
@ -376,9 +360,7 @@ int HttpRequestParser::parseRequest(MsgBuffer *buf)
{ {
if (buf->readableBytes() > TRUNK_LEN_MAX_LEN + CRLF_LEN) if (buf->readableBytes() > TRUNK_LEN_MAX_LEN + CRLF_LEN)
{ {
buf->retrieveAll(); return -k400BadRequest;
shutdownConnection(k400BadRequest);
return -1;
} }
return 0; return 0;
} }
@ -388,12 +370,10 @@ int HttpRequestParser::parseRequest(MsgBuffer *buf)
currentChunkLength_ = strtol(len.c_str(), &end, 16); currentChunkLength_ = strtol(len.c_str(), &end, 16);
if (currentChunkLength_ != 0) if (currentChunkLength_ != 0)
{ {
if (currentChunkLength_ + currentContentLength_ > if (currentChunkLength_ + remainContentLength_ >
HttpAppFrameworkImpl::instance().getClientMaxBodySize()) HttpAppFrameworkImpl::instance().getClientMaxBodySize())
{ {
buf->retrieveAll(); return -k413RequestEntityTooLarge;
shutdownConnection(k413RequestEntityTooLarge);
return -1;
} }
status_ = HttpRequestParseStatus::kExpectChunkBody; status_ = HttpRequestParseStatus::kExpectChunkBody;
} }
@ -414,13 +394,11 @@ int HttpRequestParser::parseRequest(MsgBuffer *buf)
*(buf->peek() + currentChunkLength_ + 1) != '\n') *(buf->peek() + currentChunkLength_ + 1) != '\n')
{ {
// error! // error!
buf->retrieveAll(); return -k400BadRequest;
shutdownConnection(k400BadRequest);
return -1;
} }
request_->appendToBody(buf->peek(), currentChunkLength_); request_->appendToBody(buf->peek(), currentChunkLength_);
buf->retrieve(currentChunkLength_ + CRLF_LEN); buf->retrieve(currentChunkLength_ + CRLF_LEN);
currentContentLength_ += currentChunkLength_; remainContentLength_ += currentChunkLength_;
currentChunkLength_ = 0; currentChunkLength_ = 0;
status_ = HttpRequestParseStatus::kExpectChunkLen; status_ = HttpRequestParseStatus::kExpectChunkLen;
continue; continue;
@ -435,25 +413,44 @@ int HttpRequestParser::parseRequest(MsgBuffer *buf)
if (*(buf->peek()) != '\r' || *(buf->peek() + 1) != '\n') if (*(buf->peek()) != '\r' || *(buf->peek() + 1) != '\n')
{ {
// error! // error!
buf->retrieveAll(); return -k400BadRequest;
shutdownConnection(k400BadRequest);
return -1;
} }
buf->retrieve(CRLF_LEN); buf->retrieve(CRLF_LEN);
if (!request_->isStreamMode())
{
// Previously we only have non-stream mode, drogon handled
// chunked encoding internally, and give user a regular
// request as if it has a content-length header.
//
// We have to keep compatibility for non-stream mode.
//
// But I don't think it's a good implementation. We should
// instead add an api to access real content-length of
// requests.
// Now HttpRequest::realContentLength() is added, and user
// should no longer parse content-length header by
// themselves.
//
// NOTE: request forward behavior may be infected in stream
// mode, we should check it out.
request_->addHeader("content-length",
std::to_string(
request_->realContentLength()));
request_->removeHeaderBy("transfer-encoding");
}
status_ = HttpRequestParseStatus::kGotAll; status_ = HttpRequestParseStatus::kGotAll;
request_->addHeader("content-length",
std::to_string(request_->bodyLength()));
request_->removeHeaderBy("transfer-encoding");
++requestsCounter_; ++requestsCounter_;
return 1; return 1;
} }
case HttpRequestParseStatus::kGotAll: case HttpRequestParseStatus::kGotAll:
{ {
++requestsCounter_;
return 1; return 1;
} }
} }
} }
return -1; return -1; // won't reach here, just to make compiler happy
} }
void HttpRequestParser::pushRequestToPipelining(const HttpRequestPtr &req, void HttpRequestParser::pushRequestToPipelining(const HttpRequestPtr &req,

View File

@ -43,7 +43,6 @@ class HttpRequestParser : public trantor::NonCopyable,
explicit HttpRequestParser(const trantor::TcpConnectionPtr &connPtr); explicit HttpRequestParser(const trantor::TcpConnectionPtr &connPtr);
// return false if any error
int parseRequest(trantor::MsgBuffer *buf); int parseRequest(trantor::MsgBuffer *buf);
bool gotAll() const bool gotAll() const
@ -138,7 +137,6 @@ class HttpRequestParser : public trantor::NonCopyable,
private: private:
HttpRequestImplPtr makeRequestForPool(HttpRequestImpl *p); HttpRequestImplPtr makeRequestForPool(HttpRequestImpl *p);
void shutdownConnection(HttpStatusCode code);
bool processRequestLine(const char *begin, const char *end); bool processRequestLine(const char *begin, const char *end);
HttpRequestParseStatus status_; HttpRequestParseStatus status_;
trantor::EventLoop *loop_; trantor::EventLoop *loop_;
@ -156,7 +154,7 @@ class HttpRequestParser : public trantor::NonCopyable,
std::unique_ptr<std::vector<HttpRequestImplPtr>> requestBuffer_; std::unique_ptr<std::vector<HttpRequestImplPtr>> requestBuffer_;
std::vector<HttpRequestImplPtr> requestsPool_; std::vector<HttpRequestImplPtr> requestsPool_;
size_t currentChunkLength_{0}; size_t currentChunkLength_{0};
size_t currentContentLength_{0}; size_t remainContentLength_{0};
}; };
} // namespace drogon } // namespace drogon

View File

@ -132,6 +132,13 @@ void HttpServer::onConnection(const TcpConnectionPtr &conn)
{ {
requestParser->webSocketConn()->onClose(); requestParser->webSocketConn()->onClose();
} }
else if (requestParser->requestImpl()->isStreamMode())
{
requestParser->requestImpl()->streamError(
std::make_exception_ptr(
StreamError(StreamErrorCode::kConnectionBroken,
"Connection closed")));
}
conn->clearContext(); conn->clearContext();
} }
} }
@ -162,28 +169,75 @@ void HttpServer::onMessage(const TcpConnectionPtr &conn, MsgBuffer *buf)
buf->retrieveAll(); buf->retrieveAll();
return; return;
} }
auto &req = requestParser->requestImpl();
// if stream mode enabled, parseRequest() may return >0 multiple times
// for the same request
int parseRes = requestParser->parseRequest(buf); int parseRes = requestParser->parseRequest(buf);
if (parseRes < 0) if (parseRes < 0)
{ {
if (req->isStreamMode() && req->isProcessingStarted())
{
// After entering stream mode, if request matches a non-stream
// handler, stream error would be intercepted by the
// `waitForStreamFinish()` call.
// If request matches a stream handler, stream error should be
// captured by user provided StreamReader, and response should
// also be sent by user.
req->streamError(std::make_exception_ptr(
StreamError(StreamErrorCode::kBadRequest, "Bad request")));
}
else if (parseRes != -1)
{
// In non-stream mode, request won't be process until it's fully
// parsed. To keep the old behavior, we send response directly
// through conn. (This response won't go through pre-sending
// aop, maybe we should change this behavior).
auto code = static_cast<HttpStatusCode>(-parseRes);
conn->send(utils::formattedString(
"HTTP/1.1 %d %s\r\nConnection: close\r\n\r\n",
code,
statusCodeToString(code).data()));
}
buf->retrieveAll();
// NOTE: should we call conn->forceClose() instead?
// Calling shutdown() handles socket more elegantly.
conn->shutdown();
// We have to call clearContext() here in order to ignore following
// illegal data from client
conn->clearContext();
requestParser->reset(); requestParser->reset();
conn->forceClose();
return; return;
} }
if (parseRes == 0) if (parseRes == 0)
{ {
break; break;
} }
auto &req = requestParser->requestImpl(); if (parseRes >= 2 || parseRes == 1 && !req->isStreamMode())
req->setPeerAddr(conn->peerAddr()); {
req->setLocalAddr(conn->localAddr()); req->setPeerAddr(conn->peerAddr());
req->setCreationDate(trantor::Date::date()); req->setLocalAddr(conn->localAddr());
req->setSecure(conn->isSSLConnection()); req->setCreationDate(trantor::Date::date());
req->setPeerCertificate(conn->peerCertificate()); req->setSecure(conn->isSSLConnection());
requests.push_back(req); req->setPeerCertificate(conn->peerCertificate());
requestParser->reset(); // TODO: maybe call onRequests() directly in stream mode
requests.push_back(req);
}
if (parseRes == 1 || parseRes == 2)
{
assert(requestParser->gotAll());
if (req->isStreamMode())
{
req->streamFinish();
}
requestParser->reset();
}
}
if (!requests.empty())
{
onRequests(conn, requests, requestParser);
requests.clear();
} }
onRequests(conn, requests, requestParser);
requests.clear();
} }
struct CallbackParamPack struct CallbackParamPack
@ -214,14 +268,14 @@ void HttpServer::onRequests(
const std::vector<HttpRequestImplPtr> &requests, const std::vector<HttpRequestImplPtr> &requests,
const std::shared_ptr<HttpRequestParser> &requestParser) const std::shared_ptr<HttpRequestParser> &requestParser)
{ {
if (requests.empty()) assert(!requests.empty());
return;
// will only be checked for the first request // will only be checked for the first request
if (requestParser->firstReq() && requests.size() == 1 && if (requestParser->firstReq() && requests.size() == 1 &&
isWebSocket(requests[0])) isWebSocket(requests[0]))
{ {
auto &req = requests[0]; auto &req = requests[0];
req->startProcessing();
if (passSyncAdvices(req, if (passSyncAdvices(req,
requestParser, requestParser,
false /* Not pipelined */, false /* Not pipelined */,
@ -287,6 +341,7 @@ void HttpServer::onRequests(
for (auto &req : requests) for (auto &req : requests)
{ {
req->startProcessing();
bool isHeadMethod = (req->method() == Head); bool isHeadMethod = (req->method() == Head);
if (isHeadMethod) if (isHeadMethod)
{ {
@ -421,6 +476,47 @@ void HttpServer::httpRequestRouting(
template <typename Pack> template <typename Pack>
void HttpServer::requestPostRouting(const HttpRequestImplPtr &req, Pack &&pack) void HttpServer::requestPostRouting(const HttpRequestImplPtr &req, Pack &&pack)
{ {
// Handle stream mode for non-stream handlers
if (req->streamStatus() >= ReqStreamStatus::Open &&
!pack.binderPtr->isStreamHandler())
{
LOG_TRACE << "Wait for request stream finish";
if (req->streamStatus() == ReqStreamStatus::Finish)
{
req->quitStreamMode();
}
else
{
auto contentLength = req->getContentLengthHeaderValue();
if (contentLength.has_value())
{
req->reserveBodySize(contentLength.value());
}
req->waitForStreamFinish([weakReq = std::weak_ptr(req),
pack =
std::forward<Pack>(pack)]() mutable {
auto req = weakReq.lock();
if (!req)
return;
if (req->streamStatus() == ReqStreamStatus::Finish)
{
req->quitStreamMode();
// call requestPostRouting again
requestPostRouting(req, std::forward<Pack>(pack));
return;
}
else
{
req->quitStreamMode();
LOG_DEBUG << "Stop processing request due to stream error";
pack.callback(
app().getCustomErrorHandler()(k400BadRequest, req));
}
});
return;
}
}
// post-routing aop // post-routing aop
auto &aop = AopAdvice::instance(); auto &aop = AopAdvice::instance();
aop.passPostRoutingObservers(req); aop.passPostRoutingObservers(req);

View File

@ -0,0 +1,356 @@
/**
*
* @file MultipartStreamParser.h
* @author Nitromelon
*
* Copyright 2024, Nitromelon. All rights reserved.
* https://github.com/drogonframework/drogon
* Use of this source code is governed by a MIT license
* that can be found in the License file.
*
* Drogon
*
*/
#include "MultipartStreamParser.h"
#include <cassert>
using namespace drogon;
static bool startsWith(const std::string_view &a, const std::string_view &b)
{
if (a.size() < b.size())
{
return false;
}
for (size_t i = 0; i < b.size(); i++)
{
if (a[i] != b[i])
{
return false;
}
}
return true;
}
static bool startsWithIgnoreCase(const std::string_view &a,
const std::string_view &b)
{
if (a.size() < b.size())
{
return false;
}
for (size_t i = 0; i < b.size(); i++)
{
if (::tolower(a[i]) != ::tolower(b[i]))
{
return false;
}
}
return true;
}
MultipartStreamParser::MultipartStreamParser(const std::string &contentType)
{
static const std::string_view multipart = "multipart/form-data";
static const std::string_view boundaryEq = "boundary=";
if (!startsWithIgnoreCase(contentType, multipart))
{
isValid_ = false;
return;
}
auto pos = contentType.find(boundaryEq, multipart.size());
if (pos == std::string::npos)
{
isValid_ = false;
return;
}
pos += boundaryEq.size();
size_t pos2;
if (contentType[pos] == '"')
{
++pos;
pos2 = contentType.find('"', pos);
}
else
{
pos2 = contentType.find(';', pos);
}
if (pos2 == std::string::npos)
pos2 = contentType.size();
boundary_ = contentType.substr(pos, pos2 - pos);
dashBoundaryCrlf_ = dash_ + boundary_ + crlf_;
crlfDashBoundary_ = crlf_ + dash_ + boundary_;
}
// TODO: same function in HttpRequestParser.cc
static std::pair<std::string_view, std::string_view> parseLine(
const char *begin,
const char *end)
{
auto p = begin;
while (p != end)
{
if (*p == ':')
{
if (p + 1 != end && *(p + 1) == ' ')
{
return std::make_pair(std::string_view(begin, p - begin),
std::string_view(p + 2, end - p - 2));
}
else
{
return std::make_pair(std::string_view(begin, p - begin),
std::string_view(p + 1, end - p - 1));
}
}
++p;
}
return std::make_pair(std::string_view(), std::string_view());
}
void drogon::MultipartStreamParser::parse(
const char *data,
size_t length,
const drogon::RequestStreamReader::MultipartHeaderCallback &headerCb,
const drogon::RequestStreamReader::StreamDataCallback &dataCb)
{
buffer_.append(data, length);
while (buffer_.size() > 0)
{
switch (status_)
{
case Status::kExpectFirstBoundary:
{
if (buffer_.size() < dashBoundaryCrlf_.size())
{
return;
}
std::string_view v = buffer_.view();
auto pos = v.find(dashBoundaryCrlf_);
// ignore everything before the first boundary
if (pos == std::string::npos)
{
buffer_.eraseFront(buffer_.size() -
dashBoundaryCrlf_.size());
return;
}
// found
buffer_.eraseFront(pos + dashBoundaryCrlf_.size());
status_ = Status::kExpectNewEntry;
continue;
}
case Status::kExpectNewEntry:
{
currentHeader_.name.clear();
currentHeader_.filename.clear();
currentHeader_.contentType.clear();
status_ = Status::kExpectHeader;
continue;
}
case Status::kExpectHeader:
{
std::string_view v = buffer_.view();
auto pos = v.find(crlf_);
if (pos == std::string::npos)
{
// same magic number in HttpRequestParser::parseRequest()
if (buffer_.size() > 60 * 1024)
{
isValid_ = false;
}
return; // header incomplete, wait for more data
}
// empty line
if (pos == 0)
{
buffer_.eraseFront(crlf_.size());
status_ = Status::kExpectBody;
headerCb(currentHeader_);
continue;
}
// found header line
auto [keyView, valueView] = parseLine(v.data(), v.data() + pos);
if (keyView.empty() || valueView.empty())
{
// Bad header
isValid_ = false;
return;
}
if (startsWithIgnoreCase(keyView, "content-type"))
{
currentHeader_.contentType = valueView;
}
else if (startsWithIgnoreCase(keyView, "content-disposition"))
{
static const std::string_view nameKey = "name=";
static const std::string_view fileNameKey = "filename=";
// Extract name
auto namePos = valueView.find(nameKey);
if (namePos == std::string::npos)
{
// name absent
isValid_ = false;
return;
}
namePos += nameKey.size();
size_t nameEnd;
if (valueView[namePos] == '"')
{
++namePos;
nameEnd = valueView.find('"', namePos);
}
else
{
nameEnd = valueView.find(';', namePos);
}
if (nameEnd == std::string::npos)
{
// name end not found
isValid_ = false;
return;
}
currentHeader_.name =
valueView.substr(namePos, nameEnd - namePos);
// Extract filename
auto fileNamePos = valueView.find(fileNameKey, nameEnd);
if (fileNamePos != std::string::npos)
{
fileNamePos += fileNameKey.size();
size_t fileNameEnd;
if (valueView[fileNamePos] == '"')
{
++fileNamePos;
fileNameEnd = valueView.find('"', fileNamePos);
}
else
{
fileNameEnd = valueView.find(';', fileNamePos);
}
currentHeader_.filename =
valueView.substr(fileNamePos,
fileNameEnd - fileNamePos);
}
}
// ignore other headers
buffer_.eraseFront(pos + crlf_.size());
continue;
}
case Status::kExpectBody:
{
if (buffer_.size() < crlfDashBoundary_.size())
{
return; // not enough data to check boundary
}
std::string_view v = buffer_.view();
auto pos = v.find(crlfDashBoundary_);
if (pos == std::string::npos)
{
// boundary not found, leave potential partial boundary
size_t len = v.size() - crlfDashBoundary_.size();
if (len > 0)
{
dataCb(v.data(), len);
buffer_.eraseFront(len);
}
return;
}
// found boundary
dataCb(v.data(), pos);
if (pos > 0)
{
dataCb(v.data() + pos, 0); // notify end of file
}
buffer_.eraseFront(pos + crlfDashBoundary_.size());
status_ = Status::kExpectEndOrNewEntry;
continue;
}
case Status::kExpectEndOrNewEntry:
{
std::string_view v = buffer_.view();
// Check new entry
if (v.size() < crlf_.size())
{
return;
}
if (startsWith(v, crlf_))
{
buffer_.eraseFront(crlf_.size());
status_ = Status::kExpectNewEntry;
continue;
}
// Check end
if (v.size() < dash_.size())
{
return;
}
if (startsWith(v, dash_))
{
isFinished_ = true;
buffer_.clear(); // ignore epilogue
return;
}
isValid_ = false;
return;
}
}
}
}
std::string_view MultipartStreamParser::Buffer::view() const
{
return {buffer_.data() + bufHead_, size()};
}
void MultipartStreamParser::Buffer::append(const char *data, size_t length)
{
size_t remainSize = size();
// Move existing data to the front
if (remainSize > 0 && bufHead_ > 0)
{
for (size_t i = 0; i < remainSize; i++)
{
buffer_[i] = buffer_[bufHead_ + i];
}
}
bufHead_ = 0;
bufTail_ = remainSize;
if (remainSize + length > buffer_.size())
{
buffer_.resize(remainSize + length);
}
for (size_t i = 0; i < length; ++i)
{
buffer_[bufTail_ + i] = data[i];
}
bufTail_ += length;
}
size_t MultipartStreamParser::Buffer::size() const
{
return bufTail_ - bufHead_;
}
void MultipartStreamParser::Buffer::eraseFront(size_t length)
{
assert(length <= size());
bufHead_ += length;
}
void MultipartStreamParser::Buffer::clear()
{
buffer_.clear();
bufHead_ = 0;
bufTail_ = 0;
}

View File

@ -0,0 +1,77 @@
/**
*
* @file MultipartStreamParser.h
* @author Nitromelon
*
* Copyright 2024, Nitromelon. All rights reserved.
* https://github.com/drogonframework/drogon
* Use of this source code is governed by a MIT license
* that can be found in the License file.
*
* Drogon
*
*/
#pragma once
#include <drogon/exports.h>
#include <drogon/RequestStream.h>
#include <string>
namespace drogon
{
class DROGON_EXPORT MultipartStreamParser
{
public:
MultipartStreamParser(const std::string &contentType);
void parse(const char *data,
size_t length,
const RequestStreamReader::MultipartHeaderCallback &headerCb,
const RequestStreamReader::StreamDataCallback &dataCb);
bool isFinished() const
{
return isFinished_;
}
bool isValid() const
{
return isValid_;
}
private:
const std::string dash_ = "--";
const std::string crlf_ = "\r\n";
std::string boundary_;
std::string dashBoundaryCrlf_;
std::string crlfDashBoundary_;
struct Buffer
{
public:
std::string_view view() const;
void append(const char *data, size_t length);
size_t size() const;
void eraseFront(size_t length);
void clear();
private:
std::string buffer_;
size_t bufHead_{0};
size_t bufTail_{0};
} buffer_;
enum class Status
{
kExpectFirstBoundary = 0,
kExpectNewEntry = 1,
kExpectHeader = 2,
kExpectBody = 3,
kExpectEndOrNewEntry = 4,
} status_{Status::kExpectFirstBoundary};
MultipartHeader currentHeader_;
bool isValid_{true};
bool isFinished_{false};
};
} // namespace drogon

225
lib/src/RequestStream.cc Normal file
View File

@ -0,0 +1,225 @@
/**
*
* @file RequestStream.cc
* @author Nitromelon
*
* Copyright 2024, Nitromelon. All rights reserved.
* https://github.com/drogonframework/drogon
* Use of this source code is governed by a MIT license
* that can be found in the License file.
*
* Drogon
*
*/
#include "MultipartStreamParser.h"
#include "HttpRequestImpl.h"
#include <drogon/RequestStream.h>
#include <variant>
namespace drogon
{
class RequestStreamImpl : public RequestStream
{
public:
RequestStreamImpl(const HttpRequestImplPtr &req) : weakReq_(req)
{
}
~RequestStreamImpl() override
{
if (isSet_.exchange(true))
{
return;
}
// Drop all data if no reader is set
if (auto req = weakReq_.lock())
{
setHandlerInLoop(req, RequestStreamReader::newNullReader());
}
}
void setStreamReader(RequestStreamReaderPtr reader) override
{
if (isSet_.exchange(true))
{
return;
}
if (auto req = weakReq_.lock())
{
setHandlerInLoop(req, std::move(reader));
}
}
void setHandlerInLoop(const HttpRequestImplPtr &req,
RequestStreamReaderPtr reader)
{
if (!req->isStreamMode())
{
return;
}
auto loop = req->getLoop();
if (loop->isInLoopThread())
{
req->setStreamReader(std::move(reader));
}
else
{
loop->queueInLoop([req, reader = std::move(reader)]() mutable {
req->setStreamReader(std::move(reader));
});
}
}
private:
std::weak_ptr<HttpRequestImpl> weakReq_;
std::atomic_bool isSet_{false};
};
namespace internal
{
RequestStreamPtr createRequestStream(const HttpRequestPtr &req)
{
auto reqImpl = std::static_pointer_cast<HttpRequestImpl>(req);
if (!reqImpl->isStreamMode())
{
return nullptr;
}
return std::make_shared<RequestStreamImpl>(
std::static_pointer_cast<HttpRequestImpl>(req));
}
} // namespace internal
/**
* A default implementation for convenience
*/
class DefaultStreamReader : public RequestStreamReader
{
public:
DefaultStreamReader(StreamDataCallback dataCb,
StreamFinishCallback finishCb)
: dataCb_(std::move(dataCb)), finishCb_(std::move(finishCb))
{
}
void onStreamData(const char *data, size_t length) override
{
dataCb_(data, length);
}
void onStreamFinish(std::exception_ptr ex) override
{
finishCb_(std::move(ex));
}
private:
StreamDataCallback dataCb_;
StreamFinishCallback finishCb_;
};
/**
* Drops all data
*/
class NullStreamReader : public RequestStreamReader
{
public:
void onStreamData(const char *, size_t length) override
{
}
void onStreamFinish(std::exception_ptr) override
{
}
};
/**
* Parse multipart data and return actual content
*/
class MultipartStreamReader : public RequestStreamReader
{
public:
MultipartStreamReader(const std::string &contentType,
MultipartHeaderCallback headerCb,
StreamDataCallback dataCb,
StreamFinishCallback finishCb)
: parser_(contentType),
headerCb_(std::move(headerCb)),
dataCb_(std::move(dataCb)),
finishCb_(std::move(finishCb))
{
}
void onStreamData(const char *data, size_t length) override
{
if (!parser_.isValid() || parser_.isFinished())
{
return;
}
parser_.parse(data, length, headerCb_, dataCb_);
if (!parser_.isValid())
{
// TODO: should we mix stream error and user error?
finishCb_(std::make_exception_ptr(
std::runtime_error("invalid multipart data")));
}
else if (parser_.isFinished())
{
finishCb_({});
}
}
void onStreamFinish(std::exception_ptr ex) override
{
if (!parser_.isValid() || parser_.isFinished())
{
return;
}
if (!ex)
{
finishCb_(std::make_exception_ptr(
std::runtime_error("incomplete multipart data")));
}
else
{
finishCb_(std::move(ex));
}
}
private:
MultipartStreamParser parser_;
MultipartHeaderCallback headerCb_;
StreamDataCallback dataCb_;
StreamFinishCallback finishCb_;
};
RequestStreamReaderPtr RequestStreamReader::newReader(
StreamDataCallback dataCb,
StreamFinishCallback finishCb)
{
return std::make_shared<DefaultStreamReader>(std::move(dataCb),
std::move(finishCb));
}
RequestStreamReaderPtr RequestStreamReader::newNullReader()
{
return std::make_shared<NullStreamReader>();
}
RequestStreamReaderPtr RequestStreamReader::newMultipartReader(
const HttpRequestPtr &req,
MultipartHeaderCallback headerCb,
StreamDataCallback dataCb,
StreamFinishCallback finishCb)
{
return std::make_shared<MultipartStreamReader>(req->getHeader(
"content-type"),
std::move(headerCb),
std::move(dataCb),
std::move(finishCb));
}
} // namespace drogon

View File

@ -52,7 +52,8 @@ if (BUILD_CTL)
integration_test/client/main.cc integration_test/client/main.cc
integration_test/client/WebSocketTest.cc integration_test/client/WebSocketTest.cc
integration_test/client/MultipleWsTest.cc integration_test/client/MultipleWsTest.cc
integration_test/client/HttpPipeliningTest.cc) integration_test/client/HttpPipeliningTest.cc
integration_test/client/RequestStreamTest.cc)
add_executable(integration_test_client ${INTEGRATION_TEST_CLIENT_SOURCES}) add_executable(integration_test_client ${INTEGRATION_TEST_CLIENT_SOURCES})
set(INTEGRATION_TEST_SERVER_SOURCES set(INTEGRATION_TEST_SERVER_SOURCES
@ -75,6 +76,7 @@ if (BUILD_CTL)
integration_test/server/RangeTestController.cc integration_test/server/RangeTestController.cc
integration_test/server/BeginAdviceTest.cc integration_test/server/BeginAdviceTest.cc
integration_test/server/MiddlewareTest.cc integration_test/server/MiddlewareTest.cc
integration_test/server/RequestStreamTestCtrl.cc
integration_test/server/main.cc) integration_test/server/main.cc)
if(DROGON_CXX_STANDARD GREATER_EQUAL 20 AND HAS_COROUTINE) if(DROGON_CXX_STANDARD GREATER_EQUAL 20 AND HAS_COROUTINE)

View File

@ -0,0 +1,143 @@
#include <drogon/HttpClient.h>
#include <drogon/drogon_test.h>
#include <trantor/net/TcpClient.h>
#include <chrono>
#include <string>
#include <iostream>
#include <fstream>
using namespace drogon;
template <typename T>
void checkStreamRequest(T &&TEST_CTX,
trantor::EventLoop *loop,
const trantor::InetAddress &addr,
const std::vector<std::string_view> &dataToSend,
std::string_view expectedResp)
{
auto tcpClient = std::make_shared<trantor::TcpClient>(loop, addr, "test");
std::promise<void> promise;
auto respString = std::make_shared<std::string>();
tcpClient->setMessageCallback(
[respString](const trantor::TcpConnectionPtr &conn,
trantor::MsgBuffer *buf) {
respString->append(buf->read(buf->readableBytes()));
});
tcpClient->setConnectionCallback(
[TEST_CTX, &promise, respString, dataToSend, expectedResp](
const trantor::TcpConnectionPtr &conn) {
if (conn->disconnected())
{
LOG_INFO << "Disconnected from server";
CHECK(respString->substr(0, expectedResp.size()) ==
expectedResp);
promise.set_value();
return;
}
LOG_INFO << "Connected to server";
CHECK(conn->connected());
for (auto &data : dataToSend)
{
conn->send(data.data(), data.size());
std::this_thread::sleep_for(std::chrono::milliseconds(100));
}
conn->shutdown();
});
tcpClient->connect();
promise.get_future().wait();
}
DROGON_TEST(RequestStreamTest)
{
const std::string ip = "127.0.0.1";
const uint16_t port = 8848;
auto client = HttpClient::newHttpClient(ip, port);
HttpRequestPtr req;
bool enabled = false;
req = HttpRequest::newHttpRequest();
req->setPath("/stream_status");
{
auto [res, resp] = client->sendRequest(req);
REQUIRE(res == ReqResult::Ok);
REQUIRE(resp->statusCode() == k200OK);
if (resp->body() == "enabled")
{
enabled = true;
}
else
{
LOG_INFO << "Server does not enable request stream.";
}
}
req = HttpRequest::newHttpRequest();
req->setPath("/stream_chunk");
req->setMethod(Post);
req->setBody("1234567890");
client->sendRequest(req,
[TEST_CTX, enabled](ReqResult r,
const HttpResponsePtr &resp) {
REQUIRE(r == ReqResult::Ok);
if (enabled)
{
CHECK(resp->statusCode() == k200OK);
CHECK(resp->body() == "1234567890");
}
else
{
CHECK(resp->statusCode() == k400BadRequest);
CHECK(resp->body() == "no stream");
}
});
if (!enabled)
{
return;
}
LOG_INFO << "Test request stream";
std::string filePath = "./中文.txt";
std::ifstream file(filePath);
std::stringstream content;
REQUIRE(file.is_open());
content << file.rdbuf();
req = HttpRequest::newFileUploadRequest({UploadFile{filePath}});
req->setPath("/stream_upload_echo");
req->setMethod(Post);
client->sendRequest(req,
[TEST_CTX,
content = content.str()](ReqResult r,
const HttpResponsePtr &resp) {
CHECK(r == ReqResult::Ok);
CHECK(resp->statusCode() == k200OK);
CHECK(resp->body() == content);
});
checkStreamRequest(TEST_CTX,
client->getLoop(),
trantor::InetAddress{ip, port},
// Good request
{"POST /stream_chunk HTTP/1.1\r\n"
"Transfer-Encoding: chunked\r\n\r\n",
"1\r\nz\r\n",
"2\r\nzz\r\n0\r\n\r\n"},
// Good response
"HTTP/1.1 200 OK\r\n");
checkStreamRequest(TEST_CTX,
client->getLoop(),
trantor::InetAddress{ip, port},
// Bad request
{"POST /stream_chunk HTTP/1.1\r\n"
"Transfer-Encoding: chunked\r\n\r\n",
"1\r\nz\r\n",
"1\r\nzz\r\n",
"0\r\n\r\n"},
// Bad response
"HTTP/1.1 400 Bad Request\r\n");
}

View File

@ -0,0 +1,150 @@
#include <fstream>
#include <drogon/HttpController.h>
#include <drogon/HttpRequest.h>
#include <drogon/RequestStream.h>
using namespace drogon;
class RequestStreamTestCtrl : public HttpController<RequestStreamTestCtrl>
{
public:
METHOD_LIST_BEGIN
ADD_METHOD_TO(RequestStreamTestCtrl::stream_status, "/stream_status", Get);
ADD_METHOD_TO(RequestStreamTestCtrl::stream_chunk, "/stream_chunk", Post);
ADD_METHOD_TO(RequestStreamTestCtrl::stream_upload_echo,
"/stream_upload_echo",
Post);
METHOD_LIST_END
void stream_status(
const HttpRequestPtr &,
std::function<void(const HttpResponsePtr &)> &&callback) const
{
auto resp = HttpResponse::newHttpResponse();
if (app().isRequestStreamEnabled())
{
resp->setBody("enabled");
}
else
{
resp->setBody("not enabled");
}
callback(resp);
}
void stream_chunk(
const HttpRequestPtr &,
RequestStreamPtr &&stream,
std::function<void(const HttpResponsePtr &)> &&callback) const
{
if (!stream)
{
LOG_INFO << "stream mode is not enabled";
auto resp = HttpResponse::newHttpResponse();
resp->setStatusCode(k400BadRequest);
resp->setBody("no stream");
callback(resp);
return;
}
auto respBody = std::make_shared<std::string>();
auto reader = RequestStreamReader::newReader(
[respBody](const char *data, size_t length) {
respBody->append(data, length);
},
[respBody, callback = std::move(callback)](std::exception_ptr ex) {
auto resp = HttpResponse::newHttpResponse();
if (ex)
{
try
{
std::rethrow_exception(std::move(ex));
}
catch (const std::exception &e)
{
LOG_ERROR << "stream error: " << e.what();
}
resp->setStatusCode(k400BadRequest);
resp->setBody("stream error");
callback(resp);
}
else
{
resp->setBody(*respBody);
callback(resp);
}
});
stream->setStreamReader(std::move(reader));
}
void stream_upload_echo(
const HttpRequestPtr &req,
RequestStreamPtr &&stream,
std::function<void(const HttpResponsePtr &)> &&callback) const
{
assert(drogon::app().isRequestStreamEnabled() || !stream);
if (!stream)
{
LOG_INFO << "stream mode is not enabled";
auto resp = HttpResponse::newHttpResponse();
resp->setStatusCode(k400BadRequest);
resp->setBody("no stream");
callback(resp);
return;
}
if (req->contentType() != CT_MULTIPART_FORM_DATA)
{
auto resp = HttpResponse::newHttpResponse();
resp->setStatusCode(k400BadRequest);
resp->setBody("should upload multipart");
callback(resp);
return;
}
struct Context
{
std::string firstFileContent;
size_t currentFileIndex_{0};
};
auto ctx = std::make_shared<Context>();
auto reader = RequestStreamReader::newMultipartReader(
req,
[ctx](MultipartHeader &&header) { ctx->currentFileIndex_++; },
[ctx](const char *data, size_t length) {
if (ctx->currentFileIndex_ == 1)
{
ctx->firstFileContent.append(data, length);
}
},
[ctx, callback = std::move(callback)](std::exception_ptr ex) {
auto resp = HttpResponse::newHttpResponse();
if (ex)
{
try
{
std::rethrow_exception(std::move(ex));
}
catch (const StreamError &e)
{
LOG_ERROR << "stream error: " << e.what();
}
catch (const std::exception &e)
{
LOG_ERROR << "multipart error: " << e.what();
}
resp->setStatusCode(k400BadRequest);
resp->setBody("error\n");
callback(resp);
}
else
{
resp->setBody(ctx->firstFileContent);
callback(resp);
}
});
stream->setStreamReader(std::move(reader));
}
};

View File

@ -1,6 +1,7 @@
#include <drogon/MultiPart.h> #include <drogon/MultiPart.h>
#include <drogon/drogon_test.h> #include <drogon/drogon_test.h>
#include <drogon/HttpRequest.h> #include <drogon/HttpRequest.h>
#include "../../lib/src/MultipartStreamParser.h"
DROGON_TEST(MultiPartParser) DROGON_TEST(MultiPartParser)
{ {
@ -60,3 +61,72 @@ DROGON_TEST(MultiPartParser)
CHECK(parser4.getParameters().size() == 1); CHECK(parser4.getParameters().size() == 1);
CHECK(parser4.getParameters().at("some;key") == "Hello; World"); CHECK(parser4.getParameters().at("some;key") == "Hello; World");
} }
DROGON_TEST(MultiPartStreamParser)
{
static const std::string ct = "multipart/form-data; boundary=\"12345\"";
static const std::string_view data =
"--12345\r\n"
"Content-Disposition: form-data; name=\"key1\"; filename=\"file1\"\r\n"
"\r\n"
"Hello; World\r\n"
"--12345\r\n"
"Content-Disposition: form-data; name=\"key2\"\r\n"
"\r\n"
"value2\r\n"
"--12345--";
struct Entry
{
drogon::MultipartHeader header;
std::string value;
std::string fileContent;
};
auto check = [TEST_CTX](size_t step) {
drogon::MultipartStreamParser parser(ct);
auto entries = std::make_shared<std::vector<Entry>>();
auto headerCb = [TEST_CTX, entries](drogon::MultipartHeader hdr) {
entries->emplace_back(Entry{std::move(hdr)});
};
auto dataCb = [TEST_CTX, entries](const char *data, size_t length) {
MANDATE(!entries->empty());
if (length == 0)
{
// Field finished
return;
}
if (entries->back().header.filename.empty())
{
entries->back().value.append(data, length);
}
else
{
entries->back().fileContent.append(data, length);
}
};
size_t i = 0;
while (i < data.length() && parser.isValid())
{
size_t end = i + step < data.length() ? i + step : data.length();
parser.parse(data.data() + i, end - i, headerCb, dataCb);
CHECK(parser.isValid());
i = end;
}
MANDATE(i == data.length());
MANDATE(parser.isFinished());
MANDATE(entries->size() == 2);
CHECK(entries->at(0).header.name == "key1");
CHECK(entries->at(0).fileContent == "Hello; World");
CHECK(entries->at(1).header.name == "key2");
CHECK(entries->at(1).value == "value2");
};
check(1);
check(3);
check(7);
check(20);
}

11
test.sh
View File

@ -34,6 +34,12 @@ function do_integration_test()
sed -i -e "s/\"threads_num.*$/\"threads_num\": 0\,/" config.example.json sed -i -e "s/\"threads_num.*$/\"threads_num\": 0\,/" config.example.json
sed -i -e "s/\"use_brotli.*$/\"use_brotli\": true\,/" config.example.json sed -i -e "s/\"use_brotli.*$/\"use_brotli\": true\,/" config.example.json
if [ "$1" = "stream_mode" ]; then
sed -i -e "s/\"enable_request_stream.*$/\"enable_request_stream\": true\,/" config.example.json
else
sed -i -e "s/\"enable_request_stream.*$/\"enable_request_stream\": false\,/" config.example.json
fi
if [ ! -f "integration_test_client" ]; then if [ ! -f "integration_test_client" ]; then
echo "Build failed" echo "Build failed"
exit -1 exit -1
@ -48,11 +54,11 @@ function do_integration_test()
sleep 4 sleep 4
echo "Running the integration test" echo "Running the integration test $1"
./integration_test_client -s ./integration_test_client -s
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then
echo "Integration test failed" echo "Integration test failed $1"
exit -1 exit -1
fi fi
@ -245,6 +251,7 @@ then
echo "Warning: No drogon_ctl, skip integration test and drogon_ctl test" echo "Warning: No drogon_ctl, skip integration test and drogon_ctl test"
else else
do_integration_test do_integration_test
do_integration_test stream_mode
do_drogon_ctl_test do_drogon_ctl_test
fi fi