diff --git a/examples/simple_example_test/main.cc b/examples/simple_example_test/main.cc index 00843d6c..1c392a2c 100644 --- a/examples/simple_example_test/main.cc +++ b/examples/simple_example_test/main.cc @@ -1390,10 +1390,8 @@ int main(int argc, char *argv[]) if (app().supportSSL()) { std::promise pro2; - auto sslClient = HttpClient::newHttpClient("127.0.0.1", - 8849, - true, - loop[1].getLoop()); + auto sslClient = HttpClient::newHttpClient( + "127.0.0.1", 8849, true, loop[1].getLoop(), false, false); if (sessionID) sslClient->addCookie(sessionID); doTest(sslClient, pro2, true); diff --git a/lib/inc/drogon/HttpClient.h b/lib/inc/drogon/HttpClient.h index 4fd24ccc..93666d1f 100644 --- a/lib/inc/drogon/HttpClient.h +++ b/lib/inc/drogon/HttpClient.h @@ -217,7 +217,8 @@ class HttpClient : public trantor::NonCopyable uint16_t port, bool useSSL = false, trantor::EventLoop *loop = nullptr, - bool useOldTLS = false); + bool useOldTLS = false, + bool validateCert = true); /// Get the event loop of the client; virtual trantor::EventLoop *getLoop() = 0; @@ -255,7 +256,8 @@ class HttpClient : public trantor::NonCopyable */ static HttpClientPtr newHttpClient(const std::string &hostString, trantor::EventLoop *loop = nullptr, - bool useOldTLS = false); + bool useOldTLS = false, + bool validateCert = true); virtual ~HttpClient() { diff --git a/lib/inc/drogon/HttpTypes.h b/lib/inc/drogon/HttpTypes.h index 7e9837c0..c3ab2a14 100644 --- a/lib/inc/drogon/HttpTypes.h +++ b/lib/inc/drogon/HttpTypes.h @@ -131,7 +131,9 @@ enum class ReqResult BadResponse, NetworkFailure, BadServerAddress, - Timeout + Timeout, + HandshakeError, + InvalidCertificate, }; enum class WebSocketMessageType diff --git a/lib/src/HttpClientImpl.cc b/lib/src/HttpClientImpl.cc index 6ba1e731..cbd2c5ad 100644 --- a/lib/src/HttpClientImpl.cc +++ b/lib/src/HttpClientImpl.cc @@ -39,7 +39,7 @@ void HttpClientImpl::createTcpClient() { LOG_TRACE << "useOldTLS=" << useOldTLS_; LOG_TRACE << "domain=" << domain_; - tcpClientPtr_->enableSSL(useOldTLS_, domain_); + tcpClientPtr_->enableSSL(useOldTLS_, validateCert_, domain_); } #endif auto thisPtr = shared_from_this(); @@ -107,21 +107,41 @@ void HttpClientImpl::createTcpClient() thisPtr->onRecvMessage(connPtr, msg); } }); + tcpClientPtr_->setSSLErrorCallback([weakPtr](SSLError err) { + auto thisPtr = weakPtr.lock(); + if (!thisPtr) + return; + if (err == trantor::SSLError::kSSLHandshakeError) + thisPtr->onError(ReqResult::HandshakeError); + else if (err == trantor::SSLError::kSSLInvalidCertificate) + thisPtr->onError(ReqResult::InvalidCertificate); + else + { + LOG_FATAL << "Invalid value for SSLError"; + abort(); + } + }); tcpClientPtr_->connect(); } HttpClientImpl::HttpClientImpl(trantor::EventLoop *loop, const trantor::InetAddress &addr, bool useSSL, - bool useOldTLS) - : loop_(loop), serverAddr_(addr), useSSL_(useSSL), useOldTLS_(useOldTLS) + bool useOldTLS, + bool validateCert) + : loop_(loop), + serverAddr_(addr), + useSSL_(useSSL), + validateCert_(validateCert), + useOldTLS_(useOldTLS) { } HttpClientImpl::HttpClientImpl(trantor::EventLoop *loop, const std::string &hostString, - bool useOldTLS) - : loop_(loop), useOldTLS_(useOldTLS) + bool useOldTLS, + bool validateCert) + : loop_(loop), validateCert_(validateCert), useOldTLS_(useOldTLS) { auto lowerHost = hostString; std::transform(lowerHost.begin(), @@ -559,24 +579,28 @@ HttpClientPtr HttpClient::newHttpClient(const std::string &ip, uint16_t port, bool useSSL, trantor::EventLoop *loop, - bool useOldTLS) + bool useOldTLS, + bool validateCert) { bool isIpv6 = ip.find(':') == std::string::npos ? false : true; return std::make_shared( loop == nullptr ? HttpAppFrameworkImpl::instance().getLoop() : loop, trantor::InetAddress(ip, port, isIpv6), useSSL, - useOldTLS); + useOldTLS, + validateCert); } HttpClientPtr HttpClient::newHttpClient(const std::string &hostString, trantor::EventLoop *loop, - bool useOldTLS) + bool useOldTLS, + bool validateCert) { return std::make_shared( loop == nullptr ? HttpAppFrameworkImpl::instance().getLoop() : loop, hostString, - useOldTLS); + useOldTLS, + validateCert); } void HttpClientImpl::onError(ReqResult result) diff --git a/lib/src/HttpClientImpl.h b/lib/src/HttpClientImpl.h index 4464d8c6..aeab80ce 100644 --- a/lib/src/HttpClientImpl.h +++ b/lib/src/HttpClientImpl.h @@ -33,10 +33,12 @@ class HttpClientImpl : public HttpClient, HttpClientImpl(trantor::EventLoop *loop, const trantor::InetAddress &addr, bool useSSL = false, - bool useOldTLS = false); + bool useOldTLS = false, + bool validateCert = true); HttpClientImpl(trantor::EventLoop *loop, const std::string &hostString, - bool useOldTLS = false); + bool useOldTLS = false, + bool validateCert = true); virtual void sendRequest(const HttpRequestPtr &req, const HttpReqCallback &callback, double timeout = 0) override; @@ -83,6 +85,7 @@ class HttpClientImpl : public HttpClient, trantor::EventLoop *loop_; trantor::InetAddress serverAddr_; bool useSSL_; + bool validateCert_; void sendReq(const trantor::TcpConnectionPtr &connPtr, const HttpRequestPtr &req); void sendRequestInLoop(const HttpRequestPtr &req, diff --git a/trantor b/trantor index 59f857cd..4791f763 160000 --- a/trantor +++ b/trantor @@ -1 +1 @@ -Subproject commit 59f857cd061ee295c3a5cfa69b5a857935382025 +Subproject commit 4791f76335b0bf49fb6a43e9ff2053cdbb41515a