Allow sync advice to be callable on websocket requests (#1733)

Co-authored-by: Nitromelon <hwc14@qq.com>
This commit is contained in:
Muhammad 2023-08-24 10:09:31 +03:00 committed by GitHub
parent 54b137d64f
commit f8f5283dff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -182,36 +182,7 @@ void HttpServer::onMessage(const TcpConnectionPtr &conn, MsgBuffer *buf)
req->setCreationDate(trantor::Date::date()); req->setCreationDate(trantor::Date::date());
req->setSecure(conn->isSSLConnection()); req->setSecure(conn->isSSLConnection());
req->setPeerCertificate(conn->peerCertificate()); req->setPeerCertificate(conn->peerCertificate());
if (requestParser->firstReq() && isWebSocket(req)) requests.push_back(req);
{
auto wsConn = std::make_shared<WebSocketConnectionImpl>(conn);
wsConn->setPingMessage("", std::chrono::seconds{30});
newWebsocketCallback_(
req,
[conn, wsConn, requestParser, this, req](
const HttpResponsePtr &resp) mutable {
if (conn->connected())
{
for (auto &advice : preSendingAdvices_)
{
advice(req, resp);
}
if (resp->statusCode() == k101SwitchingProtocols)
{
requestParser->setWebsockConnection(wsConn);
}
auto httpString =
((HttpResponseImpl *)resp.get())->renderToBuffer();
conn->send(httpString);
COZ_PROGRESS
}
},
wsConn);
}
else
{
requests.push_back(req);
}
requestParser->reset(); requestParser->reset();
} }
onRequests(conn, requests, requestParser); onRequests(conn, requests, requestParser);
@ -248,6 +219,55 @@ void HttpServer::onRequests(
{ {
if (requests.empty()) if (requests.empty())
return; return;
// will only be checked for the first request
if (requestParser->firstReq() && requests.size() == 1 &&
isWebSocket(requests[0]))
{
auto &req = requests[0];
if (passSyncAdvices(req,
requestParser,
syncAdvices_,
false /* Not pipelined */,
false /* Not HEAD */))
{
auto wsConn = std::make_shared<WebSocketConnectionImpl>(conn);
wsConn->setPingMessage("", std::chrono::seconds{30});
newWebsocketCallback_(
req,
[conn, wsConn, requestParser, this, req](
const HttpResponsePtr &resp) mutable {
if (conn->connected())
{
for (auto &advice : preSendingAdvices_)
{
advice(req, resp);
}
if (resp->statusCode() == k101SwitchingProtocols)
{
requestParser->setWebsockConnection(wsConn);
}
auto httpString =
((HttpResponseImpl *)resp.get())->renderToBuffer();
conn->send(httpString);
COZ_PROGRESS
}
},
wsConn);
return;
}
// flush response for not passing sync advices
if (conn->connected() && !requestParser->getResponseBuffer().empty())
{
sendResponses(conn,
requestParser->getResponseBuffer(),
requestParser->getBuffer());
requestParser->getResponseBuffer().clear();
}
return;
}
if (HttpAppFrameworkImpl::instance().keepaliveRequestsNumber() > 0 && if (HttpAppFrameworkImpl::instance().keepaliveRequestsNumber() > 0 &&
requestParser->numberOfRequestsParsed() >= requestParser->numberOfRequestsParsed() >=
HttpAppFrameworkImpl::instance().keepaliveRequestsNumber()) HttpAppFrameworkImpl::instance().keepaliveRequestsNumber())
@ -283,8 +303,7 @@ void HttpServer::onRequests(
requestParser->pushRequestToPipelining(req, isHeadMethod); requestParser->pushRequestToPipelining(req, isHeadMethod);
reqPipelined = true; reqPipelined = true;
} }
if (!syncAdvices_.empty() && if (!passSyncAdvices(
!passSyncAdvices(
req, requestParser, syncAdvices_, reqPipelined, isHeadMethod)) req, requestParser, syncAdvices_, reqPipelined, isHeadMethod))
{ {
continue; continue;
@ -678,10 +697,14 @@ void HttpServer::sendResponses(
static inline bool isWebSocket(const HttpRequestImplPtr &req) static inline bool isWebSocket(const HttpRequestImplPtr &req)
{ {
if (req->method() != Get)
return false;
auto &headers = req->headers(); auto &headers = req->headers();
if (headers.find("upgrade") == headers.end() || if (headers.find("upgrade") == headers.end() ||
headers.find("connection") == headers.end()) headers.find("connection") == headers.end())
return false; return false;
auto connectionField = req->getHeaderBy("connection"); auto connectionField = req->getHeaderBy("connection");
std::transform(connectionField.begin(), std::transform(connectionField.begin(),
connectionField.end(), connectionField.end(),