增加pytest.mark.torch or paddle标记

This commit is contained in:
MorningForest 2022-05-01 17:19:42 +08:00
parent 6b7b062cec
commit 0510f31a5e
11 changed files with 19 additions and 2 deletions

View File

@ -15,6 +15,8 @@ def test_get_element_shape_dtype():
@pytest.mark.parametrize('backend', ['raw', None, 'numpy', 'torch', 'jittor', 'paddle'])
@pytest.mark.torch
@pytest.mark.paddle
def test_get_padder_run(backend):
if not _NEED_IMPORT_TORCH and backend == 'torch':
pytest.skip("No torch")
@ -100,6 +102,7 @@ def test_numpy_padder():
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test')
@pytest.mark.torch
def test_torch_padder():
if not _NEED_IMPORT_TORCH:
pytest.skip("No torch.")

View File

@ -14,6 +14,7 @@ class TestNumpyNumberPadder:
assert (padder(a) == np.array(a)).sum() == 3
@pytest.mark.torch
class TestNumpySequencePadder:
def test_run(self):
padder = NumpySequencePadder(ele_dtype=int, dtype=int, pad_val=-1)

View File

@ -9,6 +9,7 @@ if _NEED_IMPORT_TORCH:
import torch
@pytest.mark.torch
class TestTorchNumberPadder:
def test_run(self):
padder = TorchNumberPadder(ele_dtype=int, dtype=int, pad_val=-1)
@ -18,6 +19,7 @@ class TestTorchNumberPadder:
assert (t_a == torch.LongTensor(a)).sum() == 3
@pytest.mark.torch
class TestTorchSequencePadder:
def test_run(self):
padder = TorchSequencePadder(ele_dtype=int, dtype=int, pad_val=-1)
@ -40,7 +42,7 @@ class TestTorchSequencePadder:
padder = TorchSequencePadder(ele_dtype=np.zeros(2).dtype, dtype=None, pad_val=-1)
@pytest.mark.torch
class TestTorchTensorPadder:
def test_run(self):
padder = TorchTensorPadder(ele_dtype=torch.zeros(3).dtype, dtype=int, pad_val=-1)

View File

@ -45,6 +45,7 @@ def test_get_padded_nest_list():
assert np.shape(a) == (2, 3, 2)
@pytest.mark.torch
def test_is_number_or_numpy_number():
assert is_number_or_numpy_number(type(3)) is True
assert is_number_or_numpy_number(type(3.1)) is True
@ -60,6 +61,7 @@ def test_is_number_or_numpy_number():
assert is_number_or_numpy_number(dtype) is False
@pytest.mark.torch
def test_is_number():
assert is_number(type(3)) is True
assert is_number(type(3.1)) is True
@ -75,6 +77,7 @@ def test_is_number():
assert is_number(dtype) is False
@pytest.mark.torch
def test_is_numpy_number():
assert is_numpy_number_dtype(type(3)) is False
assert is_numpy_number_dtype(type(3.1)) is False

View File

@ -42,6 +42,8 @@ def findListDiff(d1, d2):
class TestCollator:
@pytest.mark.torch
def test_run(self):
dict_batch = [{
'str': '1',

View File

@ -17,6 +17,7 @@ class RandomDataset(Dataset):
return 10
@pytest.mark.paddle
class TestPaddle:
def test_init(self):

View File

@ -5,6 +5,7 @@ from fastNLP.core.dataset import DataSet
from fastNLP.io.data_bundle import DataBundle
@pytest.mark.torch
class TestFdl:
def test_init_v1(self):

View File

@ -69,6 +69,7 @@ def pre_process():
pool.join()
@pytest.mark.torch
@pytest.mark.parametrize('dataset', [
DataSet({'pred': np.random.randint(low=0, high=1, size=(36, 32)),
'target': np.random.randint(low=0, high=1, size=(36, 32))}),

View File

@ -8,11 +8,13 @@ import paddle.distributed.fleet as fleet
from fastNLP.core.metrics import Accuracy
from fastNLP.core.drivers.paddle_driver.fleet_launcher import FleetLauncher
############################################################################
#
# 测试 单机单卡情况下的Accuracy
#
############################################################################
@pytest.mark.paddle
def test_accuracy_single():
pred = paddle.to_tensor([[1.19812393, -0.82041764, -0.53517765, -0.73061031, -1.45006669,
0.46514302],
@ -56,4 +58,3 @@ def test_accuracy_ddp():
pass
elif fleet.is_worker():
print(os.getenv("PADDLE_TRAINER_ID"))

View File

@ -29,6 +29,7 @@ def _test(local_rank: int, world_size: int, device: torch.device,
np.allclose(my_result[keys], metric_result[keys], atol=0.000001)
@pytest.mark.torch
class TestClassfiyFPreRecMetric:
def test_case_1(self):
pred = torch.tensor([[-0.4375, -0.1779, -1.0985, -1.1592, 0.4910],

View File

@ -66,6 +66,7 @@ def _test(local_rank: int,
assert my_result == sklearn_metric
@pytest.mark.torch
class TestSpanFPreRecMetric:
def test_case1(self):