mirror of
https://gitee.com/ldcsaa/HP-Socket.git
synced 2024-12-02 03:37:50 +08:00
496 lines
15 KiB
C++
496 lines
15 KiB
C++
/*
|
||
* Copyright: JessMA Open Source (ldcsaa@gmail.com)
|
||
*
|
||
* Author : Bruce Liang
|
||
* Website : https://github.com/ldcsaa
|
||
* Project : https://github.com/ldcsaa/HP-Socket
|
||
* Blog : http://www.cnblogs.com/ldcsaa
|
||
* Wiki : http://www.oschina.net/p/hp-socket
|
||
* QQ Group : 44636872, 75375912
|
||
*
|
||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
* you may not use this file except in compliance with the License.
|
||
* You may obtain a copy of the License at
|
||
*
|
||
* http://www.apache.org/licenses/LICENSE-2.0
|
||
*
|
||
* Unless required by applicable law or agreed to in writing, software
|
||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
* See the License for the specific language governing permissions and
|
||
* limitations under the License.
|
||
*/
|
||
|
||
#pragma once
|
||
|
||
#include "../Include/HPSocket/HPTypeDef.h"
|
||
|
||
#ifdef _SSL_SUPPORT
|
||
|
||
#include "Common/BufferPool.h"
|
||
|
||
#pragma warning(push)
|
||
#pragma warning(disable: 4005)
|
||
|
||
#include "openssl/ssl.h"
|
||
|
||
#pragma warning(pop)
|
||
|
||
#define OPENSSL_VERSION_1_0_2 0x10002000L
|
||
#define OPENSSL_VERSION_1_1_0 0x10100000L
|
||
|
||
#if OPENSSL_VERSION_NUMBER < OPENSSL_VERSION_1_1_0
|
||
#define DEFAULT_CIPHER_LIST _T("DEFAULT:!aNULL:!eNULL:!SSLv2")
|
||
#else
|
||
#define DEFAULT_CIPHER_LIST _T("DEFAULT:!aNULL:!eNULL:!SSLv2:!SSLv3")
|
||
#endif
|
||
|
||
/************************************************************************
|
||
名称:SSL 全局常量
|
||
描述:声明 SSL 组件的公共全局常量
|
||
************************************************************************/
|
||
|
||
#define SSL_DOMAIN_SEP_CHAR '.'
|
||
|
||
/************************************************************************
|
||
名称:SSL 握手状态
|
||
描述:标识当前连接的 SSL 握手状态
|
||
************************************************************************/
|
||
enum EnSSLHandShakeStatus
|
||
{
|
||
SSL_HSS_INIT = 0, // 初始状态
|
||
SSL_HSS_PROC = 1, // 正在握手
|
||
SSL_HSS_SUCC = 2, // 握手成功
|
||
};
|
||
|
||
#if OPENSSL_VERSION_NUMBER < OPENSSL_VERSION_1_1_0
|
||
|
||
/* SSL CRYPTO DYNLOCK 结构 */
|
||
typedef struct CRYPTO_dynlock_value
|
||
{
|
||
CSimpleRWLock cs;
|
||
} DynamicLock;
|
||
|
||
#endif
|
||
|
||
class CSSLInitializer
|
||
{
|
||
public:
|
||
|
||
static void CleanupThreadState(DWORD dwThreadID = 0);
|
||
|
||
private:
|
||
|
||
CSSLInitializer();
|
||
~CSSLInitializer();
|
||
|
||
private:
|
||
|
||
#if OPENSSL_VERSION_NUMBER < OPENSSL_VERSION_1_1_0
|
||
static void ssl_lock_callback(int mode, int n, const char *file, int line);
|
||
static CRYPTO_dynlock_value* ssl_lock_dyn_create_callback(const char *file, int line);
|
||
static void ssl_lock_dyn_callback(int mode, CRYPTO_dynlock_value* l, const char *file, int line);
|
||
static void ssl_lock_dyn_destroy_callback(CRYPTO_dynlock_value* l, const char *file, int line);
|
||
#endif
|
||
|
||
private:
|
||
|
||
#if OPENSSL_VERSION_NUMBER < OPENSSL_VERSION_1_1_0
|
||
static int sm_iLockNum;
|
||
static CSimpleRWLock* sm_pcsLocks;
|
||
#endif
|
||
|
||
static CSSLInitializer sm_instance;
|
||
|
||
};
|
||
|
||
/************************************************************************
|
||
名称:SSL Context
|
||
描述:初始化和清理 SSL 运行环境
|
||
************************************************************************/
|
||
class CSSLContext
|
||
{
|
||
typedef unordered_map<CString, int, cstring_nc_hash_func::hash, cstring_nc_hash_func::equal_to> CServerNameMap;
|
||
|
||
public:
|
||
|
||
/*
|
||
* 名称:初始化 SSL 环境参数
|
||
* 描述:SSL 环境参数必须在 SSL 通信组件启动前完成初始化,否则启动失败
|
||
*
|
||
* 参数: enSessionMode -- SSL 工作模式(参考 EnSSLSessionMode)
|
||
* iVerifyMode -- SSL 验证模式(参考 EnSSLVerifyMode)
|
||
* lpPemCert -- 证书文件(客户端可选)
|
||
* lpPemKey -- 私钥文件(客户端可选)
|
||
* lpKeyPasswod -- 私钥密码(没有密码则为空)
|
||
* lpCAPemCert -- CA 证书文件或目录(单向验证或客户端可选)
|
||
* fnServerNameCallback -- SNI 回调函数指针(可选,只用于服务端,如果为 nullptr 则使用 SNI 默认回调函数)
|
||
*
|
||
* 返回值: TRUE -- 成功
|
||
* FALSE -- 失败,可通过 ::GetLastError() 获取失败原因
|
||
*/
|
||
BOOL Initialize(EnSSLSessionMode enSessionMode, int iVerifyMode = SSL_VM_NONE, BOOL bMemory = FALSE, LPVOID lpPemCert = nullptr, LPVOID lpPemKey = nullptr, LPVOID lpKeyPasswod = nullptr, LPVOID lpCAPemCert = nullptr, Fn_SNI_ServerNameCallback fnServerNameCallback = nullptr);
|
||
|
||
/*
|
||
* 名称:增加 SNI 主机证书(只用于服务端)
|
||
* 描述:SSL 服务端在 Initialize() 成功后可以调用本方法增加多个 SNI 主机证书
|
||
*
|
||
* 参数: iVerifyMode -- SSL 验证模式(参考 EnSSLVerifyMode)
|
||
* lpPemCert -- 证书文件
|
||
* lpPemKey -- 私钥文件
|
||
* lpKeyPasswod -- 私钥密码(没有密码则为空)
|
||
* lpCAPemCert -- CA 证书文件或目录(单向验证可选)
|
||
*
|
||
* 返回值: 正数 -- 成功,并返回 SNI 主机证书对应的索引,该索引用于在 SNI 回调函数中定位 SNI 主机
|
||
* 负数 -- 失败,可通过 ::GetLastError() 获取失败原因
|
||
*/
|
||
int AddServerContext(int iVerifyMode, BOOL bMemory, LPVOID lpPemCert, LPVOID lpPemKey, LPVOID lpKeyPasswod = nullptr, LPVOID lpCAPemCert = nullptr);
|
||
|
||
/*
|
||
* 名称:绑定 SNI 主机域名
|
||
* 描述:SSL 服务端在 AddServerContext() 成功后可以调用本方法绑定主机域名到 SNI 主机证书
|
||
*
|
||
* 参数: lpszServerName -- 主机域名
|
||
* iContextIndex -- SNI 主机证书对应的索引
|
||
*
|
||
* 返回值: TRUE -- 成功
|
||
* FALSE -- 失败,可通过 ::GetLastError() 获取失败原因
|
||
*/
|
||
virtual BOOL BindServerName(LPCTSTR lpszServerName, int iContextIndex);
|
||
|
||
/*
|
||
* 名称:清理 SSL 运行环境
|
||
* 描述:清理 SSL 运行环境,回收 SSL 相关内存
|
||
* 1、CSSLContext 的析构函数会自动调用本方法
|
||
* 2、当要重新设置 SSL 环境参数时,需要先调用本方法清理原先的环境参数
|
||
*
|
||
* 参数: 无
|
||
*
|
||
* 返回值:无
|
||
*/
|
||
void Cleanup();
|
||
|
||
/* 获取 SSL 运行环境 SSL_CTX 对象 */
|
||
SSL_CTX* GetContext (int i) const;
|
||
/* 获取 SSL 运行环境默认 SSL_CTX 对象 */
|
||
SSL_CTX* GetDefaultContext () const {return m_sslCtx;}
|
||
/* 获取 SSL 运行环境的配置模式,配置模式参考:EnSSLSessionMode */
|
||
EnSSLSessionMode GetSessionMode () const {return m_enSessionMode;}
|
||
/* 检查 SSL 运行环境是否初始化完成 */
|
||
BOOL IsValid () const {return m_sslCtx != nullptr;}
|
||
|
||
/* 设置 SSL 加密算法列表 */
|
||
void SetCipherList(LPCTSTR lpszCipherList) {m_strCipherList = lpszCipherList;}
|
||
/* 获取 SSL 加密算法列表 */
|
||
LPCTSTR GetCipherList() {return m_strCipherList;}
|
||
|
||
public:
|
||
|
||
/*
|
||
* 名称:清理线程局部环境 SSL 资源
|
||
* 描述:任何一个操作 SSL 的线程,在通信结束时都需要清理线程局部环境 SSL 资源
|
||
* 1、主线程和 HP-Socket 工作线程在通信结束时会自动清理线程局部环境 SSL 资源。因此,一般情况下不必手工调用本方法
|
||
* 2、特殊情况下,当自定义线程参与 HP-Socket 通信操作并检查到 SSL 内存泄漏时,需在每次通信结束时自定义线程调用本方法
|
||
*
|
||
* 参数: dwThreadID -- 线程 ID(0:当前线程)
|
||
*
|
||
* 返回值:无
|
||
*/
|
||
static void RemoveThreadLocalState(DWORD dwThreadID = 0) {CSSLInitializer::CleanupThreadState(dwThreadID);}
|
||
|
||
public:
|
||
|
||
CSSLContext()
|
||
: m_strCipherList (DEFAULT_CIPHER_LIST)
|
||
, m_enSessionMode (SSL_SM_SERVER)
|
||
, m_sslCtx (nullptr)
|
||
, m_fnServerNameCallback(nullptr)
|
||
{
|
||
|
||
}
|
||
|
||
~CSSLContext() {Cleanup();}
|
||
|
||
private:
|
||
|
||
void SetServerNameCallback(Fn_SNI_ServerNameCallback fn);
|
||
int AddContext(int iVerifyMode, BOOL bMemory, LPVOID lpPemCert, LPVOID lpPemKey, LPVOID lpKeyPasswod, LPVOID lpCAPemCert);
|
||
BOOL LoadCertAndKey(SSL_CTX* sslCtx, int iVerifyMode, BOOL bMemory, LPVOID lpPemCert, LPVOID lpPemKey, LPVOID lpKeyPasswod, LPVOID lpCAPemCert);
|
||
BOOL LoadCertAndKeyByFile(SSL_CTX* sslCtx, int iVerifyMode, LPCTSTR lpszPemCertFile, LPCTSTR lpszPemKeyFile, LPCTSTR lpszKeyPassword, LPCTSTR lpszCAPemCertFileOrPath);
|
||
BOOL LoadCertAndKeyByMemory(SSL_CTX* sslCtx, int iVerifyMode, LPCSTR lpszPemCert, LPCSTR lpszPemKey, LPCSTR lpszKeyPassword, LPCSTR lpszCAPemCert);
|
||
BOOL LoadCAPemCertByMemory(SSL_CTX* sslCtx, int iVerifyMode, LPCSTR lpszCAPemCert);
|
||
BOOL LoadPemCertAndKeyByMemory(SSL_CTX* sslCtx, LPCSTR lpszPemCert, LPCSTR lpszPemKey, LPCSTR lpszKeyPassword);
|
||
BOOL AddCAPemCertToStoreByMemory(SSL_CTX* sslCtx, LPCSTR lpszPemCert);
|
||
BOOL SetClientCAListByMemory(SSL_CTX* sslCtx, LPCSTR lpszPemCert);
|
||
BOOL SetPrivateKeyByMemory(SSL_CTX* sslCtx, LPCSTR lpszPemKey);
|
||
BOOL SetCertChainByMemory(SSL_CTX* sslCtx, LPCSTR lpszPemCert);
|
||
|
||
private:
|
||
|
||
static int InternalServerNameCallback(SSL* ssl, int* ad, void* arg);
|
||
|
||
public:
|
||
|
||
/*
|
||
* 名称:SNI 默认回调函数
|
||
* 描述:Initialize 方法中如果不指定 SNI 回调函数则使用此 SNI 默认回调函数
|
||
*
|
||
* 参数: lpszServerName -- 请求域名
|
||
* pContext -- SSL Context 对象
|
||
*
|
||
* 返回值:SNI 主机证书对应的索引
|
||
*/
|
||
static int __HP_CALL DefaultServerNameCallback(LPCTSTR lpszServerName, PVOID pContext);
|
||
|
||
private:
|
||
|
||
CString m_strCipherList;
|
||
EnSSLSessionMode m_enSessionMode;
|
||
CServerNameMap m_sslServerNames;
|
||
vector<SSL_CTX*> m_lsSslCtxs;
|
||
SSL_CTX* m_sslCtx;
|
||
|
||
Fn_SNI_ServerNameCallback m_fnServerNameCallback;
|
||
};
|
||
|
||
class CSSLSession
|
||
{
|
||
public:
|
||
|
||
BOOL WriteRecvChannel(const BYTE* pData, int iLength);
|
||
BOOL ReadRecvChannel();
|
||
|
||
BOOL WriteSendChannel(const BYTE* pData, int iLength);
|
||
BOOL WriteSendChannel(const WSABUF pBuffers[], int iCount);
|
||
BOOL ReadSendChannel();
|
||
|
||
const WSABUF& GetRecvBuffer() const {return m_bufRecv;}
|
||
const WSABUF& GetSendBuffer() const {return m_bufSend;}
|
||
|
||
CSSLSession* Renew(const CSSLContext& sslCtx, LPCSTR lpszHostName = nullptr);
|
||
BOOL Reset();
|
||
BOOL IsValid() const {return GetStatus() != SSL_HSS_INIT;}
|
||
BOOL IsHandShaking() const {return GetStatus() == SSL_HSS_PROC;}
|
||
BOOL IsReady() const {return GetStatus() == SSL_HSS_SUCC;}
|
||
EnSSLHandShakeStatus GetStatus() const {return m_enStatus;}
|
||
DWORD GetFreeTime() const {return m_dwFreeTime;}
|
||
CCriSec& GetSendLock() {return m_csSend;}
|
||
BOOL GetSessionInfo(EnSSLSessionInfo enInfo, LPVOID* lppInfo);
|
||
|
||
private:
|
||
|
||
BOOL IsFatalError(int iBytes);
|
||
|
||
public:
|
||
|
||
CSSLSession(CItemPool& itPool)
|
||
: m_enStatus(SSL_HSS_INIT)
|
||
, m_itPool (itPool)
|
||
, m_ssl (nullptr)
|
||
, m_bioSend (nullptr)
|
||
, m_bioRecv (nullptr)
|
||
, m_pitSend (nullptr)
|
||
, m_pitRecv (nullptr)
|
||
{
|
||
|
||
}
|
||
|
||
~CSSLSession()
|
||
{
|
||
Reset();
|
||
}
|
||
|
||
static CSSLSession* Construct(CItemPool& itPool)
|
||
{return new CSSLSession(itPool);}
|
||
|
||
static void Destruct(CSSLSession* pSession)
|
||
{if(pSession) delete pSession;}
|
||
|
||
private:
|
||
CItemPool& m_itPool;
|
||
CCriSec m_csSend;
|
||
|
||
DWORD m_dwFreeTime;
|
||
EnSSLHandShakeStatus m_enStatus;
|
||
|
||
SSL* m_ssl;
|
||
BIO* m_bioSend;
|
||
BIO* m_bioRecv;
|
||
|
||
TItem* m_pitSend;
|
||
TItem* m_pitRecv;
|
||
WSABUF m_bufSend;
|
||
WSABUF m_bufRecv;
|
||
};
|
||
|
||
class CSSLSessionPool
|
||
{
|
||
typedef CRingPool<CSSLSession> TSSLSessionList;
|
||
typedef CCASQueue<CSSLSession> TSSLSessionQueue;
|
||
|
||
public:
|
||
CSSLSession* PickFreeSession (LPCSTR lpszHostName = nullptr);
|
||
void PutFreeSession (CSSLSession* pSession);
|
||
|
||
void Prepare ();
|
||
void Clear ();
|
||
|
||
private:
|
||
void ReleaseGCSession (BOOL bForce = FALSE);
|
||
|
||
public:
|
||
void SetItemCapacity (DWORD dwItemCapacity) {m_itPool.SetItemCapacity(dwItemCapacity);}
|
||
void SetItemPoolSize (DWORD dwItemPoolSize) {m_itPool.SetPoolSize(dwItemPoolSize);}
|
||
void SetItemPoolHold (DWORD dwItemPoolHold) {m_itPool.SetPoolHold(dwItemPoolHold);}
|
||
|
||
void SetSessionLockTime (DWORD dwSessionLockTime) {m_dwSessionLockTime = dwSessionLockTime;}
|
||
void SetSessionPoolSize (DWORD dwSessionPoolSize) {m_dwSessionPoolSize = dwSessionPoolSize;}
|
||
void SetSessionPoolHold (DWORD dwSessionPoolHold) {m_dwSessionPoolHold = dwSessionPoolHold;}
|
||
|
||
DWORD GetItemCapacity () {return m_itPool.GetItemCapacity();}
|
||
DWORD GetItemPoolSize () {return m_itPool.GetPoolSize();}
|
||
DWORD GetItemPoolHold () {return m_itPool.GetPoolHold();}
|
||
|
||
DWORD GetSessionLockTime() {return m_dwSessionLockTime;}
|
||
DWORD GetSessionPoolSize() {return m_dwSessionPoolSize;}
|
||
DWORD GetSessionPoolHold() {return m_dwSessionPoolHold;}
|
||
|
||
public:
|
||
CSSLSessionPool(const CSSLContext& sslCtx,
|
||
DWORD dwPoolSize = DEFAULT_SESSION_POOL_SIZE,
|
||
DWORD dwPoolHold = DEFAULT_SESSION_POOL_HOLD,
|
||
DWORD dwLockTime = DEFAULT_SESSION_LOCK_TIME)
|
||
: m_sslCtx(sslCtx)
|
||
, m_dwSessionPoolSize(dwPoolSize)
|
||
, m_dwSessionPoolHold(dwPoolHold)
|
||
, m_dwSessionLockTime(dwLockTime)
|
||
{
|
||
|
||
}
|
||
|
||
~CSSLSessionPool() {Clear();}
|
||
|
||
DECLARE_NO_COPY_CLASS(CSSLSessionPool)
|
||
|
||
public:
|
||
static const DWORD DEFAULT_ITEM_CAPACITY;
|
||
static const DWORD DEFAULT_ITEM_POOL_SIZE;
|
||
static const DWORD DEFAULT_ITEM_POOL_HOLD;
|
||
static const DWORD DEFAULT_SESSION_LOCK_TIME;
|
||
static const DWORD DEFAULT_SESSION_POOL_SIZE;
|
||
static const DWORD DEFAULT_SESSION_POOL_HOLD;
|
||
|
||
private:
|
||
CItemPool m_itPool;
|
||
const CSSLContext& m_sslCtx;
|
||
|
||
DWORD m_dwSessionLockTime;
|
||
DWORD m_dwSessionPoolSize;
|
||
DWORD m_dwSessionPoolHold;
|
||
|
||
TSSLSessionList m_lsFreeSession;
|
||
TSSLSessionQueue m_lsGCSession;
|
||
};
|
||
|
||
template<class T, class S> EnHandleResult ProcessHandShake(T* pThis, S* pSocketObj, CSSLSession* pSession)
|
||
{
|
||
EnHandleResult result = HR_OK;
|
||
|
||
CCriSecLock locallock(pSession->GetSendLock());
|
||
|
||
while(TRUE)
|
||
{
|
||
ENSURE(pSession->ReadSendChannel());
|
||
const WSABUF& buffer = pSession->GetSendBuffer();
|
||
|
||
if(buffer.len == 0)
|
||
break;
|
||
|
||
if(!pThis->DoSendPackets(pSocketObj, &buffer, 1))
|
||
{
|
||
result = HR_ERROR;
|
||
break;
|
||
}
|
||
}
|
||
|
||
return result;
|
||
}
|
||
|
||
template<class T, class S> EnHandleResult ProcessReceive(T* pThis, S* pSocketObj, CSSLSession* pSession, const BYTE* pData, int iLength)
|
||
{
|
||
if(!pSession->WriteRecvChannel(pData, iLength))
|
||
return HR_ERROR;
|
||
|
||
EnHandleResult result = HR_OK;
|
||
EnSSLHandShakeStatus enStatus = pSession->GetStatus();
|
||
|
||
while(TRUE)
|
||
{
|
||
if(!pSession->ReadRecvChannel())
|
||
return HR_ERROR;
|
||
|
||
if(enStatus == SSL_HSS_PROC && pSession->IsReady())
|
||
{
|
||
result = ProcessHandShake(pThis, pSocketObj, pSession);
|
||
|
||
if(result == HR_ERROR)
|
||
break;
|
||
|
||
enStatus = SSL_HSS_SUCC;
|
||
result = pThis->DoFireHandShake(pSocketObj);
|
||
|
||
if(result == HR_ERROR)
|
||
break;
|
||
}
|
||
|
||
const WSABUF& buffer = pSession->GetRecvBuffer();
|
||
|
||
if(buffer.len == 0)
|
||
break;
|
||
|
||
result = pThis->DoFireReceive(pSocketObj, (const BYTE*)buffer.buf, buffer.len);
|
||
|
||
if(result == HR_ERROR)
|
||
break;
|
||
}
|
||
|
||
if(result != HR_ERROR && pSession->IsHandShaking())
|
||
result = ::ProcessHandShake(pThis, pSocketObj, pSession);
|
||
|
||
return result;
|
||
}
|
||
|
||
template<class T, class S> BOOL ProcessSend(T* pThis, S* pSocketObj, CSSLSession* pSession, const WSABUF * pBuffers, int iCount)
|
||
{
|
||
if(pSession == nullptr || !pSession->IsReady())
|
||
{
|
||
::SetLastError(ERROR_INVALID_STATE);
|
||
return FALSE;
|
||
}
|
||
|
||
CCriSecLock locallock(pSession->GetSendLock());
|
||
|
||
if(!pSession->IsReady())
|
||
{
|
||
::SetLastError(ERROR_INVALID_STATE);
|
||
return FALSE;
|
||
}
|
||
|
||
ENSURE(pSession->WriteSendChannel(pBuffers, iCount));
|
||
|
||
while(TRUE)
|
||
{
|
||
ENSURE(pSession->ReadSendChannel());
|
||
const WSABUF& buffer = pSession->GetSendBuffer();
|
||
|
||
if(buffer.len == 0)
|
||
break;
|
||
|
||
if(!pThis->DoSendPackets(pSocketObj, &buffer, 1))
|
||
return FALSE;
|
||
}
|
||
|
||
return TRUE;
|
||
}
|
||
|
||
#endif |