限制测试时使用的设备;为fsdp的测试添加版本限制

This commit is contained in:
x54-729 2022-09-19 16:31:25 +08:00
parent 5ebc8d1be4
commit 82fe167ea3
4 changed files with 22 additions and 15 deletions

View File

@ -86,7 +86,7 @@ class BiLSTMCRF(nn.Module):
:param seq_len: 每个句子的长度形状为 ``[batch,]``
:return: 如果 ``target`` ``None``则返回预测结果 ``{'pred': torch.Tensor}``否则返回 loss ``{'loss': torch.Tensor}``
"""
return self(words, seq_len, target)
return self(words, target, seq_len)
def evaluate_step(self, words: "torch.LongTensor", seq_len: "torch.LongTensor"):
"""
@ -94,7 +94,7 @@ class BiLSTMCRF(nn.Module):
:param seq_len: 每个句子的长度形状为 ``[batch,]``
:return: 预测结果 ``{'pred': torch.Tensor}``
"""
return self(words, seq_len)
return self(words, seq_len=seq_len)
class SeqLabeling(nn.Module):
@ -286,7 +286,7 @@ class AdvSeqLabel(nn.Module):
:param seq_len: 每个句子的长度形状为 ``[batch,]``
:return: 如果 ``target`` ``None``则返回预测结果 ``{'pred': torch.Tensor}``否则返回 loss ``{'loss': torch.Tensor}``
"""
return self(words, seq_len, target)
return self(words, target, seq_len)
def evaluate_step(self, words: "torch.LongTensor", seq_len: "torch.LongTensor"):
"""
@ -294,4 +294,4 @@ class AdvSeqLabel(nn.Module):
:param seq_len: 每个句子的长度形状为 ``[batch,]``
:return: 预测结果 ``{'pred': torch.Tensor}``
"""
return self(words, seq_len)
return self(words, seq_len=seq_len)

View File

