fastNLP/tests/helpers/utils.py

124 lines
4.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import sys
import __main__
from functools import wraps
import inspect
from inspect import ismethod
import functools
from copy import deepcopy
from io import StringIO
import time
import numpy as np
from fastNLP.envs.env import FASTNLP_GLOBAL_RANK
from fastNLP.core.drivers.utils import distributed_open_proc
def get_class_that_defined_method(meth):
if isinstance(meth, functools.partial):
return get_class_that_defined_method(meth.func)
if inspect.ismethod(meth) or (inspect.isbuiltin(meth) and getattr(meth, '__self__', None) is not None and getattr(meth.__self__, '__class__', None)):
for cls in inspect.getmro(meth.__self__.__class__):
if meth.__name__ in cls.__dict__:
return cls
meth = getattr(meth, '__func__', meth) # fallback to __qualname__ parsing
if inspect.isfunction(meth):
cls = getattr(inspect.getmodule(meth),
meth.__qualname__.split('.<locals>', 1)[0].rsplit('.', 1)[0],
None)
if isinstance(cls, type):
return cls
return getattr(meth, '__objclass__', None) # handle special descriptor objects
def magic_argv_env_context(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
command = deepcopy(sys.argv)
env = deepcopy(os.environ.copy())
used_args = []
for each_arg in sys.argv[1:]:
if "test" not in each_arg:
used_args.append(each_arg)
pytest_current_test = os.environ.get('PYTEST_CURRENT_TEST')
try:
l_index = pytest_current_test.index("[")
r_index = pytest_current_test.index("]")
subtest = pytest_current_test[l_index: r_index + 1]
except:
subtest = ""
if not ismethod(fn) and get_class_that_defined_method(fn) is None:
sys.argv = [sys.argv[0], f"{os.path.abspath(sys.modules[fn.__module__].__file__)}::{fn.__name__}{subtest}"] + used_args
else:
sys.argv = [sys.argv[0], f"{os.path.abspath(sys.modules[fn.__module__].__file__)}::{get_class_that_defined_method(fn).__name__}::{fn.__name__}{subtest}"] + used_args
res = fn(*args, **kwargs)
sys.argv = deepcopy(command)
os.environ = env
return res
return wrapper
class Capturing(list):
# 用来捕获当前环境中的stdout和stderr会将其中stderr的输出拼接在stdout的输出后面
"""
使用例子
with Capturing() as output:
do_something
assert 'xxx' in output[0]
"""
def __init__(self, no_del=False):
# 如果no_del为True则不会删除_stringio和_stringioerr
super().__init__()
self.no_del = no_del
def __enter__(self):
self._stdout = sys.stdout
self._stderr = sys.stderr
sys.stdout = self._stringio = StringIO()
sys.stderr = self._stringioerr = StringIO()
return self
def __exit__(self, *args):
self.append(self._stringio.getvalue() + self._stringioerr.getvalue())
if not self.no_del:
del self._stringio, self._stringioerr # free up some memory
sys.stdout = self._stdout
sys.stderr = self._stderr
def re_run_current_cmd_for_torch(num_procs, output_from_new_proc='ignore'):
# Script called as `python a/b/c.py`
if int(os.environ.get('LOCAL_RANK', '0')) == 0:
if __main__.__spec__ is None: # pragma: no-cover
# pull out the commands used to run the script and resolve the abs file path
command = sys.argv
command[0] = os.path.abspath(command[0])
# use the same python interpreter and actually running
command = [sys.executable] + command
# Script called as `python -m a.b.c`
else:
command = [sys.executable, "-m", __main__.__spec__._name] + sys.argv[1:]
for rank in range(1, num_procs+1):
env_copy = os.environ.copy()
env_copy["LOCAL_RANK"] = f"{rank}"
env_copy['WOLRD_SIZE'] = f'{num_procs+1}'
env_copy['RANK'] = f'{rank}'
# 如果是多机,一定需要用户自己拉起,因此我们自己使用 open_subprocesses 开启的进程的 FASTNLP_GLOBAL_RANK 一定是 LOCAL_RANK
env_copy[FASTNLP_GLOBAL_RANK] = str(rank)
proc = distributed_open_proc(output_from_new_proc, command, env_copy, None)
delay = np.random.uniform(1, 5, 1)[0]
time.sleep(delay)