add more safety protection for dns module

This commit is contained in:
shuxin   zheng 2019-12-16 10:43:50 +08:00
parent 6135813d6c
commit 994a27c7e9
7 changed files with 87 additions and 44 deletions

View File

@ -1,5 +1,8 @@
修改历史列表:
665) 2019.12.16
665.1) safety: add more safety protection for dns module.
664) 2019.12.11
664.1) feature: 支持 Linux 平台下的 abstract unix domain 抽象域套接字。

View File

@ -242,8 +242,10 @@ ACL_API void *acl_aio_dns(ACL_AIO *aio);
* @param aio {ACL_AIO*}
* @param dns_list {const char*} DNS ip1:port,ip2:port...
* @param timeout {int}
* @return {int} DNS 0 -1
* UDP UDP
*/
ACL_API void acl_aio_set_dns(ACL_AIO *aio, const char *dns_list, int timeout);
ACL_API int acl_aio_set_dns(ACL_AIO *aio, const char *dns_list, int timeout);
/**
* DNS

View File

@ -74,14 +74,15 @@ typedef struct ACL_DNS_REQ ACL_DNS_REQ;
* @param dns {ACL_DNS*} DNS异步查询句柄
* @param aio {ACL_AIO*}
* @param timeout {int} DNS查询时的超时值
* @return {int} 0 -1
*/
ACL_API void acl_dns_init(ACL_DNS *dns, ACL_AIO *aio, int timeout);
ACL_API int acl_dns_init(ACL_DNS *dns, ACL_AIO *aio, int timeout);
/**
* DNS异步查询对象并同时进行初始化
* @param aio {ACL_AIO*}
* @param timeout {int} DNS查询时的超时值
* @return {ACL_DNS*} DNSÒì²½²éѯ¾ä±ú
* @return {ACL_DNS*} DNS异步查询句柄 NULL DNS
*/
ACL_API ACL_DNS *acl_dns_create(ACL_AIO *aio, int timeout);

View File

@ -101,7 +101,7 @@ void *acl_aio_dns(ACL_AIO *aio)
return aio->dns;
}
void acl_aio_set_dns(ACL_AIO *aio, const char *dns_list, int timeout)
int acl_aio_set_dns(ACL_AIO *aio, const char *dns_list, int timeout)
{
ACL_ARGV *tokens;
ACL_ITER iter;
@ -111,11 +111,17 @@ void acl_aio_set_dns(ACL_AIO *aio, const char *dns_list, int timeout)
if (tokens == NULL) {
acl_msg_error("%s(%d), %s: invalid dns_list=%s",
__FILE__, __LINE__, __FUNCTION__, dns_list);
return;
return -1;
}
if (aio->dns == NULL) {
aio->dns = acl_dns_create(aio, timeout);
if (aio->dns == NULL) {
acl_msg_error("%s(%d), %s: acl_dns_create error=%s",
__FILE__, __LINE__, __FUNCTION__,
acl_last_serror());
return -1;
}
/* acl_dns_check_dns_ip(aio->dns); */
}
@ -137,6 +143,7 @@ void acl_aio_set_dns(ACL_AIO *aio, const char *dns_list, int timeout)
}
acl_argv_free(tokens);
return 0;
}
void acl_aio_del_dns(ACL_AIO *aio, const char *dns_list)

View File

@ -474,17 +474,16 @@ int acl_aio_connect_addr(ACL_AIO *aio, const char *addr, int timeout,
acl_aio_add_connect_hook(conn, connect_callback, ctx);
acl_aio_add_timeo_hook(conn, connect_timeout, ctx);
acl_aio_add_close_hook(conn, connect_failed, ctx);
return 0;
} else if (aio->dns == NULL) {
acl_msg_error("%s(%d), %s: call acl_aio_set_dns first",
__FILE__, __LINE__, __FUNCTION__);
return -1;
} else {
if (aio->dns == NULL) {
acl_msg_error("%s(%d), %s: call acl_aio_set_dns first",
__FILE__, __LINE__, __FUNCTION__);
return -1;
}
acl_dns_lookup(aio->dns, buf, dns_lookup_callback, ctx);
return 0;
}
return 0;
}
int acl_astream_get_status(const ACL_ASTREAM_CTX *ctx)

View File

