mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-05 21:59:30 +08:00
425 lines
12 KiB
Python
425 lines
12 KiB
Python
import os
|
||
import pytest
|
||
import subprocess
|
||
from io import StringIO
|
||
import sys
|
||
sys.path.append(os.path.join(os.path.dirname(__file__), '../../..'))
|
||
|
||
from fastNLP.core.utils.cache_results import cache_results
|
||
from fastNLP.envs.distributed import rank_zero_rm
|
||
|
||
|
||
def get_subprocess_results(cmd):
|
||
output = subprocess.check_output(cmd, shell=True)
|
||
return output.decode('utf8')
|
||
|
||
|
||
class Capturing(list):
|
||
# 用来捕获当前环境中的stdout和stderr,会将其中stderr的输出拼接在stdout的输出后面
|
||
def __enter__(self):
|
||
self._stdout = sys.stdout
|
||
self._stderr = sys.stderr
|
||
sys.stdout = self._stringio = StringIO()
|
||
sys.stderr = self._stringioerr = StringIO()
|
||
return self
|
||
|
||
def __exit__(self, *args):
|
||
self.append(self._stringio.getvalue() + self._stringioerr.getvalue())
|
||
del self._stringio, self._stringioerr # free up some memory
|
||
sys.stdout = self._stdout
|
||
sys.stderr = self._stderr
|
||
|
||
|
||
class TestCacheResults:
|
||
def test_cache_save(self):
|
||
cache_fp = 'demo.pkl'
|
||
try:
|
||
@cache_results(cache_fp)
|
||
def demo():
|
||
print("¥")
|
||
return 1
|
||
res = demo()
|
||
with Capturing() as output:
|
||
res = demo()
|
||
assert '¥' not in output[0]
|
||
|
||
finally:
|
||
rank_zero_rm(cache_fp)
|
||
|
||
def test_cache_save_refresh(self):
|
||
cache_fp = 'demo.pkl'
|
||
try:
|
||
@cache_results(cache_fp, _refresh=True)
|
||
def demo():
|
||
print("¥")
|
||
return 1
|
||
|
||
res = demo()
|
||
with Capturing() as output:
|
||
res = demo()
|
||
assert '¥' in output[0]
|
||
finally:
|
||
rank_zero_rm(cache_fp)
|
||
|
||
def test_cache_no_func_change(self):
|
||
cache_fp = os.path.abspath('demo.pkl')
|
||
try:
|
||
@cache_results(cache_fp)
|
||
def demo():
|
||
print('¥')
|
||
return 1
|
||
|
||
with Capturing() as output:
|
||
res = demo()
|
||
assert '¥' in output[0]
|
||
|
||
@cache_results(cache_fp)
|
||
def demo():
|
||
print('¥')
|
||
return 1
|
||
|
||
with Capturing() as output:
|
||
res = demo()
|
||
assert '¥' not in output[0]
|
||
finally:
|
||
rank_zero_rm('demo.pkl')
|
||
|
||
def test_cache_func_change(self, capsys):
|
||
cache_fp = 'demo.pkl'
|
||
try:
|
||
@cache_results(cache_fp)
|
||
def demo():
|
||
print('¥')
|
||
return 1
|
||
|
||
with Capturing() as output:
|
||
res = demo()
|
||
assert '¥' in output[0]
|
||
|
||
@cache_results(cache_fp)
|
||
def demo():
|
||
print('¥¥')
|
||
return 1
|
||
|
||
with Capturing() as output:
|
||
res = demo()
|
||
assert 'different' in output[0]
|
||
assert '¥' not in output[0]
|
||
|
||
# 关闭check_hash应该不warning的
|
||
with Capturing() as output:
|
||
res = demo(_check_hash=0)
|
||
assert 'different' not in output[0]
|
||
assert '¥' not in output[0]
|
||
|
||
finally:
|
||
rank_zero_rm('demo.pkl')
|
||
|
||
def test_cache_check_hash(self):
|
||
cache_fp = 'demo.pkl'
|
||
try:
|
||
@cache_results(cache_fp, _check_hash=False)
|
||
def demo():
|
||
print('¥')
|
||
return 1
|
||
|
||
with Capturing() as output:
|
||
res = demo(_check_hash=0)
|
||
assert '¥' in output[0]
|
||
|
||
@cache_results(cache_fp, _check_hash=False)
|
||
def demo():
|
||
print('¥¥')
|
||
return 1
|
||
|
||
# 默认不会check
|
||
with Capturing() as output:
|
||
res = demo()
|
||
assert 'different' not in output[0]
|
||
assert '¥' not in output[0]
|
||
|
||
# check也可以
|
||
with Capturing() as output:
|
||
res = demo(_check_hash=True)
|
||
assert 'different' in output[0]
|
||
assert '¥' not in output[0]
|
||
|
||
finally:
|
||
rank_zero_rm('demo.pkl')
|
||
|
||
# 外部 function 改变也会 导致改变
|
||
def test_refer_fun_change(self):
|
||
cache_fp = 'demo.pkl'
|
||
test_type = 'func_refer_fun_change'
|
||
try:
|
||
cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 0'
|
||
res = get_subprocess_results(cmd)
|
||
assert "¥" in res
|
||
# 引用的function没有变化
|
||
cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 0'
|
||
res = get_subprocess_results(cmd)
|
||
assert "¥" not in res
|
||
|
||
assert 'Read' in res
|
||
assert 'different' not in res
|
||
|
||
# 引用的function有变化
|
||
cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 1'
|
||
res = get_subprocess_results(cmd)
|
||
assert "¥" not in res
|
||
assert 'different' in res
|
||
|
||
finally:
|
||
rank_zero_rm(cache_fp)
|
||
|
||
# 外部 method 改变也会 导致改变
|
||
def test_refer_class_method_change(self):
|
||
cache_fp = 'demo.pkl'
|
||
test_type = 'refer_class_method_change'
|
||
try:
|
||
cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 0'
|
||
res = get_subprocess_results(cmd)
|
||
assert "¥" in res
|
||
|
||
# 引用的class没有变化
|
||
cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 0'
|
||
res = get_subprocess_results(cmd)
|
||
assert 'Read' in res
|
||
assert 'different' not in res
|
||
assert "¥" not in res
|
||
|
||
cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 1'
|
||
res = get_subprocess_results(cmd)
|
||
assert 'different' in res
|
||
assert "¥" not in res
|
||
|
||
finally:
|
||
rank_zero_rm(cache_fp)
|
||
|
||
def test_duplicate_keyword(self):
|
||
with pytest.raises(RuntimeError):
|
||
@cache_results(None)
|
||
def func_verbose(a, _verbose):
|
||
pass
|
||
|
||
func_verbose(0, 1)
|
||
with pytest.raises(RuntimeError):
|
||
@cache_results(None)
|
||
def func_cache(a, _cache_fp):
|
||
pass
|
||
|
||
func_cache(1, 2)
|
||
with pytest.raises(RuntimeError):
|
||
@cache_results(None)
|
||
def func_refresh(a, _refresh):
|
||
pass
|
||
|
||
func_refresh(1, 2)
|
||
|
||
with pytest.raises(RuntimeError):
|
||
@cache_results(None)
|
||
def func_refresh(a, _check_hash):
|
||
pass
|
||
|
||
func_refresh(1, 2)
|
||
|
||
def test_create_cache_dir(self):
|
||
@cache_results('demo/demo.pkl')
|
||
def cache():
|
||
return 1, 2
|
||
|
||
try:
|
||
results = cache()
|
||
assert (1, 2) == results
|
||
finally:
|
||
rank_zero_rm('demo/')
|
||
|
||
def test_result_none_error(self):
|
||
@cache_results('demo.pkl')
|
||
def cache():
|
||
pass
|
||
|
||
try:
|
||
with pytest.raises(RuntimeError):
|
||
results = cache()
|
||
finally:
|
||
rank_zero_rm('demo.pkl')
|
||
|
||
|
||
def remove_postfix(folder='.', post_fix='.pkl'):
|
||
import os
|
||
for f in os.listdir(folder):
|
||
if os.path.isfile(f) and f.endswith(post_fix):
|
||
os.remove(os.path.join(folder, f))
|
||
|
||
|
||
class TestCacheResultsWithParam:
|
||
@pytest.mark.parametrize('_refresh', [True, False])
|
||
@pytest.mark.parametrize('_hash_param', [True, False])
|
||
@pytest.mark.parametrize('_verbose', [0, 1])
|
||
@pytest.mark.parametrize('_check_hash', [True, False])
|
||
def test_cache_save(self, _refresh, _hash_param, _verbose, _check_hash):
|
||
cache_fp = 'demo.pkl'
|
||
try:
|
||
@cache_results(cache_fp, _refresh=_refresh, _hash_param=_hash_param, _verbose=_verbose,
|
||
_check_hash=_check_hash)
|
||
def demo(a=1):
|
||
print("¥")
|
||
return 1
|
||
res = demo()
|
||
|
||
with Capturing() as output:
|
||
res = demo(a=1)
|
||
if _refresh is False:
|
||
assert '¥' not in output[0]
|
||
if _verbose is 0:
|
||
assert 'read' not in output[0]
|
||
|
||
with Capturing() as output:
|
||
res = demo(1)
|
||
if _refresh is False:
|
||
assert '¥' not in output[0]
|
||
|
||
with Capturing() as output:
|
||
res = demo(a=2)
|
||
if _hash_param is True: # 一定对不上,需要重新生成
|
||
assert '¥' in output[0]
|
||
|
||
finally:
|
||
remove_postfix('.')
|
||
|
||
def test_cache_complex_param(self):
|
||
cache_fp = 'demo.pkl'
|
||
try:
|
||
@cache_results(cache_fp, _refresh=False)
|
||
def demo(*args, s=1, **kwargs):
|
||
print("¥")
|
||
return 1
|
||
|
||
res = demo(1,2,3, s=4, d=4)
|
||
with Capturing() as output:
|
||
res = demo(1,2,3,d=4, s=4)
|
||
assert '¥' not in output[0]
|
||
finally:
|
||
remove_postfix('.')
|
||
|
||
def test_wrapper_change(self):
|
||
cache_fp = 'demo.pkl'
|
||
test_type = 'wrapper_change'
|
||
try:
|
||
cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 0'
|
||
res = get_subprocess_results(cmd)
|
||
assert "¥" in res
|
||
cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 1'
|
||
res = get_subprocess_results(cmd)
|
||
assert "¥" not in res
|
||
assert 'Read' in res
|
||
assert 'different' not in res
|
||
|
||
finally:
|
||
remove_postfix('.')
|
||
|
||
def test_param_change(self):
|
||
cache_fp = 'demo.pkl'
|
||
test_type = 'param_change'
|
||
try:
|
||
cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 0'
|
||
res = get_subprocess_results(cmd)
|
||
assert "¥" in res
|
||
cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 1'
|
||
res = get_subprocess_results(cmd)
|
||
assert "¥" in res
|
||
assert 'Read' not in res
|
||
finally:
|
||
remove_postfix('.')
|
||
|
||
def test_create_cache_dir(self):
|
||
@cache_results('demo/demo.pkl')
|
||
def cache(s):
|
||
return 1, 2
|
||
|
||
try:
|
||
results = cache(s=1)
|
||
assert (1, 2) == results
|
||
finally:
|
||
import shutil
|
||
shutil.rmtree('demo/')
|
||
|
||
|
||
if __name__ == '__main__':
|
||
import argparse
|
||
parser = argparse.ArgumentParser()
|
||
parser.add_argument('--test_type', type=str, default='refer_class_method_change')
|
||
parser.add_argument('--turn', type=int, default=1)
|
||
parser.add_argument('--cache_fp', type=str, default='demo.pkl')
|
||
args = parser.parse_args()
|
||
|
||
test_type = args.test_type
|
||
cache_fp = args.cache_fp
|
||
turn = args.turn
|
||
|
||
if test_type == 'func_refer_fun_change':
|
||
if turn == 0:
|
||
def demo():
|
||
b = 1
|
||
return b
|
||
else:
|
||
def demo():
|
||
b = 2
|
||
return b
|
||
|
||
@cache_results(cache_fp)
|
||
def demo_refer_other_func():
|
||
b = demo()
|
||
print("¥")
|
||
return b
|
||
|
||
res = demo_refer_other_func()
|
||
|
||
if test_type == 'refer_class_method_change':
|
||
print(f"Turn:{turn}")
|
||
if turn == 0:
|
||
from helper_for_cache_results_1 import Demo
|
||
else:
|
||
from helper_for_cache_results_2 import Demo
|
||
|
||
demo = Demo()
|
||
# import pdb
|
||
# pdb.set_trace()
|
||
@cache_results(cache_fp)
|
||
def demo_func():
|
||
print("¥")
|
||
b = demo.demo()
|
||
return b
|
||
|
||
res = demo_func()
|
||
|
||
if test_type == 'wrapper_change':
|
||
if turn == 0:
|
||
@cache_results(cache_fp, _refresh=True)
|
||
def demo_wrapper_change():
|
||
print("¥")
|
||
return 1
|
||
else:
|
||
@cache_results(cache_fp, _refresh=False)
|
||
def demo_wrapper_change():
|
||
print("¥")
|
||
return 1
|
||
|
||
res = demo_wrapper_change()
|
||
|
||
if test_type == 'param_change':
|
||
if turn == 0:
|
||
@cache_results(cache_fp, _refresh=False)
|
||
def demo_param_change():
|
||
print("¥")
|
||
return 1
|
||
else:
|
||
@cache_results(cache_fp, _refresh=False)
|
||
def demo_param_change(a=1):
|
||
print("¥")
|
||
return 1
|
||
|
||
res = demo_param_change()
|
||
|