mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-01 03:37:55 +08:00
136 lines
4.6 KiB
Python
136 lines
4.6 KiB
Python
import os
|
||
import sys
|
||
import __main__
|
||
from functools import wraps, partial
|
||
from inspect import ismethod
|
||
from copy import deepcopy
|
||
from io import StringIO
|
||
import time
|
||
import signal
|
||
|
||
import numpy as np
|
||
|
||
from fastNLP.core.utils.utils import get_class_that_defined_method
|
||
from fastNLP.envs.env import FASTNLP_GLOBAL_RANK
|
||
from fastNLP.core.drivers.utils import distributed_open_proc
|
||
from fastNLP.core.log import logger
|
||
|
||
|
||
def recover_logger(fn):
|
||
@wraps(fn)
|
||
def wrapper(*args, **kwargs):
|
||
# 保存logger的状态
|
||
handlers = [handler for handler in logger.handlers]
|
||
level = logger.level
|
||
res = fn(*args, **kwargs)
|
||
logger.handlers = handlers
|
||
logger.setLevel(level)
|
||
return res
|
||
|
||
return wrapper
|
||
|
||
|
||
def magic_argv_env_context(fn=None, timeout=600):
|
||
"""
|
||
用来在测试时包裹每一个单独的测试函数,使得 ddp 测试正确;
|
||
:param timeout: 表示一个测试如果经过多久还没有通过的话就主动将其 kill 掉,默认为 10 分钟,单位为秒;
|
||
:return:
|
||
"""
|
||
# 说明是通过 @magic_argv_env_context(timeout=600) 调用;
|
||
if fn is None:
|
||
return partial(magic_argv_env_context, timeout=timeout)
|
||
|
||
@wraps(fn)
|
||
def wrapper(*args, **kwargs):
|
||
command = deepcopy(sys.argv)
|
||
env = deepcopy(os.environ.copy())
|
||
|
||
used_args = []
|
||
for each_arg in sys.argv[1:]:
|
||
if "test" not in each_arg:
|
||
used_args.append(each_arg)
|
||
|
||
pytest_current_test = os.environ.get('PYTEST_CURRENT_TEST')
|
||
|
||
try:
|
||
l_index = pytest_current_test.index("[")
|
||
r_index = pytest_current_test.index("]")
|
||
subtest = pytest_current_test[l_index: r_index + 1]
|
||
except:
|
||
subtest = ""
|
||
|
||
if not ismethod(fn) and get_class_that_defined_method(fn) is None:
|
||
sys.argv = [sys.argv[0], f"{os.path.abspath(sys.modules[fn.__module__].__file__)}::{fn.__name__}{subtest}"] + used_args
|
||
else:
|
||
sys.argv = [sys.argv[0], f"{os.path.abspath(sys.modules[fn.__module__].__file__)}::{get_class_that_defined_method(fn).__name__}::{fn.__name__}{subtest}"] + used_args
|
||
|
||
def _handle_timeout(signum, frame):
|
||
raise TimeoutError(f"\nYour test fn: {fn.__name__} has timed out.\n")
|
||
|
||
signal.signal(signal.SIGALRM, _handle_timeout)
|
||
signal.alarm(timeout)
|
||
res = fn(*args, **kwargs)
|
||
signal.alarm(0)
|
||
sys.argv = deepcopy(command)
|
||
os.environ = env
|
||
return res
|
||
|
||
return wrapper
|
||
|
||
|
||
class Capturing(list):
|
||
# 用来捕获当前环境中的stdout和stderr,会将其中stderr的输出拼接在stdout的输出后面
|
||
"""
|
||
使用例子
|
||
with Capturing() as output:
|
||
do_something
|
||
|
||
assert 'xxx' in output[0]
|
||
"""
|
||
def __init__(self, no_del=False):
|
||
# 如果no_del为True,则不会删除_stringio,和_stringioerr
|
||
super().__init__()
|
||
self.no_del = no_del
|
||
|
||
def __enter__(self):
|
||
self._stdout = sys.stdout
|
||
self._stderr = sys.stderr
|
||
sys.stdout = self._stringio = StringIO()
|
||
sys.stderr = self._stringioerr = StringIO()
|
||
return self
|
||
|
||
def __exit__(self, *args):
|
||
self.append(self._stringio.getvalue() + self._stringioerr.getvalue())
|
||
if not self.no_del:
|
||
del self._stringio, self._stringioerr # free up some memory
|
||
sys.stdout = self._stdout
|
||
sys.stderr = self._stderr
|
||
|
||
|
||
def re_run_current_cmd_for_torch(num_procs, output_from_new_proc='ignore'):
|
||
# Script called as `python a/b/c.py`
|
||
if int(os.environ.get('LOCAL_RANK', '0')) == 0:
|
||
if __main__.__spec__ is None: # pragma: no-cover
|
||
# pull out the commands used to run the script and resolve the abs file path
|
||
command = sys.argv
|
||
command[0] = os.path.abspath(command[0])
|
||
# use the same python interpreter and actually running
|
||
command = [sys.executable] + command
|
||
# Script called as `python -m a.b.c`
|
||
else:
|
||
command = [sys.executable, "-m", __main__.__spec__._name] + sys.argv[1:]
|
||
|
||
for rank in range(1, num_procs+1):
|
||
env_copy = os.environ.copy()
|
||
env_copy["LOCAL_RANK"] = f"{rank}"
|
||
env_copy['WOLRD_SIZE'] = f'{num_procs+1}'
|
||
env_copy['RANK'] = f'{rank}'
|
||
|
||
# 如果是多机,一定需要用户自己拉起,因此我们自己使用 open_subprocesses 开启的进程的 FASTNLP_GLOBAL_RANK 一定是 LOCAL_RANK;
|
||
env_copy[FASTNLP_GLOBAL_RANK] = str(rank)
|
||
|
||
proc = distributed_open_proc(output_from_new_proc, command, env_copy, None)
|
||
|
||
delay = np.random.uniform(1, 5, 1)[0]
|
||
time.sleep(delay)
|