mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-01 19:57:34 +08:00
限制测试时使用的设备;为fsdp的测试添加版本限制
This commit is contained in:
parent
5ebc8d1be4
commit
82fe167ea3
@ -86,7 +86,7 @@ class BiLSTMCRF(nn.Module):
|
||||
:param seq_len: 每个句子的长度,形状为 ``[batch,]``
|
||||
:return: 如果 ``target`` 为 ``None``,则返回预测结果 ``{'pred': torch.Tensor}``,否则返回 loss ``{'loss': torch.Tensor}``
|
||||
"""
|
||||
return self(words, seq_len, target)
|
||||
return self(words, target, seq_len)
|
||||
|
||||
def evaluate_step(self, words: "torch.LongTensor", seq_len: "torch.LongTensor"):
|
||||
"""
|
||||
@ -94,7 +94,7 @@ class BiLSTMCRF(nn.Module):
|
||||
:param seq_len: 每个句子的长度,形状为 ``[batch,]``
|
||||
:return: 预测结果 ``{'pred': torch.Tensor}``
|
||||
"""
|
||||
return self(words, seq_len)
|
||||
return self(words, seq_len=seq_len)
|
||||
|
||||
|
||||
class SeqLabeling(nn.Module):
|
||||
@ -286,7 +286,7 @@ class AdvSeqLabel(nn.Module):
|
||||
:param seq_len: 每个句子的长度,形状为 ``[batch,]``
|
||||
:return: 如果 ``target`` 为 ``None``,则返回预测结果 ``{'pred': torch.Tensor}``,否则返回 loss ``{'loss': torch.Tensor}``
|
||||
"""
|
||||
return self(words, seq_len, target)
|
||||
return self(words, target, seq_len)
|
||||
|
||||
def evaluate_step(self, words: "torch.LongTensor", seq_len: "torch.LongTensor"):
|
||||
"""
|
||||
@ -294,4 +294,4 @@ class AdvSeqLabel(nn.Module):
|
||||
:param seq_len: 每个句子的长度,形状为 ``[batch,]``
|
||||
:return: 预测结果 ``{'pred': torch.Tensor}``
|
||||
"""
|
||||
return self(words, seq_len)
|
||||
return self(words, seq_len=seq_len)
|
||||
|
@ -103,8 +103,8 @@ def model_and_optimizers(request):
|
||||
|
||||
# 测试一下普通的情况;
|
||||
@pytest.mark.torch
|
||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 4),
|
||||
("torch", [4, 5])]) # ("torch", "cpu"), ("torch", 1), ("torch", [0, 1])
|
||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 1),
|
||||
("torch", [0, 1])]) # ("torch", "cpu"), ("torch", 1), ("torch", [0, 1])
|
||||
@pytest.mark.parametrize("evaluate_every", [-3, -1, 2])
|
||||
@magic_argv_env_context
|
||||
def test_trainer_torch_with_evaluator(
|
||||
@ -139,7 +139,7 @@ def test_trainer_torch_with_evaluator(
|
||||
|
||||
|
||||
@pytest.mark.torch
|
||||
@pytest.mark.parametrize("driver,device", [("torch", [4, 5]), ("torch", 4)]) # ("torch", [0, 1]),("torch", 1)
|
||||
@pytest.mark.parametrize("driver,device", [("torch", [0, 1]), ("torch", 1)]) # ("torch", [0, 1]),("torch", 1)
|
||||
@pytest.mark.parametrize("fp16", [True, False])
|
||||
@pytest.mark.parametrize("accumulation_steps", [1, 3])
|
||||
@magic_argv_env_context
|
||||
@ -250,7 +250,7 @@ def test_trainer_on(
|
||||
|
||||
|
||||
@pytest.mark.torch
|
||||
@pytest.mark.parametrize("driver,device", [("torch", 'cpu'), ("torch", 4)]) # ("torch", [0, 1]),("torch", 1)
|
||||
@pytest.mark.parametrize("driver,device", [("torch", 'cpu'), ("torch", 1)]) # ("torch", [0, 1]),("torch", 1)
|
||||
@magic_argv_env_context
|
||||
def test_trainer_specific_params_1(
|
||||
model_and_optimizers: TrainerParameters,
|
||||
@ -291,7 +291,7 @@ def test_trainer_specific_params_1(
|
||||
|
||||
|
||||
@pytest.mark.torch
|
||||
@pytest.mark.parametrize("driver,device", [("torch", [4, 5])]) # ("torch", [0, 1]),("torch", 1)
|
||||
@pytest.mark.parametrize("driver,device", [("torch", [0, 1])]) # ("torch", [0, 1]),("torch", 1)
|
||||
@magic_argv_env_context
|
||||
def test_trainer_specific_params_2(
|
||||
model_and_optimizers: TrainerParameters,
|
||||
@ -340,7 +340,7 @@ def test_trainer_specific_params_2(
|
||||
|
||||
|
||||
@pytest.mark.torch
|
||||
@pytest.mark.parametrize("driver,device", [("torch", 4), ("torch", [4, 5])]) # ("torch", [0, 1]),("torch", 1)
|
||||
@pytest.mark.parametrize("driver,device", [("torch", 1), ("torch", [0, 1])]) # ("torch", [0, 1]),("torch", 1)
|
||||
@pytest.mark.parametrize("overfit_batches,num_train_batch_per_epoch", [(-1, -1), (0, -1), (3, 10), (6, -1)])
|
||||
@magic_argv_env_context
|
||||
def test_trainer_w_evaluator_overfit_torch(
|
||||
|
@ -14,7 +14,7 @@ from tests.helpers.callbacks.helper_callbacks import RecordLossCallback
|
||||
from tests.helpers.callbacks.helper_callbacks_torch import RecordAccumulationStepsCallback_Torch
|
||||
from tests.helpers.utils import magic_argv_env_context, Capturing
|
||||
from fastNLP.envs.distributed import rank_zero_rm
|
||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
|
||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _TORCH_GREATER_EQUAL_1_12
|
||||
if _NEED_IMPORT_TORCH:
|
||||
import torch.distributed as dist
|
||||
from torch.optim import SGD
|
||||
@ -290,7 +290,7 @@ def test_trainer_on_exception(
|
||||
|
||||
|
||||
@pytest.mark.torch
|
||||
@pytest.mark.parametrize("version", [0, 1, 2, 3])
|
||||
@pytest.mark.parametrize("version", [0, 1])
|
||||
@magic_argv_env_context
|
||||
def test_torch_distributed_launch_1(version):
|
||||
"""
|
||||
@ -304,7 +304,7 @@ def test_torch_distributed_launch_1(version):
|
||||
|
||||
|
||||
@pytest.mark.torch
|
||||
@pytest.mark.parametrize("version", [0, 1, 2, 3])
|
||||
@pytest.mark.parametrize("version", [0, 1])
|
||||
@magic_argv_env_context
|
||||
def test_torch_distributed_launch_2(version):
|
||||
"""
|
||||
@ -325,6 +325,8 @@ def test_torch_wo_auto_param_call(
|
||||
device,
|
||||
n_epochs=2,
|
||||
):
|
||||
if driver == "torch_fsdp" and not _TORCH_GREATER_EQUAL_1_12:
|
||||
pytest.skip("fsdp 需要 torch 在 1.12 及以上")
|
||||
|
||||
model = TorchNormalModel_Classification_3(
|
||||
num_labels=NormalClassificationTrainTorchConfig.num_labels,
|
||||
@ -373,6 +375,9 @@ def test_trainer_overfit_torch(
|
||||
overfit_batches,
|
||||
num_train_batch_per_epoch
|
||||
):
|
||||
if driver == "torch_fsdp" and not _TORCH_GREATER_EQUAL_1_12:
|
||||
pytest.skip("fsdp 需要 torch 在 1.12 及以上")
|
||||
|
||||
trainer = Trainer(
|
||||
model=model_and_optimizers.model,
|
||||
driver=driver,
|
||||
|
@ -11,7 +11,7 @@ from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
|
||||
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification, TorchArgMaxDataset
|
||||
from tests.helpers.callbacks.helper_callbacks import RecordLossCallback
|
||||
from tests.helpers.utils import magic_argv_env_context
|
||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
|
||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _TORCH_GREATER_EQUAL_1_12
|
||||
from fastNLP.envs import FASTNLP_LAUNCH_TIME, rank_zero_rm
|
||||
if _NEED_IMPORT_TORCH:
|
||||
import torch.distributed as dist
|
||||
@ -67,6 +67,7 @@ def model_and_optimizers(request):
|
||||
|
||||
return trainer_params
|
||||
|
||||
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, "fsdp 需要 torch 版本在 1.12 及以上")
|
||||
@pytest.mark.torch
|
||||
@magic_argv_env_context
|
||||
def test_trainer_torch_without_evaluator(
|
||||
@ -76,7 +77,7 @@ def test_trainer_torch_without_evaluator(
|
||||
trainer = Trainer(
|
||||
model=model_and_optimizers.model,
|
||||
driver="torch_fsdp",
|
||||
device=[4, 5],
|
||||
device=[0, 1],
|
||||
optimizers=model_and_optimizers.optimizers,
|
||||
train_dataloader=model_and_optimizers.train_dataloader,
|
||||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders,
|
||||
@ -96,6 +97,7 @@ def test_trainer_torch_without_evaluator(
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, "fsdp 需要 torch 版本在 1.12 及以上")
|
||||
@pytest.mark.torch
|
||||
@pytest.mark.parametrize("save_on_rank0", [True, False])
|
||||
@magic_argv_env_context(timeout=100)
|
||||
|
Loading…
Reference in New Issue
Block a user