@ -43,7 +43,7 @@ struct ACL_DNS_REQ{
#define SAFE_COPY ACL_SAFE_STRNCPY
static void dns_stream_open(ACL_DNS *dns);
static int dns_stream_open(ACL_DNS *dns);
/* ACL_VSTREAM: 从数据流读取数据的回调函数 */
@ -355,11 +355,16 @@ static int dns_lookup_close(ACL_ASTREAM *server acl_unused, void *ctx acl_unused
/* 创建DNS查询的异步流 */
static void dns_stream_open(ACL_DNS *dns)
static int dns_stream_open(ACL_DNS *dns)
{
/* ndk9 居然要求 acl_vstream_bind 前加返回类型?*/
ACL_VSTREAM *stream = (ACL_VSTREAM*) acl_vstream_bind("0.0.0.0:0", 0, 0);
acl_assert(stream);
if (stream == NULL) {
acl_msg_error("%s(%d), %s: acl_vstream_bind error=%s",
__FILE__, __LINE__, __FUNCTION__, acl_last_serror());
return -1;
}
/* 创建异步流 */
dns->astream = acl_aio_open(dns->aio, stream);
@ -375,6 +380,7 @@ static void dns_stream_open(ACL_DNS *dns)
/* 设置该异步流为持续读状态 */
dns->astream->keep_read = 1;
return 0;
}
static void dns_lookup_send(ACL_DNS *dns, const char *domain, unsigned short qid)
@ -434,7 +440,7 @@ static void dns_lookup_timeout(int event_type, ACL_EVENT *event acl_unused,
acl_myfree(req);
}
void acl_dns_init(ACL_DNS *dns, ACL_AIO *aio, int timeout)
int acl_dns_init(ACL_DNS *dns, ACL_AIO *aio, int timeout)
{
dns->flag &= ~ACL_DNS_FLAG_ALLOC; /* 默认为栈空间 */
dns->aio = aio;
@ -453,17 +459,27 @@ void acl_dns_init(ACL_DNS *dns, ACL_AIO *aio, int timeout)
dns->lookup_timeout = dns_lookup_timeout;
/* 打开异步读取DNS服务器响应的数据流 */
dns_stream_open(dns);
if (dns_stream_open(dns) == -1) {
acl_msg_error("%s(%d), %s: dns_stream_open error=%s",
__FILE__, __LINE__, __FUNCTION__, acl_last_serror());
return -1;
}
/* 开始异步读查询结果 */
acl_aio_read(dns->astream);
return 0;
}
ACL_DNS *acl_dns_create(ACL_AIO *aio, int timeout)
{
ACL_DNS *dns = (ACL_DNS*) acl_mycalloc(1, sizeof(ACL_DNS));
acl_dns_init(dns, aio, timeout);
if (acl_dns_init(dns, aio, timeout) == -1) {
acl_myfree(dns);
acl_msg_error("%s(%d), %s: acl_dns_init error",
__FILE__, __LINE__, __FUNCTION__);
return NULL;
}
dns->flag |= ACL_DNS_FLAG_ALLOC; /* 设置为堆分配的变量 */
return dns;
}

View File

@ -34,13 +34,17 @@ void acl_res_set_timeout(int conn_timeout, int rw_timeout)
ACL_RES *acl_res_new(const char *dns_ip, unsigned short dns_port)
{
const char *myname = "acl_res_new";
ACL_RES *res;
if (dns_ip == NULL || *dns_ip == 0)
acl_msg_fatal("%s: dns_ip invalid", myname);
if (dns_port <= 0)
if (dns_ip == NULL || *dns_ip == 0) {
acl_msg_error("%s(%d), %s: dns_ip invalid",
__FILE__, __LINE__, __FUNCTION__);
return NULL;
}
if (dns_port <= 0) {
dns_port = 53;
}
res = acl_mycalloc(1, sizeof(ACL_RES));
res->cur_qid = (unsigned short) time(NULL);
@ -52,25 +56,29 @@ ACL_RES *acl_res_new(const char *dns_ip, unsigned short dns_port)
res->rw_timeout = __rw_timeout;
res->transfer = ACL_RES_USE_UDP;
return (res);
return res;
}
void acl_res_free(ACL_RES *res)
{
if (res)
if (res) {
acl_myfree(res);
}
}
static int udp_res_lookup(ACL_RES *res, const char *data, int dlen, char *buf, int size)
static int udp_res_lookup(ACL_RES *res, const char *data, int dlen,
char *buf, int size)
{
const char *myname = "udp_res_lookup";
ssize_t ret;
ACL_SOCKET fd;
struct sockaddr_in addr;
fd = socket(PF_INET, SOCK_DGRAM, 0);
if (fd == ACL_SOCKET_INVALID)
acl_msg_fatal("%s: socket create error", myname);
if (fd == ACL_SOCKET_INVALID) {
acl_msg_error("%s(%d), %s: socket create error=%s",
__FILE__, __LINE__, __FUNCTION__, acl_last_serror());
return -1;
}
memset(&addr, 0, sizeof(addr));
addr.sin_family = AF_INET;
@ -122,7 +130,7 @@ static int tcp_res_lookup(ACL_RES *res, const char *data,
snprintf(addr, sizeof(addr), "%s:%d", res->dns_ip, res->dns_port);
stream = acl_vstream_connect(addr, ACL_BLOCKING, res->conn_timeout,
res->rw_timeout, 1024);
res->rw_timeout, 1024);
if (stream == NULL) {
res->errnum = ACL_RES_ERR_CONN;
RETURN (-1);
@ -167,15 +175,15 @@ static int tcp_res_lookup(ACL_RES *res, const char *data,
static int res_lookup(ACL_RES *res, const char *data, int dlen,
char *buf, int size)
{
if (res->transfer == ACL_RES_USE_TCP)
if (res->transfer == ACL_RES_USE_TCP) {
return tcp_res_lookup(res, data, dlen, buf, size);
else
} else {
return udp_res_lookup(res, data, dlen, buf, size);
}
}
ACL_DNS_DB *acl_res_lookup(ACL_RES *res, const char *domain)
{
const char *myname = "acl_res_lookup";
ACL_DNS_DB *dns_db;
char buf[1024];
ssize_t ret, i;
@ -183,10 +191,14 @@ ACL_DNS_DB *acl_res_lookup(ACL_RES *res, const char *domain)
ACL_HOSTNAME *phost;
time_t begin;
if (res == NULL)
acl_msg_fatal("%s: res NULL", myname);
if (res == NULL) {
acl_msg_error("%s(%d), %s: res NULL",
__FILE__, __LINE__, __FUNCTION__);
return NULL;
}
if (domain == NULL || *domain == 0) {
acl_msg_error("%s: domain %s", myname, domain ? "empty" : "null");
acl_msg_error("%s(%d), %s: domain %s", __FILE__, __LINE__,
__FUNCTION__, domain ? "empty" : "null");
return NULL;
}
@ -197,8 +209,9 @@ ACL_DNS_DB *acl_res_lookup(ACL_RES *res, const char *domain)
ret = res_lookup(res, buf, (int) ret, buf, sizeof(buf));
res->tm_spent = time(NULL) - begin;
if (ret <= 0)
if (ret <= 0) {
return NULL;
}
ret = rfc1035MessageUnpack(buf, ret, &answers);
if (ret < 0) {
@ -226,8 +239,9 @@ ACL_DNS_DB *acl_res_lookup(ACL_RES *res, const char *domain)
(void) acl_array_append(dns_db->h_db, phost);
dns_db->size++;
} else if (acl_msg_verbose) {
acl_msg_error("%s: can't print answer type %d, domain %s",
myname, (int) answers->answer[i].type, domain);
acl_msg_error("%s(%d), %s: answer type %d, domain %s",
__FILE__, __LINE__, __FUNCTION__,
(int) answers->answer[i].type, domain);
}
}
@ -252,8 +266,9 @@ const char *acl_res_strerror(int errnum)
};
for (i = 0; errmsg[i].errnum != 0; i++) {
if (errmsg[i].errnum == errnum)
if (errmsg[i].errnum == errnum) {
return errmsg[i].msg;
}
}
return rfc1035Strerror(errnum);
@ -261,10 +276,10 @@ const char *acl_res_strerror(int errnum)
const char *acl_res_errmsg(const ACL_RES *res)
{
const char *myname = "acl_res_errmsg";
if (res == NULL)
acl_msg_fatal("%s: res null", myname);
if (res == NULL) {
acl_msg_error("%s: res null", __FUNCTION__);
return "res NULL";
}
return acl_res_strerror(res->errnum);
}