mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-11-30 19:28:17 +08:00
增加pytest.mark.torch or paddle标记
This commit is contained in:
parent
6b7b062cec
commit
0510f31a5e
@ -15,6 +15,8 @@ def test_get_element_shape_dtype():
|
||||
|
||||
|
||||
@pytest.mark.parametrize('backend', ['raw', None, 'numpy', 'torch', 'jittor', 'paddle'])
|
||||
@pytest.mark.torch
|
||||
@pytest.mark.paddle
|
||||
def test_get_padder_run(backend):
|
||||
if not _NEED_IMPORT_TORCH and backend == 'torch':
|
||||
pytest.skip("No torch")
|
||||
@ -100,6 +102,7 @@ def test_numpy_padder():
|
||||
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test')
|
||||
|
||||
|
||||
@pytest.mark.torch
|
||||
def test_torch_padder():
|
||||
if not _NEED_IMPORT_TORCH:
|
||||
pytest.skip("No torch.")
|
||||
|
@ -14,6 +14,7 @@ class TestNumpyNumberPadder:
|
||||
assert (padder(a) == np.array(a)).sum() == 3
|
||||
|
||||
|
||||
@pytest.mark.torch
|
||||
class TestNumpySequencePadder:
|
||||
def test_run(self):
|
||||
padder = NumpySequencePadder(ele_dtype=int, dtype=int, pad_val=-1)
|
||||
|
@ -9,6 +9,7 @@ if _NEED_IMPORT_TORCH:
|
||||
import torch
|
||||
|
||||
|
||||
@pytest.mark.torch
|
||||
class TestTorchNumberPadder:
|
||||
def test_run(self):
|
||||
padder = TorchNumberPadder(ele_dtype=int, dtype=int, pad_val=-1)
|
||||
@ -18,6 +19,7 @@ class TestTorchNumberPadder:
|
||||
assert (t_a == torch.LongTensor(a)).sum() == 3
|
||||
|
||||
|
||||
@pytest.mark.torch
|
||||
class TestTorchSequencePadder:
|
||||
def test_run(self):
|
||||
padder = TorchSequencePadder(ele_dtype=int, dtype=int, pad_val=-1)
|
||||
@ -40,7 +42,7 @@ class TestTorchSequencePadder:
|
||||
padder = TorchSequencePadder(ele_dtype=np.zeros(2).dtype, dtype=None, pad_val=-1)
|
||||
|
||||
|
||||
|
||||
@pytest.mark.torch
|
||||
class TestTorchTensorPadder:
|
||||
def test_run(self):
|
||||
padder = TorchTensorPadder(ele_dtype=torch.zeros(3).dtype, dtype=int, pad_val=-1)
|
||||
|
@ -45,6 +45,7 @@ def test_get_padded_nest_list():
|
||||
assert np.shape(a) == (2, 3, 2)
|
||||
|
||||
|
||||
@pytest.mark.torch
|
||||
def test_is_number_or_numpy_number():
|
||||
assert is_number_or_numpy_number(type(3)) is True
|
||||
assert is_number_or_numpy_number(type(3.1)) is True
|
||||
@ -60,6 +61,7 @@ def test_is_number_or_numpy_number():
|
||||
assert is_number_or_numpy_number(dtype) is False
|
||||
|
||||
|
||||
@pytest.mark.torch
|
||||
def test_is_number():
|
||||
assert is_number(type(3)) is True
|
||||
assert is_number(type(3.1)) is True
|
||||
@ -75,6 +77,7 @@ def test_is_number():
|
||||
assert is_number(dtype) is False
|
||||
|
||||
|
||||
@pytest.mark.torch
|
||||
def test_is_numpy_number():
|
||||
assert is_numpy_number_dtype(type(3)) is False
|
||||
assert is_numpy_number_dtype(type(3.1)) is False
|
||||
|
@ -42,6 +42,8 @@ def findListDiff(d1, d2):
|
||||
|
||||
|
||||
class TestCollator:
|
||||
|
||||
@pytest.mark.torch
|
||||
def test_run(self):
|
||||
dict_batch = [{
|
||||
'str': '1',
|
||||
|
@ -17,6 +17,7 @@ class RandomDataset(Dataset):
|
||||
return 10
|
||||
|
||||
|
||||
@pytest.mark.paddle
|
||||
class TestPaddle:
|
||||
|
||||
def test_init(self):
|
||||
|
@ -5,6 +5,7 @@ from fastNLP.core.dataset import DataSet
|
||||
from fastNLP.io.data_bundle import DataBundle
|
||||
|
||||
|
||||
@pytest.mark.torch
|
||||
class TestFdl:
|
||||
|
||||
def test_init_v1(self):
|
||||
|
@ -69,6 +69,7 @@ def pre_process():
|
||||
pool.join()
|
||||
|
||||
|
||||
@pytest.mark.torch
|
||||
@pytest.mark.parametrize('dataset', [
|
||||
DataSet({'pred': np.random.randint(low=0, high=1, size=(36, 32)),
|
||||
'target': np.random.randint(low=0, high=1, size=(36, 32))}),
|
||||
|
@ -8,11 +8,13 @@ import paddle.distributed.fleet as fleet
|
||||
from fastNLP.core.metrics import Accuracy
|
||||
from fastNLP.core.drivers.paddle_driver.fleet_launcher import FleetLauncher
|
||||
|
||||
|
||||
############################################################################
|
||||
#
|
||||
# 测试 单机单卡情况下的Accuracy
|
||||
#
|
||||
############################################################################
|
||||
@pytest.mark.paddle
|
||||
def test_accuracy_single():
|
||||
pred = paddle.to_tensor([[1.19812393, -0.82041764, -0.53517765, -0.73061031, -1.45006669,
|
||||
0.46514302],
|
||||
@ -56,4 +58,3 @@ def test_accuracy_ddp():
|
||||
pass
|
||||
elif fleet.is_worker():
|
||||
print(os.getenv("PADDLE_TRAINER_ID"))
|
||||
|
||||
|
@ -29,6 +29,7 @@ def _test(local_rank: int, world_size: int, device: torch.device,
|
||||
np.allclose(my_result[keys], metric_result[keys], atol=0.000001)
|
||||
|
||||
|
||||
@pytest.mark.torch
|
||||
class TestClassfiyFPreRecMetric:
|
||||
def test_case_1(self):
|
||||
pred = torch.tensor([[-0.4375, -0.1779, -1.0985, -1.1592, 0.4910],
|
||||
|
@ -66,6 +66,7 @@ def _test(local_rank: int,
|
||||
assert my_result == sklearn_metric
|
||||
|
||||
|
||||
@pytest.mark.torch
|
||||
class TestSpanFPreRecMetric:
|
||||
|
||||
def test_case1(self):
|
||||
|
Loading…
Reference in New Issue
Block a user