mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-11-29 18:59:01 +08:00
初步修正了多卡情况下 evaluate_dataloader 使用定制 batch_sampler 会报错的问题,改为给出 warning
This commit is contained in:
parent
fe30d02f86
commit
28db704f70
@ -609,9 +609,15 @@ class TorchDDPDriver(TorchDriver):
|
||||
# evaluator
|
||||
elif dist == "unrepeatdist":
|
||||
args = self.get_dataloader_args(dataloader)
|
||||
if type(args.batch_sampler) != BatchSampler:
|
||||
# TODO 这里的目的是判断用户的 batch_sampler 是定制的,可能需要完善
|
||||
logger.warning("Note that you are using customized ``batch_sampler`` in evaluate dataloader or" \
|
||||
"train dataloader while testing ``overfit_batches``, which may cause that" \
|
||||
"the data for distributed evaluation is not unrepeated.")
|
||||
if isinstance(args.sampler, ReproducibleSampler):
|
||||
sampler = conversion_between_reproducible_and_unrepeated_sampler(args.sampler)
|
||||
elif not isinstance(args.sampler, UnrepeatedSampler):
|
||||
# TODO 避开 batch_sampler 的情况
|
||||
_check_dataloader_args_for_distributed(args, controller='Evaluator')
|
||||
sampler = UnrepeatedSequentialSampler(
|
||||
dataset=args.dataset
|
||||
@ -622,6 +628,7 @@ class TorchDDPDriver(TorchDriver):
|
||||
num_replicas=self.world_size,
|
||||
rank=self.global_rank
|
||||
)
|
||||
# TODO 这里暂时统一替换为 BatchSampler
|
||||
batch_sampler = BatchSampler(sampler, args.batch_size, drop_last=False)
|
||||
return replace_batch_sampler(dataloader, batch_sampler)
|
||||
else:
|
||||
|
@ -14,7 +14,7 @@ from fastNLP.envs import (
|
||||
FASTNLP_BACKEND_LAUNCH,
|
||||
FASTNLP_GLOBAL_SEED,
|
||||
)
|
||||
from fastNLP.core.samplers import re_instantiate_sampler, ReproducibleBatchSampler
|
||||
from fastNLP.core.samplers import re_instantiate_sampler, ReproducibleBatchSampler, ReproducibleSampler
|
||||
from fastNLP.core.utils import auto_param_call, apply_to_collection
|
||||
from fastNLP.core.log import logger
|
||||
|
||||
@ -308,15 +308,26 @@ def optimizer_state_to_device(state, device):
|
||||
|
||||
|
||||
def _check_dataloader_args_for_distributed(args, controller='Trainer'):
|
||||
if type(args.batch_sampler) is not TorchBatchSampler or (type(args.sampler) not in {TorchRandomSampler,
|
||||
TorchSequentialSampler}):
|
||||
mode = 'training' if controller == 'Trainer' else 'evaluation'
|
||||
substitution = 'fastNLP.RandomSampler' if controller == 'Trainer' else 'fastNLP.UnrepeatedSequentialSampler'
|
||||
"""
|
||||
检查 dataloader 的 sampler 情况,如果用户替换了自己定制的 sampler ,为了防止
|
||||
在分布式训练中出现错误会报错。
|
||||
"""
|
||||
error_flag = (type(args.sampler) not in {TorchRandomSampler, TorchSequentialSampler})
|
||||
if controller == 'Trainer':
|
||||
mode = 'training'
|
||||
substitution = 'fastNLP.RandomSampler'
|
||||
error_flag = (type(args.batch_sampler) != TorchBatchSampler) or error_flag
|
||||
else: # Evaluator
|
||||
mode = 'evaluation'
|
||||
substitution = 'fastNLP.UnrepeatedSequentialSampler'
|
||||
if error_flag:
|
||||
raise TypeError(f"Using customized ``batch_sampler`` or ``sampler`` for distributed {mode} may cause "
|
||||
f"unpredictable problems, because fastNLP will substitute the dataloader's sampler into "
|
||||
f"``{substitution}``. The customized sampler should set for distributed running "
|
||||
f"before initializing ``{controller}`` , and then set the "
|
||||
f"parameter ``use_dist_sampler`` of ``{controller}`` to ``False``.")
|
||||
f"parameter ``use_dist_sampler`` of ``{controller}`` to ``False``."
|
||||
f"\n Current batch_sampler: {type(args.batch_sampler)}"
|
||||
f"\n Current sampler: {type(args.sampler)}")
|
||||
|
||||
def _create_default_config(
|
||||
zero_optimization: bool = True,
|
||||
|
@ -7,6 +7,7 @@ from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from fastNLP.core.controllers.trainer import Trainer
|
||||
from fastNLP.core.samplers import BucketedBatchSampler, RandomSampler
|
||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
|
||||
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification, TorchArgMaxDataset
|
||||
from tests.helpers.callbacks.helper_callbacks import RecordLossCallback, RecordMetricCallback
|
||||
@ -378,4 +379,96 @@ def test_trainer_w_evaluator_overfit_torch(
|
||||
trainer.run(num_train_batch_per_epoch=num_train_batch_per_epoch)
|
||||
|
||||
if dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
dist.destroy_process_group()
|
||||
|
||||
@pytest.mark.torch
|
||||
@pytest.mark.parametrize("driver,device", [("torch", 1), ("torch", [0, 1])]) # ("torch", [0, 1]),("torch", 1)
|
||||
@pytest.mark.parametrize("train_sampler", ["batch_sampler", "sampler"])
|
||||
@pytest.mark.parametrize("eval_sampler", ["batch_sampler", "sampler"])
|
||||
@pytest.mark.parametrize("overfit_batches", [-1, 0])
|
||||
@magic_argv_env_context
|
||||
def test_trainer_w_evaluator_w_samplers(
|
||||
driver,
|
||||
device,
|
||||
train_sampler,
|
||||
eval_sampler,
|
||||
overfit_batches,
|
||||
):
|
||||
"""
|
||||
测试使用 dataloader 时使用了定制 batch_sampler 或 sampler 且合法的情况
|
||||
"""
|
||||
model = TorchNormalModel_Classification_1(
|
||||
num_labels=ArgMaxDatasetConfig.num_labels,
|
||||
feature_dimension=ArgMaxDatasetConfig.feature_dimension
|
||||
)
|
||||
optimizers = SGD(model.parameters(), lr=0.001)
|
||||
metrics = {"acc": Accuracy()}
|
||||
|
||||
dataset = TorchArgMaxDataset(
|
||||
feature_dimension=ArgMaxDatasetConfig.feature_dimension,
|
||||
data_num=ArgMaxDatasetConfig.data_num,
|
||||
seed=ArgMaxDatasetConfig.seed
|
||||
)
|
||||
if train_sampler == "batch_sampler":
|
||||
train_dataloader = DataLoader(
|
||||
dataset=dataset,
|
||||
batch_sampler=BucketedBatchSampler(
|
||||
dataset,[3] * len(dataset), ArgMaxDatasetConfig.batch_size
|
||||
)
|
||||
)
|
||||
elif train_sampler == "sampler":
|
||||
train_dataloader = DataLoader(
|
||||
dataset=dataset,
|
||||
batch_size=ArgMaxDatasetConfig.batch_size,
|
||||
sampler=RandomSampler(dataset)
|
||||
)
|
||||
else:
|
||||
train_dataloader = DataLoader(
|
||||
dataset=dataset,
|
||||
batch_size=ArgMaxDatasetConfig.batch_size,
|
||||
shuffle=True,
|
||||
)
|
||||
if eval_sampler == "batch_sampler":
|
||||
eval_dataloader = DataLoader(
|
||||
dataset=dataset,
|
||||
batch_sampler=BucketedBatchSampler(
|
||||
dataset,[3] * len(dataset), ArgMaxDatasetConfig.batch_size
|
||||
)
|
||||
)
|
||||
elif eval_sampler == "sampler":
|
||||
eval_dataloader = DataLoader(
|
||||
dataset=dataset,
|
||||
sampler=RandomSampler(dataset)
|
||||
)
|
||||
else:
|
||||
DataLoader(
|
||||
dataset=dataset,
|
||||
batch_size=ArgMaxDatasetConfig.batch_size,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
driver=driver,
|
||||
device=device,
|
||||
overfit_batches=overfit_batches,
|
||||
optimizers=optimizers,
|
||||
train_dataloader=train_dataloader,
|
||||
evaluate_dataloaders={"dl": eval_dataloader},
|
||||
metrics=metrics,
|
||||
n_epochs=2,
|
||||
output_from_new_proc="all",
|
||||
evaluate_every=-1,
|
||||
|
||||
torch_kwargs={
|
||||
"non_blocking": False,
|
||||
"set_grad_to_none": True
|
||||
}
|
||||
|
||||
)
|
||||
|
||||
trainer.run()
|
||||
|
||||
if dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user