#include "lib_acl.h" #include #include #include #include #include #ifndef __USE_GNU #define __USE_GNU #endif #include #include "hook_close.h" static acl_pthread_key_t __trace_key; static acl_pthread_once_t __trace_once = ACL_PTHREAD_ONCE_INIT; static unsigned int *__main_buf = NULL; static void trace_buf_free(void *buf) { if ((unsigned long) acl_pthread_self() != acl_main_thread_self()) free(buf); } static void main_buf_free(void) { if (__main_buf) free(__main_buf); } static void trace_buf_init(void) { acl_assert(acl_pthread_key_create(&__trace_key, trace_buf_free) == 0); } static void trace_info(const char* prompt) { void *buffer[1000]; size_t n, i; char **results; unsigned int *intbuf; if (acl_pthread_once(&__trace_once, trace_buf_init) != 0) return; intbuf = (unsigned int *) acl_pthread_getspecific(__trace_key); if (intbuf == NULL) { intbuf = (unsigned int *) malloc(sizeof(int)); *intbuf = 0; acl_assert(acl_pthread_setspecific(__trace_key, intbuf) == 0); if ((unsigned long) acl_pthread_self() == acl_main_thread_self()) { __main_buf = intbuf; atexit(main_buf_free); } } if ((*intbuf) > 0) return; n = backtrace(buffer, 1000); if (n == 0) return; (*intbuf)++; results = backtrace_symbols(buffer, n); for (i = 0; i < n; i++) acl_msg_info("%s->backtrace: %s", prompt, results[i]); if (results != NULL) free(results); (*intbuf)--; } /****************************************************************************/ typedef int (*close_fn)(int); static close_fn __sys_close = NULL; void hook_setup(void) __attribute__ ((constructor)); void hook_setup(void) { __sys_close = (close_fn) dlsym(RTLD_NEXT, "close"); } static int __monitor_ports[100]; static int __monitor_ports_max = sizeof(__monitor_ports); static int __monitor_ports_cur = 0; int monitor_port_add(int port) { if (port <= 0) { acl_msg_error("%s(%d), %s: invalid port: %d", __FILE__, __LINE__, __FUNCTION__, port); return -1; } if (__monitor_ports_cur >= __monitor_ports_max) { acl_msg_error("too many ports added, cur: %d, max: %d", __monitor_ports_cur, __monitor_ports_max); return __monitor_ports_cur; } __monitor_ports[__monitor_ports_cur] = port; __monitor_ports_cur++; return __monitor_ports_cur; } static int should_check(int port) { int i; if (port <= 0) { acl_msg_error("%s(%d), %s: invalid port: %d", __FILE__, __LINE__, __FUNCTION__, port); return 0; } for (i = 0; i < __monitor_ports_cur; i++) { if (__monitor_ports[i] == port) return 1; } return 0; } int close(int fd) { char peer[64], *at, buf[256]; if (fd < 0) { acl_msg_error("%s(%d), %s: invalid fd: %d", __FILE__, __LINE__, __FUNCTION__, fd); trace_info("invalid fd"); return -1; } if (__monitor_ports_cur == 0 || acl_getpeername(fd, peer, sizeof(peer)) == -1) { peer[0] = 0; } else if ((at = strchr(peer, ':')) && *++at && should_check(atoi(at))) { snprintf(buf, sizeof(buf), "closing %s", peer); trace_info(buf); } if (__sys_close(fd) == 0) return 0; snprintf(buf, sizeof(buf), "close error %s, fd=%d, peer=%s", acl_last_serror(), fd, peer); trace_info(buf); return -1; }