Merge pull request #117 from an-tao/dev

Modify the implementation of WebSocket
This commit is contained in:
An Tao 2019-04-08 17:10:33 +08:00 committed by GitHub
commit f13f330f3e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 483 additions and 469 deletions

View File

@ -35,7 +35,7 @@ int main(int argc, char *argv[])
if (r == ReqResult::Ok)
{
std::cout << "ws connected!" << std::endl;
wsPtr->getConnection()->send("hello");
wsPtr->getConnection()->send("hello!");
}
else
{

View File

@ -34,7 +34,7 @@ class WebSocketClient
{
public:
/// Get the WebSocket connection that is typically used to send messages.
virtual const WebSocketConnectionPtr &getConnection() = 0;
virtual WebSocketConnectionPtr getConnection() = 0;
/// Set messages handler. When a message is recieved from the server, the @param callback is called.
virtual void setMessageHandler(const std::function<void(std::string &&message, const WebSocketClientPtr &, const WebSocketMessageType &)> &callback) = 0;

View File

@ -317,8 +317,6 @@ void HttpAppFrameworkImpl::run()
}
serverPtr->setHttpAsyncCallback(std::bind(&HttpAppFrameworkImpl::onAsyncRequest, this, _1, _2));
serverPtr->setNewWebsocketCallback(std::bind(&HttpAppFrameworkImpl::onNewWebsockRequest, this, _1, _2, _3));
serverPtr->setWebsocketMessageCallback(std::bind(&HttpAppFrameworkImpl::onWebsockMessage, this, _1, _2, _3));
serverPtr->setDisconnectWebsocketCallback(std::bind(&HttpAppFrameworkImpl::onWebsockDisconnect, this, _1));
serverPtr->setConnectionCallback(std::bind(&HttpAppFrameworkImpl::onConnection, this, _1));
serverPtr->kickoffIdleConnections(_idleConnectionTimeout);
serverPtr->start();
@ -356,8 +354,6 @@ void HttpAppFrameworkImpl::run()
serverPtr->setIoLoopNum(_threadNum);
serverPtr->setHttpAsyncCallback(std::bind(&HttpAppFrameworkImpl::onAsyncRequest, this, _1, _2));
serverPtr->setNewWebsocketCallback(std::bind(&HttpAppFrameworkImpl::onNewWebsockRequest, this, _1, _2, _3));
serverPtr->setWebsocketMessageCallback(std::bind(&HttpAppFrameworkImpl::onWebsockMessage, this, _1, _2, _3));
serverPtr->setDisconnectWebsocketCallback(std::bind(&HttpAppFrameworkImpl::onWebsockDisconnect, this, _1));
serverPtr->setConnectionCallback(std::bind(&HttpAppFrameworkImpl::onConnection, this, _1));
serverPtr->kickoffIdleConnections(_idleConnectionTimeout);
serverPtr->start();
@ -472,17 +468,7 @@ void HttpAppFrameworkImpl::createDbClients(const std::vector<trantor::EventLoop
}
}
#endif
void HttpAppFrameworkImpl::onWebsockDisconnect(const WebSocketConnectionPtr &wsConnPtr)
{
auto wsConnImplPtr = std::dynamic_pointer_cast<WebSocketConnectionImpl>(wsConnPtr);
assert(wsConnImplPtr);
auto ctrl = wsConnImplPtr->controller();
if (ctrl)
{
ctrl->handleConnectionClosed(wsConnPtr);
wsConnImplPtr->setController(WebSocketControllerBasePtr());
}
}
void HttpAppFrameworkImpl::onConnection(const TcpConnectionPtr &conn)
{
static std::mutex mtx;
@ -540,16 +526,6 @@ void HttpAppFrameworkImpl::onConnection(const TcpConnectionPtr &conn)
}
}
void HttpAppFrameworkImpl::onWebsockMessage(const WebSocketConnectionPtr &wsConnPtr, std::string &&message, const WebSocketMessageType &type)
{
auto wsConnImplPtr = std::dynamic_pointer_cast<WebSocketConnectionImpl>(wsConnPtr);
assert(wsConnImplPtr);
auto ctrl = wsConnImplPtr->controller();
if (ctrl)
{
ctrl->handleNewMessage(wsConnPtr, std::move(message), type);
}
}
void HttpAppFrameworkImpl::setUploadPath(const std::string &uploadPath)
{
@ -571,7 +547,7 @@ void HttpAppFrameworkImpl::setUploadPath(const std::string &uploadPath)
}
void HttpAppFrameworkImpl::onNewWebsockRequest(const HttpRequestImplPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const WebSocketConnectionPtr &wsConnPtr)
const WebSocketConnectionImplPtr &wsConnPtr)
{
_websockCtrlsRouter.route(req, std::move(callback), wsConnPtr);
}

View File

