From d4bccf3f6ad7fdca22b5057f20ccea145c25c314 Mon Sep 17 00:00:00 2001 From: yh Date: Sun, 15 May 2022 15:37:08 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8DEvaluator=E7=9A=84evaluate=5F?= =?UTF-8?q?use=5Fdist=5Fsampler=E5=9C=A8Trainer=E4=B8=AD=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/controllers/evaluator.py | 6 +- fastNLP/core/controllers/trainer.py | 5 +- .../core/controllers/test_evaluator_torch.py | 214 ++++++++++++++++++ 3 files changed, 221 insertions(+), 4 deletions(-) create mode 100644 tests/core/controllers/test_evaluator_torch.py diff --git a/fastNLP/core/controllers/evaluator.py b/fastNLP/core/controllers/evaluator.py index 908c3564..e8a8a872 100644 --- a/fastNLP/core/controllers/evaluator.py +++ b/fastNLP/core/controllers/evaluator.py @@ -146,7 +146,9 @@ class Evaluator: self.separator = kwargs.get('separator', '#') self.model_use_eval_mode = kwargs.get('model_use_eval_mode', True) - use_dist_sampler = kwargs.get("use_dist_sampler", self.driver.is_distributed()) + use_dist_sampler = kwargs.get("use_dist_sampler", None) + if use_dist_sampler is None: + use_dist_sampler = self.driver.is_distributed() if use_dist_sampler: self._dist_sampler = "unrepeatdist" else: @@ -384,7 +386,7 @@ class _MetricsWrapper: # 如果数据是分布式的,但是不aggregate的话可能有问题 if evaluator._dist_sampler is not None and metric.aggregate_when_get_metric is False: logger.rank_zero_warning( - "You have replace the sampler as distributed sampler when evaluation, but your metric " + "You have replaced the sampler as distributed sampler when evaluation, but your metric " f"{metric_name}:{metric.__class__.__name__}'s `aggregate_when_get_metric` is False.", once=True) if metric.aggregate_when_get_metric is None: metric.aggregate_when_get_metric = evaluator._dist_sampler is not None diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index 81ddd3e8..01be134d 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -288,7 +288,8 @@ class Trainer(TrainerEventTrigger): * *use_dist_sampler* -- 表示是否使用分布式的 ``sampler``。在多卡时,分布式 ``sampler`` 将自动决定每张卡上读取的 sample ,使得一个 epoch 内所有卡的 sample 加起来为一整个数据集的 sample。默认会根据 driver 是否为分布式进行设置。 - * *evaluate_use_dist_sampler* -- 表示在 ``Evaluator`` 中在使用分布式的时候是否将 dataloader 的 ``sampler`` 替换为分布式的 ``sampler``;默认为 ``True``; + * *evaluate_use_dist_sampler* -- 表示在 ``Evaluator`` 中在使用分布式的时候是否将 dataloader 的 ``sampler`` 替换为分布式的 ``sampler``; + 不传入该值时,该值与 ``use_dist_sampler`` 参数保持一致; * *output_from_new_proc* -- 应当为一个字符串,表示在多进程的 driver 中其它进程的输出流应当被做如何处理;其值应当为以下之一: ["all", "ignore", "only_error"];当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到 log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 "only_error"; @@ -477,7 +478,7 @@ class Trainer(TrainerEventTrigger): driver=self.driver, evaluate_batch_step_fn=evaluate_batch_step_fn, evaluate_fn=evaluate_fn, input_mapping=evaluate_input_mapping, output_mapping=evaluate_output_mapping, fp16=fp16, verbose=0, - use_dist_sampler=kwargs.get("evaluate_use_dist_sampler", None), + use_dist_sampler=kwargs.get("evaluate_use_dist_sampler", use_dist_sampler), progress_bar=progress_bar) if train_fn is not None and not isinstance(train_fn, str): diff --git a/tests/core/controllers/test_evaluator_torch.py b/tests/core/controllers/test_evaluator_torch.py new file mode 100644 index 00000000..a30adea9 --- /dev/null +++ b/tests/core/controllers/test_evaluator_torch.py @@ -0,0 +1,214 @@ +import pytest + +from fastNLP import Metric, Evaluator + +from dataclasses import dataclass +from typing import Any +from itertools import product + +from fastNLP.core.controllers.trainer import Trainer +from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 +from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification, TorchArgMaxDataset +from tests.helpers.utils import magic_argv_env_context +from fastNLP.envs.imports import _NEED_IMPORT_TORCH +from fastNLP import Event + +# 检查能否正确 aggregate + + +class DistMetric(Metric): + def __init__(self, aggregate_when_get_metric=None): + super().__init__(aggregate_when_get_metric=aggregate_when_get_metric) + self.register_element('count', value=0, aggregate_method='sum') + self.data = 0 + + def update(self, y): + self.count += len(y) + self.data += len(y) + + def get_metric(self) -> dict: + count2 = sum(self.all_gather_object(self.data)) + return {'count': self.count.item(), 'count2': count2} + + def reset(self): + self.data = 0 + + + +if _NEED_IMPORT_TORCH: + from torch.optim import SGD + from torch.utils.data import DataLoader + import torch.distributed as dist + from torch.utils.data import Dataset + import torch + + + class DataSet(Dataset): + def __init__(self, num_samples=1000, num_features=10): + g = torch.Generator() + g.manual_seed(1000) + self.data = torch.randn(num_samples, num_features, generator=g) + self.y = self.data.argmax(dim=-1) + + def __getitem__(self, item): + return {'x': self.data[item], 'y': self.data[item]} + + def __len__(self): + return len(self.data) + + +@dataclass +class NormalClassificationTrainTorchConfig: + num_labels: int = 10 + feature_dimension: int = 10 + seed: int = 0 + + batch_size: int = 4 + shuffle: bool = True + + +@dataclass +class TrainerParameters: + model: Any = None + optimizers: Any = None + train_dataloader: Any = None + evaluate_dataloaders: Any = None + input_mapping: Any = None + output_mapping: Any = None + metrics: Any = None + + +@pytest.fixture(scope="module", params=[1], autouse=True) +def trainer_params(request): + trainer_params = TrainerParameters() + + trainer_params.model = TorchNormalModel_Classification_1( + num_labels=NormalClassificationTrainTorchConfig.num_labels, + feature_dimension=NormalClassificationTrainTorchConfig.feature_dimension + ) + trainer_params.optimizers = SGD(trainer_params.model.parameters(), lr=0.001) + + dataset = DataSet(99, num_features=NormalClassificationTrainTorchConfig.feature_dimension) + _dataloader = DataLoader( + dataset=dataset, + batch_size=NormalClassificationTrainTorchConfig.batch_size, + shuffle=True + ) + trainer_params.train_dataloader = _dataloader + trainer_params.evaluate_dataloaders = _dataloader + + return trainer_params + + +@pytest.mark.torch +@pytest.mark.parametrize('device', [[0, 1], None]) +@magic_argv_env_context +def test_1(trainer_params: TrainerParameters, device): + # 测试能否自动 aggregate 。 + for aggregate_when_get_metric, use_dist_sampler in product([True, False], [True, False, None]): + metric = DistMetric(aggregate_when_get_metric=aggregate_when_get_metric) + + evaluator = Evaluator(trainer_params.model, dataloaders=trainer_params.evaluate_dataloaders, + metrics={'c': metric}, + driver='torch', device=device, use_dist_sampler=use_dist_sampler, + progress_bar='tqdm') + if use_dist_sampler is None: + use_dist_sampler = device is not None + results = evaluator.run() + num_samples = len(trainer_params.evaluate_dataloaders.dataset) + if device is None: + assert results['count#c'] == num_samples + assert results['count2#c'] == num_samples + else: + if aggregate_when_get_metric is True and use_dist_sampler is True: + assert results['count#c'] == num_samples + assert results['count2#c'] == num_samples + elif aggregate_when_get_metric is True and use_dist_sampler is False: + assert results['count#c'] == 2*num_samples + assert results['count2#c'] == 2*num_samples + elif aggregate_when_get_metric is False and use_dist_sampler is True: + assert results['count#c'] in (49, 50) # 不同卡,数量不同 + assert results['count2#c'] in (49, 50) + else: + assert results['count#c'] == num_samples + assert results['count2#c'] == num_samples + + if dist.is_initialized(): + dist.destroy_process_group() + + + +@pytest.mark.torch +@pytest.mark.parametrize('device', [[0, 1], None]) +@magic_argv_env_context +def test_2(trainer_params: TrainerParameters, device): + # 测试能否自动 aggregate 。 + for aggregate_when_get_metric, use_dist_sampler in product([True, False], [True, False, None]): + metric = DistMetric(aggregate_when_get_metric=aggregate_when_get_metric) + + num_samples = len(trainer_params.evaluate_dataloaders.dataset) + + @Trainer.on(Event.on_sanity_check_end()) + def on_valid_end(trainer, results): + if device is None: + assert results['count#c'] == num_samples + assert results['count2#c'] == num_samples + else: + if aggregate_when_get_metric is True and use_dist_sampler is True: + assert results['count#c'] == num_samples + assert results['count2#c'] == num_samples + elif aggregate_when_get_metric is True and use_dist_sampler is False: + assert results['count#c'] == 2 * num_samples + assert results['count2#c'] == 2 * num_samples + elif aggregate_when_get_metric is False and use_dist_sampler is True: + assert results['count#c'] in (49, 50) # 不同卡,数量不同 + assert results['count2#c'] in (49, 50) + else: + assert results['count#c'] == num_samples + assert results['count2#c'] == num_samples + + trainer = Trainer( + model=trainer_params.model, + driver='torch', + device=device, + optimizers=trainer_params.optimizers, + train_dataloader=trainer_params.train_dataloader, + evaluate_dataloaders=trainer_params.evaluate_dataloaders, + metrics={'c': metric}, + evaluate_every=-1, + n_epochs=0, + output_from_new_proc="all", + use_dist_sampler=use_dist_sampler, + progress_bar='tqdm' + ) + + if use_dist_sampler is None: + use_dist_sampler = device is not None + + trainer.run(num_eval_sanity_batch=-1) + + trainer = Trainer( + model=trainer_params.model, + driver='torch', + device=device, + optimizers=trainer_params.optimizers, + train_dataloader=trainer_params.train_dataloader, + evaluate_dataloaders=trainer_params.evaluate_dataloaders, + metrics={'c': DistMetric(aggregate_when_get_metric=aggregate_when_get_metric)}, + evaluate_every=-1, + n_epochs=0, + output_from_new_proc="all", + use_dist_sampler=not (use_dist_sampler is True), #取相反的值 + evaluate_use_dist_sampler=use_dist_sampler, + progress_bar='rich' # 刚好测试一下可以替换 progress 么 + ) + trainer.run(num_eval_sanity_batch=-1) + + if dist.is_initialized(): + dist.destroy_process_group() + + + + + +