@ -103,8 +103,8 @@ def model_and_optimizers(request):
# 测试一下普通的情况;
@pytest.mark.torch
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 4),
("torch", [4, 5])]) # ("torch", "cpu"), ("torch", 1), ("torch", [0, 1])
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 1),
("torch", [0, 1])]) # ("torch", "cpu"), ("torch", 1), ("torch", [0, 1])
@pytest.mark.parametrize("evaluate_every", [-3, -1, 2])
@magic_argv_env_context
def test_trainer_torch_with_evaluator(
@ -139,7 +139,7 @@ def test_trainer_torch_with_evaluator(
@pytest.mark.torch
@pytest.mark.parametrize("driver,device", [("torch", [4, 5]), ("torch", 4)]) # ("torch", [0, 1]),("torch", 1)
@pytest.mark.parametrize("driver,device", [("torch", [0, 1]), ("torch", 1)]) # ("torch", [0, 1]),("torch", 1)
@pytest.mark.parametrize("fp16", [True, False])
@pytest.mark.parametrize("accumulation_steps", [1, 3])
@magic_argv_env_context
@ -250,7 +250,7 @@ def test_trainer_on(
@pytest.mark.torch
@pytest.mark.parametrize("driver,device", [("torch", 'cpu'), ("torch", 4)]) # ("torch", [0, 1]),("torch", 1)
@pytest.mark.parametrize("driver,device", [("torch", 'cpu'), ("torch", 1)]) # ("torch", [0, 1]),("torch", 1)
@magic_argv_env_context
def test_trainer_specific_params_1(
model_and_optimizers: TrainerParameters,
@ -291,7 +291,7 @@ def test_trainer_specific_params_1(
@pytest.mark.torch
@pytest.mark.parametrize("driver,device", [("torch", [4, 5])]) # ("torch", [0, 1]),("torch", 1)
@pytest.mark.parametrize("driver,device", [("torch", [0, 1])]) # ("torch", [0, 1]),("torch", 1)
@magic_argv_env_context
def test_trainer_specific_params_2(
model_and_optimizers: TrainerParameters,
@ -340,7 +340,7 @@ def test_trainer_specific_params_2(
@pytest.mark.torch
@pytest.mark.parametrize("driver,device", [("torch", 4), ("torch", [4, 5])]) # ("torch", [0, 1]),("torch", 1)
@pytest.mark.parametrize("driver,device", [("torch", 1), ("torch", [0, 1])]) # ("torch", [0, 1]),("torch", 1)
@pytest.mark.parametrize("overfit_batches,num_train_batch_per_epoch", [(-1, -1), (0, -1), (3, 10), (6, -1)])
@magic_argv_env_context
def test_trainer_w_evaluator_overfit_torch(

View File

@ -14,7 +14,7 @@ from tests.helpers.callbacks.helper_callbacks import RecordLossCallback
from tests.helpers.callbacks.helper_callbacks_torch import RecordAccumulationStepsCallback_Torch
from tests.helpers.utils import magic_argv_env_context, Capturing
from fastNLP.envs.distributed import rank_zero_rm
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _TORCH_GREATER_EQUAL_1_12
if _NEED_IMPORT_TORCH:
import torch.distributed as dist
from torch.optim import SGD
@ -290,7 +290,7 @@ def test_trainer_on_exception(
@pytest.mark.torch
@pytest.mark.parametrize("version", [0, 1, 2, 3])
@pytest.mark.parametrize("version", [0, 1])
@magic_argv_env_context
def test_torch_distributed_launch_1(version):
"""
@ -304,7 +304,7 @@ def test_torch_distributed_launch_1(version):
@pytest.mark.torch
@pytest.mark.parametrize("version", [0, 1, 2, 3])
@pytest.mark.parametrize("version", [0, 1])
@magic_argv_env_context
def test_torch_distributed_launch_2(version):
"""
@ -325,6 +325,8 @@ def test_torch_wo_auto_param_call(
device,
n_epochs=2,
):
if driver == "torch_fsdp" and not _TORCH_GREATER_EQUAL_1_12:
pytest.skip("fsdp 需要 torch 在 1.12 及以上")
model = TorchNormalModel_Classification_3(
num_labels=NormalClassificationTrainTorchConfig.num_labels,
@ -373,6 +375,9 @@ def test_trainer_overfit_torch(
overfit_batches,
num_train_batch_per_epoch
):
if driver == "torch_fsdp" and not _TORCH_GREATER_EQUAL_1_12:
pytest.skip("fsdp 需要 torch 在 1.12 及以上")
trainer = Trainer(
model=model_and_optimizers.model,
driver=driver,

View File

@ -11,7 +11,7 @@ from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification, TorchArgMaxDataset
from tests.helpers.callbacks.helper_callbacks import RecordLossCallback
from tests.helpers.utils import magic_argv_env_context
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _TORCH_GREATER_EQUAL_1_12
from fastNLP.envs import FASTNLP_LAUNCH_TIME, rank_zero_rm
if _NEED_IMPORT_TORCH:
import torch.distributed as dist
@ -67,6 +67,7 @@ def model_and_optimizers(request):
return trainer_params
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, "fsdp 需要 torch 版本在 1.12 及以上")
@pytest.mark.torch
@magic_argv_env_context
def test_trainer_torch_without_evaluator(
@ -76,7 +77,7 @@ def test_trainer_torch_without_evaluator(
trainer = Trainer(
model=model_and_optimizers.model,
driver="torch_fsdp",
device=[4, 5],
device=[0, 1],
optimizers=model_and_optimizers.optimizers,
train_dataloader=model_and_optimizers.train_dataloader,
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders,
@ -96,6 +97,7 @@ def test_trainer_torch_without_evaluator(
dist.destroy_process_group()
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, "fsdp 需要 torch 版本在 1.12 及以上")
@pytest.mark.torch
@pytest.mark.parametrize("save_on_rank0", [True, False])
@magic_argv_env_context(timeout=100)