@ -18,7 +18,7 @@
#include "HttpResponseImpl.h"
#include "HttpClientImpl.h"
#include "SharedLibManager.h"
#include "WebSockectConnectionImpl.h"
#include "WebSocketConnectionImpl.h"
#include "HttpControllersRouter.h"
#include "HttpSimpleControllersRouter.h"
#include "WebsocketControllersRouter.h"
@ -177,9 +177,7 @@ class HttpAppFrameworkImpl : public HttpAppFramework
void onAsyncRequest(const HttpRequestImplPtr &req, std::function<void(const HttpResponsePtr &)> &&callback);
void onNewWebsockRequest(const HttpRequestImplPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const WebSocketConnectionPtr &wsConnPtr);
void onWebsockMessage(const WebSocketConnectionPtr &wsConnPtr, std::string &&message, const WebSocketMessageType &type);
void onWebsockDisconnect(const WebSocketConnectionPtr &wsConnPtr);
const WebSocketConnectionImplPtr &wsConnPtr);
void onConnection(const TcpConnectionPtr &conn);
void addHttpPath(const std::string &path,
const internal::HttpBinderBasePtr &binder,

View File

@ -15,7 +15,7 @@
#pragma once
#include "HttpRequestImpl.h"
#include "WebSockectConnectionImpl.h"
#include "WebSocketConnectionImpl.h"
#include <trantor/utils/MsgBuffer.h>
#include <drogon/HttpResponse.h>
#include <deque>

View File

@ -51,7 +51,7 @@ static void defaultHttpAsyncCallback(const HttpRequestPtr &, std::function<void(
static void defaultWebSockAsyncCallback(const HttpRequestPtr &,
std::function<void(const HttpResponsePtr &resp)> &&callback,
const WebSocketConnectionPtr &wsConnPtr)
const WebSocketConnectionImplPtr &wsConnPtr)
{
auto resp = HttpResponse::newNotFoundResponse();
resp->setCloseConnection(true);
@ -104,7 +104,7 @@ void HttpServer::onConnection(const TcpConnectionPtr &conn)
{
if (requestParser->webSocketConn())
{
_disconnectWebsocketCallback(requestParser->webSocketConn());
requestParser->webSocketConn()->onClose();
}
#if (CXX_STD > 14)
conn->getMutableContext()->reset(); //reset(): since c++17
@ -125,33 +125,7 @@ void HttpServer::onMessage(const TcpConnectionPtr &conn,
if (requestParser->webSocketConn())
{
//Websocket payload
while (buf->readableBytes() > 0)
{
std::string message;
WebSocketMessageType type;
auto success = parseWebsockMessage(buf, message, type);
if (success)
{
if (type == WebSocketMessageType::Ping)
{
//ping
requestParser->webSocketConn()->send(message, WebSocketMessageType::Pong);
}
else if (type == WebSocketMessageType::Close)
{
//close
conn->shutdown();
}
_webSocketMessageCallback(requestParser->webSocketConn(), std::move(message), type);
}
else
{
//Websock error!
conn->shutdown();
return;
}
}
return;
requestParser->webSocketConn()->onNewMessage(conn, buf);
}
else
{
@ -182,6 +156,7 @@ void HttpServer::onMessage(const TcpConnectionPtr &conn,
if (resp->statusCode() == k101SwitchingProtocols)
{
requestParser->setWebsockConnection(wsConn);
}
auto httpString = std::dynamic_pointer_cast<HttpResponseImpl>(resp)->renderToString();
conn->send(httpString);

View File

@ -14,9 +14,10 @@
#pragma once
#include "WebSockectConnectionImpl.h"
#include "WebSocketConnectionImpl.h"
#include "HttpRequestImpl.h"
#include <drogon/config.h>
#include <drogon/WebSocketController.h>
#include <trantor/net/TcpServer.h>
#include <trantor/net/callbacks.h>
#include <trantor/utils/NonCopyable.h>
@ -35,12 +36,8 @@ class HttpServer : trantor::NonCopyable
typedef std::function<void(const HttpRequestImplPtr &, std::function<void(const HttpResponsePtr &)> &&)> HttpAsyncCallback;
typedef std::function<void(const HttpRequestImplPtr &,
std::function<void(const HttpResponsePtr &)> &&,
const WebSocketConnectionPtr &)>
const WebSocketConnectionImplPtr &)>
WebSocketNewAsyncCallback;
typedef std::function<void(const WebSocketConnectionPtr &)>
WebSocketDisconnetCallback;
typedef std::function<void(const WebSocketConnectionPtr &, std::string &&, const WebSocketMessageType &)>
WebSocketMessageCallback;
HttpServer(EventLoop *loop,
const InetAddress &listenAddr,
@ -58,14 +55,6 @@ class HttpServer : trantor::NonCopyable
{
_newWebsocketCallback = cb;
}
void setDisconnectWebsocketCallback(const WebSocketDisconnetCallback &cb)
{
_disconnectWebsocketCallback = cb;
}
void setWebsocketMessageCallback(const WebSocketMessageCallback &cb)
{
_webSocketMessageCallback = cb;
}
void setConnectionCallback(const ConnectionCallback &cb)
{
_connectionCallback = cb;
@ -104,8 +93,6 @@ class HttpServer : trantor::NonCopyable
trantor::TcpServer _server;
HttpAsyncCallback _httpAsyncCallback;
WebSocketNewAsyncCallback _newWebsocketCallback;
WebSocketDisconnetCallback _disconnectWebsocketCallback;
WebSocketMessageCallback _webSocketMessageCallback;
trantor::ConnectionCallback _connectionCallback;
};

View File

@ -375,110 +375,4 @@ const string_view &statusCodeToString(int code)
}
}
// Return false if any error
bool parseWebsockMessage(trantor::MsgBuffer *buffer, std::string &message, WebSocketMessageType &type)
{
assert(message.empty());
if (buffer->readableBytes() >= 2)
{
unsigned char opcode = (*buffer)[0] & 0x0f;
switch (opcode)
{
case 1:
type = WebSocketMessageType::Text;
break;
case 2:
type = WebSocketMessageType::Binary;
break;
case 8:
type = WebSocketMessageType::Close;
break;
case 9:
type = WebSocketMessageType::Ping;
break;
case 10:
type = WebSocketMessageType::Pong;
break;
default:
type = WebSocketMessageType::Unknown;
break;
}
auto secondByte = (*buffer)[1];
size_t length = secondByte & 127;
int isMasked = (secondByte & 0x80);
if (isMasked != 0)
{
LOG_TRACE << "data encoded!";
}
else
LOG_TRACE << "plain data";
size_t indexFirstMask = 2;
if (length == 126)
{
indexFirstMask = 4;
}
else if (length == 127)
{
indexFirstMask = 10;
}
if (indexFirstMask > 2 && buffer->readableBytes() >= indexFirstMask)
{
if (indexFirstMask == 4)
{
length = (unsigned char)(*buffer)[2];
length = (length << 8) + (unsigned char)(*buffer)[3];
}
else if (indexFirstMask == 10)
{
length = (unsigned char)(*buffer)[2];
length = (length << 8) + (unsigned char)(*buffer)[3];
length = (length << 8) + (unsigned char)(*buffer)[4];
length = (length << 8) + (unsigned char)(*buffer)[5];
length = (length << 8) + (unsigned char)(*buffer)[6];
length = (length << 8) + (unsigned char)(*buffer)[7];
length = (length << 8) + (unsigned char)(*buffer)[8];
length = (length << 8) + (unsigned char)(*buffer)[9];
// length=*((uint64_t *)(buffer->peek()+2));
// length=ntohll(length);
}
else
{
LOG_ERROR << "Websock parsing failed!";
return false;
}
}
if (isMasked != 0)
{
if (buffer->readableBytes() >= (indexFirstMask + 4 + length))
{
auto masks = buffer->peek() + indexFirstMask;
int indexFirstDataByte = indexFirstMask + 4;
auto rawData = buffer->peek() + indexFirstDataByte;
message.resize(length);
for (size_t i = 0; i < length; i++)
{
message[i] = (rawData[i] ^ masks[i % 4]);
}
buffer->retrieve(indexFirstMask + 4 + length);
LOG_TRACE << "got message len=" << message.length();
return true;
}
}
else
{
if (buffer->readableBytes() >= (indexFirstMask + length))
{
auto rawData = buffer->peek() + indexFirstMask;
message.append(rawData, length);
buffer->retrieve(indexFirstMask + length);
LOG_TRACE << "got message len=" << message.length();
return true;
}
}
}
return true;
}
} // namespace drogon

View File

@ -31,6 +31,5 @@ namespace drogon
const string_view &webContentTypeToString(ContentType contenttype);
const string_view &statusCodeToString(int code);
bool parseWebsockMessage(trantor::MsgBuffer *buffer, std::string &message, WebSocketMessageType &type);
} // namespace drogon

View File

@ -1,178 +0,0 @@
/**
*
* WebSocketConnectionImpl.cc
* An Tao
*
* Copyright 2018, An Tao. All rights reserved.
* https://github.com/an-tao/drogon
* Use of this source code is governed by a MIT license
* that can be found in the License file.
*
* Drogon
*
*/
#include "WebSockectConnectionImpl.h"
#include <trantor/net/TcpConnection.h>
#include <thread>
using namespace drogon;
WebSocketConnectionImpl::WebSocketConnectionImpl(const trantor::TcpConnectionPtr &conn, bool isServer)
: _tcpConn(conn),
_localAddr(conn->localAddr()),
_peerAddr(conn->peerAddr()),
_isServer(isServer)
{
}
void WebSocketConnectionImpl::send(const char *msg, uint64_t len, const WebSocketMessageType &type)
{
unsigned char opcode;
if (type == WebSocketMessageType::Text)
opcode = 1;
else if (type == WebSocketMessageType::Binary)
opcode = 2;
else if (type == WebSocketMessageType::Close)
opcode = 8;
else if (type == WebSocketMessageType::Ping)
opcode = 9;
else if (type == WebSocketMessageType::Pong)
opcode = 10;
else
{
opcode = 0;
assert(0);
}
sendWsData(msg, len, opcode);
}
void WebSocketConnectionImpl::sendWsData(const char *msg, size_t len, unsigned char opcode)
{
LOG_TRACE << "send " << len << " bytes";
auto conn = _tcpConn.lock();
if (conn)
{
//Format the frame
std::string bytesFormatted;
bytesFormatted.resize(len + 10);
bytesFormatted[0] = char(0x80 | (opcode & 0x0f));
int indexStartRawData = -1;
if (len <= 125)
{
bytesFormatted[1] = len;
indexStartRawData = 2;
}
else if (len <= 65535)
{
bytesFormatted[1] = 126;
bytesFormatted[2] = ((len >> 8) & 255);
bytesFormatted[3] = ((len)&255);
LOG_TRACE << "bytes[2]=" << (size_t)bytesFormatted[2];
LOG_TRACE << "bytes[3]=" << (size_t)bytesFormatted[3];
indexStartRawData = 4;
}
else
{
bytesFormatted[1] = 127;
bytesFormatted[2] = ((len >> 56) & 255);
bytesFormatted[3] = ((len >> 48) & 255);
bytesFormatted[4] = ((len >> 40) & 255);
bytesFormatted[5] = ((len >> 32) & 255);
bytesFormatted[6] = ((len >> 24) & 255);
bytesFormatted[7] = ((len >> 16) & 255);
bytesFormatted[8] = ((len >> 8) & 255);
bytesFormatted[9] = ((len)&255);
indexStartRawData = 10;
}
if (!_isServer)
{
//Add masking key;
static std::once_flag once;
std::call_once(once, []() {
std::srand(time(nullptr));
});
int random = std::rand();
bytesFormatted[1] = (bytesFormatted[1] | 0x80);
bytesFormatted.resize(indexStartRawData + 4 + len);
*((int *)&bytesFormatted[indexStartRawData]) = random;
for (size_t i = 0; i < len; i++)
{
bytesFormatted[indexStartRawData + 4 + i] = (msg[i] ^ bytesFormatted[indexStartRawData + (i % 4)]);
}
}
else
{
bytesFormatted.resize(indexStartRawData);
bytesFormatted.append(msg, len);
}
conn->send(bytesFormatted);
}
}
void WebSocketConnectionImpl::send(const std::string &msg, const WebSocketMessageType &type)
{
send(msg.data(), msg.length(), type);
}
const trantor::InetAddress &WebSocketConnectionImpl::localAddr() const
{
return _localAddr;
}
const trantor::InetAddress &WebSocketConnectionImpl::peerAddr() const
{
return _peerAddr;
}
bool WebSocketConnectionImpl::connected() const
{
auto conn = _tcpConn.lock();
if (conn)
{
return conn->connected();
}
return false;
}
bool WebSocketConnectionImpl::disconnected() const
{
auto conn = _tcpConn.lock();
if (conn)
{
return conn->disconnected();
}
return true;
}
void WebSocketConnectionImpl::WebSocketConnectionImpl::shutdown()
{
auto conn = _tcpConn.lock();
if (conn)
{
conn->shutdown();
}
}
void WebSocketConnectionImpl::WebSocketConnectionImpl::forceClose()
{
auto conn = _tcpConn.lock();
if (conn)
{
conn->forceClose();
}
}
void WebSocketConnectionImpl::setContext(const any &context)
{
_context = context;
}
const any &WebSocketConnectionImpl::WebSocketConnectionImpl::getContext() const
{
return _context;
}
any *WebSocketConnectionImpl::WebSocketConnectionImpl::getMutableContext()
{
return &_context;
}

View File

@ -1,63 +0,0 @@
/**
*
* WebSocketConnectionImpl.h
* An Tao
*
* Copyright 2018, An Tao. All rights reserved.
* https://github.com/an-tao/drogon
* Use of this source code is governed by a MIT license
* that can be found in the License file.
*
* Drogon
*
*/
#pragma once
#include <drogon/WebSocketConnection.h>
#include <drogon/WebSocketController.h>
namespace drogon
{
class WebSocketConnectionImpl : public WebSocketConnection
{
public:
explicit WebSocketConnectionImpl(const trantor::TcpConnectionPtr &conn, bool isServer = true);
virtual void send(const char *msg, uint64_t len, const WebSocketMessageType &type = WebSocketMessageType::Text) override;
virtual void send(const std::string &msg, const WebSocketMessageType &type = WebSocketMessageType::Text) override;
virtual const trantor::InetAddress &localAddr() const override;
virtual const trantor::InetAddress &peerAddr() const override;
virtual bool connected() const override;
virtual bool disconnected() const override;
virtual void shutdown() override; //close write
virtual void forceClose() override; //close
virtual void setContext(const any &context) override;
virtual const any &getContext() const override;
virtual any *getMutableContext() override;
void setController(const WebSocketControllerBasePtr &ctrl)
{
_ctrlPtr = ctrl;
}
WebSocketControllerBasePtr controller()
{
return _ctrlPtr;
}
private:
std::weak_ptr<trantor::TcpConnection> _tcpConn;
trantor::InetAddress _localAddr;
trantor::InetAddress _peerAddr;
WebSocketControllerBasePtr _ctrlPtr;
any _context;
bool _isServer = true;
void sendWsData(const char *msg, size_t len, unsigned char opcode);
};
typedef std::shared_ptr<WebSocketConnectionImpl> WebSocketConnectionImplPtr;
} // namespace drogon

View File

@ -121,6 +121,7 @@ void WebSocketClientImpl::connectToServerInLoop()
{
LOG_TRACE << "connection disconnect";
thisPtr->_connectionClosedCallback(thisPtr);
thisPtr->_websockConnPtr.reset();
thisPtr->_loop->runAfter(1.0, [thisPtr]() {
thisPtr->reconnect();
});
@ -154,36 +155,8 @@ void WebSocketClientImpl::connectToServerInLoop()
void WebSocketClientImpl::onRecvWsMessage(const trantor::TcpConnectionPtr &connPtr, trantor::MsgBuffer *msgBuffer)
{
std::string message;
WebSocketMessageType type;
auto success = parseWebsockMessage(msgBuffer, message, type);
if (success)
{
if (type == WebSocketMessageType::Close)
{
//close
connPtr->shutdown();
}
else if (type == WebSocketMessageType::Ping)
{
//ping
if (_websockConnPtr)
{
_websockConnPtr->send(message, WebSocketMessageType::Pong);
}
}
_messageCallback(std::move(message), shared_from_this(), type);
}
else
{
//Websock error!
connPtr->shutdown();
auto thisPtr = shared_from_this();
_loop->runAfter(1.0, [thisPtr]() {
thisPtr->reconnect();
});
return;
}
assert(_websockConnPtr);
_websockConnPtr->onNewMessage(connPtr, msgBuffer);
}
void WebSocketClientImpl::onRecvMessage(const trantor::TcpConnectionPtr &connPtr, trantor::MsgBuffer *msgBuffer)
@ -234,6 +207,12 @@ void WebSocketClientImpl::onRecvMessage(const trantor::TcpConnectionPtr &connPtr
_upgraded = true;
_websockConnPtr = std::make_shared<WebSocketConnectionImpl>(connPtr, false);
auto thisPtr = shared_from_this();
_websockConnPtr->setMessageCallback([thisPtr](std::string &&message,
const WebSocketConnectionImplPtr &connPtr,
const WebSocketMessageType &type) {
thisPtr->_messageCallback(std::move(message), thisPtr, type);
});
_requestCallback(ReqResult::Ok, resp, shared_from_this());
if (msgBuffer->readableBytes() > 0)
{

View File

@ -14,7 +14,7 @@
#pragma once
#include "WebSockectConnectionImpl.h"
#include "WebSocketConnectionImpl.h"
#include <drogon/WebSocketClient.h>
#include <trantor/utils/NonCopyable.h>
#include <trantor/net/EventLoop.h>
@ -29,7 +29,7 @@ namespace drogon
class WebSocketClientImpl : public WebSocketClient, public std::enable_shared_from_this<WebSocketClientImpl>
{
public:
virtual const WebSocketConnectionPtr &getConnection() override
virtual WebSocketConnectionPtr getConnection() override
{
return _websockConnPtr;
}
@ -50,6 +50,7 @@ class WebSocketClientImpl : public WebSocketClient, public std::enable_shared_fr
virtual void connectToServer(const HttpRequestPtr &request, const WebSocketRequestCallback &callback) override
{
assert(callback);
if (_loop->isInLoopThread())
{
_upgradeRequest = request;
@ -87,10 +88,10 @@ class WebSocketClientImpl : public WebSocketClient, public std::enable_shared_fr
trantor::TimerId _heartbeatTimerId;
HttpRequestPtr _upgradeRequest;
std::function<void(std::string &&message, const WebSocketClientPtr &, const WebSocketMessageType &)> _messageCallback;
std::function<void(const WebSocketClientPtr &)> _connectionClosedCallback;
std::function<void(std::string &&message, const WebSocketClientPtr &, const WebSocketMessageType &)> _messageCallback = [](std::string &&message, const WebSocketClientPtr &, const WebSocketMessageType &) {};
std::function<void(const WebSocketClientPtr &)> _connectionClosedCallback = [](const WebSocketClientPtr &) {};
WebSocketRequestCallback _requestCallback;
WebSocketConnectionPtr _websockConnPtr;
WebSocketConnectionImplPtr _websockConnPtr;
void connectToServerInLoop();
void sendReq(const trantor::TcpConnectionPtr &connPtr);

View File

@ -0,0 +1,292 @@
/**
*
* WebSocketConnectionImpl.cc
* An Tao
*
* Copyright 2018, An Tao. All rights reserved.
* https://github.com/an-tao/drogon
* Use of this source code is governed by a MIT license
* that can be found in the License file.
*
* Drogon
*
*/
#include "WebSocketConnectionImpl.h"
#include <trantor/net/TcpConnection.h>
#include <trantor/net/inner/TcpConnectionImpl.h>
#include <thread>
using namespace drogon;
WebSocketConnectionImpl::WebSocketConnectionImpl(const trantor::TcpConnectionPtr &conn, bool isServer)
: _tcpConn(conn),
_localAddr(conn->localAddr()),
_peerAddr(conn->peerAddr()),
_isServer(isServer)
{
}
void WebSocketConnectionImpl::send(const char *msg, uint64_t len, const WebSocketMessageType &type)
{
unsigned char opcode;
if (type == WebSocketMessageType::Text)
opcode = 1;
else if (type == WebSocketMessageType::Binary)
opcode = 2;
else if (type == WebSocketMessageType::Close)
{
assert(len <= 125);
opcode = 8;
}
else if (type == WebSocketMessageType::Ping)
{
assert(len <= 125);
opcode = 9;
}
else if (type == WebSocketMessageType::Pong)
{
assert(len <= 125);
opcode = 10;
}
else
{
opcode = 0;
assert(0);
}
sendWsData(msg, len, opcode);
}
void WebSocketConnectionImpl::sendWsData(const char *msg, size_t len, unsigned char opcode)
{
LOG_TRACE << "send " << len << " bytes";
//Format the frame
std::string bytesFormatted;
bytesFormatted.resize(len + 10);
bytesFormatted[0] = char(0x80 | (opcode & 0x0f));
int indexStartRawData = -1;
if (len <= 125)
{
bytesFormatted[1] = len;
indexStartRawData = 2;
}
else if (len <= 65535)
{
bytesFormatted[1] = 126;
bytesFormatted[2] = ((len >> 8) & 255);
bytesFormatted[3] = ((len)&255);
LOG_TRACE << "bytes[2]=" << (size_t)bytesFormatted[2];
LOG_TRACE << "bytes[3]=" << (size_t)bytesFormatted[3];
indexStartRawData = 4;
}
else
{
bytesFormatted[1] = 127;
bytesFormatted[2] = ((len >> 56) & 255);
bytesFormatted[3] = ((len >> 48) & 255);
bytesFormatted[4] = ((len >> 40) & 255);
bytesFormatted[5] = ((len >> 32) & 255);
bytesFormatted[6] = ((len >> 24) & 255);
bytesFormatted[7] = ((len >> 16) & 255);
bytesFormatted[8] = ((len >> 8) & 255);
bytesFormatted[9] = ((len)&255);
indexStartRawData = 10;
}
if (!_isServer)
{
//Add masking key;
static std::once_flag once;
std::call_once(once, []() {
std::srand(time(nullptr));
});
int random = std::rand();
bytesFormatted[1] = (bytesFormatted[1] | 0x80);
bytesFormatted.resize(indexStartRawData + 4 + len);
*((int *)&bytesFormatted[indexStartRawData]) = random;
for (size_t i = 0; i < len; i++)
{
bytesFormatted[indexStartRawData + 4 + i] = (msg[i] ^ bytesFormatted[indexStartRawData + (i % 4)]);
}
}
else
{
bytesFormatted.resize(indexStartRawData);
bytesFormatted.append(msg, len);
}
_tcpConn->send(bytesFormatted);
}
void WebSocketConnectionImpl::send(const std::string &msg, const WebSocketMessageType &type)
{
send(msg.data(), msg.length(), type);
}
const trantor::InetAddress &WebSocketConnectionImpl::localAddr() const
{
return _localAddr;
}
const trantor::InetAddress &WebSocketConnectionImpl::peerAddr() const
{
return _peerAddr;
}
bool WebSocketConnectionImpl::connected() const
{
return _tcpConn->connected();
}
bool WebSocketConnectionImpl::disconnected() const
{
return _tcpConn->disconnected();
}
void WebSocketConnectionImpl::WebSocketConnectionImpl::shutdown()
{
_tcpConn->shutdown();
}
void WebSocketConnectionImpl::WebSocketConnectionImpl::forceClose()
{
_tcpConn->forceClose();
}
void WebSocketConnectionImpl::setContext(const any &context)
{
_context = context;
}
const any &WebSocketConnectionImpl::WebSocketConnectionImpl::getContext() const
{
return _context;
}
any *WebSocketConnectionImpl::WebSocketConnectionImpl::getMutableContext()
{
return &_context;
}
bool WebSocketMessageParser::parse(trantor::MsgBuffer *buffer)
{
//According to the rfc6455
_gotAll = false;
if (buffer->readableBytes() >= 2)
{
unsigned char opcode = (*buffer)[0] & 0x0f;
bool isControlFrame = false;
switch (opcode)
{
case 0:
//continuation frame
break;
case 1:
_type = WebSocketMessageType::Text;
break;
case 2:
_type = WebSocketMessageType::Binary;
break;
case 8:
_type = WebSocketMessageType::Close;
isControlFrame = true;
break;
case 9:
_type = WebSocketMessageType::Ping;
isControlFrame = true;
break;
case 10:
_type = WebSocketMessageType::Pong;
isControlFrame = true;
break;
default:
LOG_ERROR << "Unknown frame type";
return false;
break;
}
bool isFin = (((*buffer)[0] & 0x80) == 0x80);
if (!isFin && isControlFrame)
{
//rfc6455-5.5
LOG_ERROR << "Bad frame: all control frames MUST NOT be fragmented";
return false;
}
auto secondByte = (*buffer)[1];
size_t length = secondByte & 127;
int isMasked = (secondByte & 0x80);
if (isMasked != 0)
{
LOG_TRACE << "data encoded!";
}
else
LOG_TRACE << "plain data";
size_t indexFirstMask = 2;
if (length == 126)
{
indexFirstMask = 4;
}
else if (length == 127)
{
indexFirstMask = 10;
}
if (indexFirstMask > 2 && buffer->readableBytes() >= indexFirstMask)
{
if (isControlFrame)
{
//rfc6455-5.5
LOG_ERROR << "Bad frame: all control frames MUST have a payload length of 125 bytes or less";
return false;
}
if (indexFirstMask == 4)
{
length = (unsigned char)(*buffer)[2];
length = (length << 8) + (unsigned char)(*buffer)[3];
}
else if (indexFirstMask == 10)
{
length = (unsigned char)(*buffer)[2];
length = (length << 8) + (unsigned char)(*buffer)[3];
length = (length << 8) + (unsigned char)(*buffer)[4];
length = (length << 8) + (unsigned char)(*buffer)[5];
length = (length << 8) + (unsigned char)(*buffer)[6];
length = (length << 8) + (unsigned char)(*buffer)[7];
length = (length << 8) + (unsigned char)(*buffer)[8];
length = (length << 8) + (unsigned char)(*buffer)[9];
}
else
{
LOG_ERROR << "Websock parsing failed!";
return false;
}
}
if (isMasked != 0)
{
if (buffer->readableBytes() >= (indexFirstMask + 4 + length))
{
auto masks = buffer->peek() + indexFirstMask;
int indexFirstDataByte = indexFirstMask + 4;
auto rawData = buffer->peek() + indexFirstDataByte;
auto oldLen = _message.length();
_message.resize(oldLen + length);
for (size_t i = 0; i < length; i++)
{
_message[oldLen + i] = (rawData[i] ^ masks[i % 4]);
}
if (isFin)
_gotAll = true;
buffer->retrieve(indexFirstMask + 4 + length);
return true;
}
}
else
{
if (buffer->readableBytes() >= (indexFirstMask + length))
{
auto rawData = buffer->peek() + indexFirstMask;
_message.append(rawData, length);
if (isFin)
_gotAll = true;
buffer->retrieve(indexFirstMask + length);
return true;
}
}
}
return true;
}

View File

@ -0,0 +1,145 @@
/**
*
* WebSocketConnectionImpl.h
* An Tao
*
* Copyright 2018, An Tao. All rights reserved.
* https://github.com/an-tao/drogon
* Use of this source code is governed by a MIT license
* that can be found in the License file.
*
* Drogon
*
*/
#pragma once
#include <drogon/WebSocketConnection.h>
#include <drogon/WebSocketController.h>
namespace drogon
{
class WebSocketConnectionImpl;
typedef std::shared_ptr<WebSocketConnectionImpl> WebSocketConnectionImplPtr;
class WebSocketMessageParser
{
public:
bool parse(trantor::MsgBuffer *buffer);
bool gotAll(std::string &message, WebSocketMessageType &type)
{
assert(message.empty());
if (!_gotAll)
return false;
message.swap(_message);
type = _type;
return true;
}
private:
std::string _message;
WebSocketMessageType _type;
bool _gotAll = false;
};
class WebSocketConnectionImpl : public WebSocketConnection, public std::enable_shared_from_this<WebSocketConnectionImpl>
{
public:
explicit WebSocketConnectionImpl(const trantor::TcpConnectionPtr &conn, bool isServer = true);
virtual void send(const char *msg, uint64_t len, const WebSocketMessageType &type = WebSocketMessageType::Text) override;
virtual void send(const std::string &msg, const WebSocketMessageType &type = WebSocketMessageType::Text) override;
virtual const trantor::InetAddress &localAddr() const override;
virtual const trantor::InetAddress &peerAddr() const override;
virtual bool connected() const override;
virtual bool disconnected() const override;
virtual void shutdown() override; //close write
virtual void forceClose() override; //close
virtual void setContext(const any &context) override;
virtual const any &getContext() const override;
virtual any *getMutableContext() override;
void setMessageCallback(const std::function<void(std::string &&,
const WebSocketConnectionImplPtr &,
const WebSocketMessageType &)> &callback)
{
_messageCallback = callback;
}
void setCloseCallback(const std::function<void(const WebSocketConnectionImplPtr &)> &callback)
{
_closeCallback = callback;
}
void onNewMessage(const trantor::TcpConnectionPtr &connPtr, trantor::MsgBuffer *buffer)
{
while (buffer->readableBytes() > 0)
{
auto success = _parser.parse(buffer);
if (success)
{
std::string message;
WebSocketMessageType type;
if (_parser.gotAll(message, type))
{
if (type == WebSocketMessageType::Ping)
{
//ping
send(message, WebSocketMessageType::Pong);
}
else if (type == WebSocketMessageType::Close)
{
//close
connPtr->shutdown();
}
else if (type == WebSocketMessageType::Unknown)
{
return;
}
_messageCallback(std::move(message), shared_from_this(), type);
}
else
{
return;
}
}
else
{
//Websock error!
connPtr->shutdown();
return;
}
}
return;
}
void onClose()
{
_closeCallback(shared_from_this());
}
private:
trantor::TcpConnectionPtr _tcpConn;
trantor::InetAddress _localAddr;
trantor::InetAddress _peerAddr;
any _context;
bool _isServer = true;
std::function<void(std::string &&,
const WebSocketConnectionImplPtr &,
const WebSocketMessageType &)>
_messageCallback = [](std::string &&,
const WebSocketConnectionImplPtr &,
const WebSocketMessageType &) {};
std::function<void(const WebSocketConnectionImplPtr &)> _closeCallback = [](const WebSocketConnectionImplPtr &) {};
void sendWsData(const char *msg, size_t len, unsigned char opcode);
WebSocketMessageParser _parser;
};
} // namespace drogon

View File

@ -41,7 +41,7 @@ void WebsocketControllersRouter::registerWebSocketController(const std::string &
void WebsocketControllersRouter::route(const HttpRequestImplPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const WebSocketConnectionPtr &wsConnPtr)
const WebSocketConnectionImplPtr &wsConnPtr)
{
std::string wsKey = req->getHeaderBy("sec-websocket-key");
if (!wsKey.empty())
@ -81,7 +81,7 @@ void WebsocketControllersRouter::doControllerHandler(const WebSocketControllerBa
std::string &wsKey,
const HttpRequestImplPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const WebSocketConnectionPtr &wsConnPtr)
const WebSocketConnectionImplPtr &wsConnPtr)
{
wsKey.append("258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
unsigned char accKey[SHA_DIGEST_LENGTH];
@ -93,9 +93,14 @@ void WebsocketControllersRouter::doControllerHandler(const WebSocketControllerBa
resp->addHeader("Connection", "Upgrade");
resp->addHeader("Sec-WebSocket-Accept", base64Key);
callback(resp);
auto wsConnImplPtr = std::dynamic_pointer_cast<WebSocketConnectionImpl>(wsConnPtr);
assert(wsConnImplPtr);
wsConnImplPtr->setController(ctrlPtr);
wsConnPtr->setMessageCallback([ctrlPtr](std::string &&message,
const WebSocketConnectionImplPtr &connPtr,
const WebSocketMessageType &type) {
ctrlPtr->handleNewMessage(connPtr, std::move(message), type);
});
wsConnPtr->setCloseCallback([ctrlPtr](const WebSocketConnectionImplPtr &connPtr) {
ctrlPtr->handleConnectionClosed(connPtr);
});
ctrlPtr->handleNewConnection(req, wsConnPtr);
return;
}

View File

@ -15,6 +15,7 @@
#pragma once
#include "HttpRequestImpl.h"
#include "HttpResponseImpl.h"
#include "WebSocketConnectionImpl.h"
#include <trantor/utils/NonCopyable.h>
#include <drogon/WebSocketController.h>
#include <drogon/HttpFilter.h>
@ -36,7 +37,7 @@ class WebsocketControllersRouter : public trantor::NonCopyable
const std::vector<std::string> &filters);
void route(const HttpRequestImplPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const WebSocketConnectionPtr &wsConnPtr);
const WebSocketConnectionImplPtr &wsConnPtr);
void init();
private:
@ -53,6 +54,6 @@ class WebsocketControllersRouter : public trantor::NonCopyable
std::string &wsKey,
const HttpRequestImplPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const WebSocketConnectionPtr &wsConnPtr);
const WebSocketConnectionImplPtr &wsConnPtr);
};
} // namespace drogon

View File

@ -19,6 +19,7 @@ killall -9 webapp
sleep 4
echo "Test http requests and responses."
./webapp_test
if [ $? -ne 0 ];then
@ -27,6 +28,7 @@ if [ $? -ne 0 ];then
fi
#Test WebSocket
echo "Test the WebSocket"
./websocket_test -t
if [ $? -ne 0 ];then
echo "Error in testing"
@ -34,6 +36,7 @@ if [ $? -ne 0 ];then
fi
#Test pipelining
echo "Test the pipelining"
./pipelining_test
if [ $? -ne 0 ];then
echo "Error in testing"
@ -43,7 +46,7 @@ fi
killall -9 webapp
#Test drogon_ctl
echo "Test the drogon_ctl"
rm -rf drogon_test
drogon_ctl create project drogon_test