mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-04 05:07:44 +08:00
修复Evaluator的evaluate_use_dist_sampler在Trainer中的bug
This commit is contained in:
parent
b7ff2d1b4e
commit
d4bccf3f6a
@ -146,7 +146,9 @@ class Evaluator:
|
||||
|
||||
self.separator = kwargs.get('separator', '#')
|
||||
self.model_use_eval_mode = kwargs.get('model_use_eval_mode', True)
|
||||
use_dist_sampler = kwargs.get("use_dist_sampler", self.driver.is_distributed())
|
||||
use_dist_sampler = kwargs.get("use_dist_sampler", None)
|
||||
if use_dist_sampler is None:
|
||||
use_dist_sampler = self.driver.is_distributed()
|
||||
if use_dist_sampler:
|
||||
self._dist_sampler = "unrepeatdist"
|
||||
else:
|
||||
@ -384,7 +386,7 @@ class _MetricsWrapper:
|
||||
# 如果数据是分布式的,但是不aggregate的话可能有问题
|
||||
if evaluator._dist_sampler is not None and metric.aggregate_when_get_metric is False:
|
||||
logger.rank_zero_warning(
|
||||
"You have replace the sampler as distributed sampler when evaluation, but your metric "
|
||||
"You have replaced the sampler as distributed sampler when evaluation, but your metric "
|
||||
f"{metric_name}:{metric.__class__.__name__}'s `aggregate_when_get_metric` is False.", once=True)
|
||||
if metric.aggregate_when_get_metric is None:
|
||||
metric.aggregate_when_get_metric = evaluator._dist_sampler is not None
|
||||
|
@ -288,7 +288,8 @@ class Trainer(TrainerEventTrigger):
|
||||
|
||||
* *use_dist_sampler* -- 表示是否使用分布式的 ``sampler``。在多卡时,分布式 ``sampler`` 将自动决定每张卡上读取的 sample ,使得一个 epoch
|
||||
内所有卡的 sample 加起来为一整个数据集的 sample。默认会根据 driver 是否为分布式进行设置。
|
||||
* *evaluate_use_dist_sampler* -- 表示在 ``Evaluator`` 中在使用分布式的时候是否将 dataloader 的 ``sampler`` 替换为分布式的 ``sampler``;默认为 ``True``;
|
||||
* *evaluate_use_dist_sampler* -- 表示在 ``Evaluator`` 中在使用分布式的时候是否将 dataloader 的 ``sampler`` 替换为分布式的 ``sampler``;
|
||||
不传入该值时,该值与 ``use_dist_sampler`` 参数保持一致;
|
||||
* *output_from_new_proc* -- 应当为一个字符串,表示在多进程的 driver 中其它进程的输出流应当被做如何处理;其值应当为以下之一:
|
||||
["all", "ignore", "only_error"];当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到
|
||||
log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 "only_error";
|
||||
@ -477,7 +478,7 @@ class Trainer(TrainerEventTrigger):
|
||||
driver=self.driver, evaluate_batch_step_fn=evaluate_batch_step_fn,
|
||||
evaluate_fn=evaluate_fn, input_mapping=evaluate_input_mapping,
|
||||
output_mapping=evaluate_output_mapping, fp16=fp16, verbose=0,
|
||||
use_dist_sampler=kwargs.get("evaluate_use_dist_sampler", None),
|
||||
use_dist_sampler=kwargs.get("evaluate_use_dist_sampler", use_dist_sampler),
|
||||
progress_bar=progress_bar)
|
||||
|
||||
if train_fn is not None and not isinstance(train_fn, str):
|
||||
|
214
tests/core/controllers/test_evaluator_torch.py
Normal file
214
tests/core/controllers/test_evaluator_torch.py
Normal file
@ -0,0 +1,214 @@
|
||||
import pytest
|
||||
|
||||
from fastNLP import Metric, Evaluator
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from itertools import product
|
||||
|
||||
from fastNLP.core.controllers.trainer import Trainer
|
||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
|
||||
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification, TorchArgMaxDataset
|
||||
from tests.helpers.utils import magic_argv_env_context
|
||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
|
||||
from fastNLP import Event
|
||||
|
||||
# 检查能否正确 aggregate
|
||||
|
||||
|
||||
class DistMetric(Metric):
|
||||
def __init__(self, aggregate_when_get_metric=None):
|
||||
super().__init__(aggregate_when_get_metric=aggregate_when_get_metric)
|
||||
self.register_element('count', value=0, aggregate_method='sum')
|
||||
self.data = 0
|
||||
|
||||
def update(self, y):
|
||||
self.count += len(y)
|
||||
self.data += len(y)
|
||||
|
||||
def get_metric(self) -> dict:
|
||||
count2 = sum(self.all_gather_object(self.data))
|
||||
return {'count': self.count.item(), 'count2': count2}
|
||||
|
||||
def reset(self):
|
||||
self.data = 0
|
||||
|
||||
|
||||
|
||||
if _NEED_IMPORT_TORCH:
|
||||
from torch.optim import SGD
|
||||
from torch.utils.data import DataLoader
|
||||
import torch.distributed as dist
|
||||
from torch.utils.data import Dataset
|
||||
import torch
|
||||
|
||||
|
||||
class DataSet(Dataset):
|
||||
def __init__(self, num_samples=1000, num_features=10):
|
||||
g = torch.Generator()
|
||||
g.manual_seed(1000)
|
||||
self.data = torch.randn(num_samples, num_features, generator=g)
|
||||
self.y = self.data.argmax(dim=-1)
|
||||
|
||||
def __getitem__(self, item):
|
||||
return {'x': self.data[item], 'y': self.data[item]}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
|
||||
@dataclass
|
||||
class NormalClassificationTrainTorchConfig:
|
||||
num_labels: int = 10
|
||||
feature_dimension: int = 10
|
||||
seed: int = 0
|
||||
|
||||
batch_size: int = 4
|
||||
shuffle: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainerParameters:
|
||||
model: Any = None
|
||||
optimizers: Any = None
|
||||
train_dataloader: Any = None
|
||||
evaluate_dataloaders: Any = None
|
||||
input_mapping: Any = None
|
||||
output_mapping: Any = None
|
||||
metrics: Any = None
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=[1], autouse=True)
|
||||
def trainer_params(request):
|
||||
trainer_params = TrainerParameters()
|
||||
|
||||
trainer_params.model = TorchNormalModel_Classification_1(
|
||||
num_labels=NormalClassificationTrainTorchConfig.num_labels,
|
||||
feature_dimension=NormalClassificationTrainTorchConfig.feature_dimension
|
||||
)
|
||||
trainer_params.optimizers = SGD(trainer_params.model.parameters(), lr=0.001)
|
||||
|
||||
dataset = DataSet(99, num_features=NormalClassificationTrainTorchConfig.feature_dimension)
|
||||
_dataloader = DataLoader(
|
||||
dataset=dataset,
|
||||
batch_size=NormalClassificationTrainTorchConfig.batch_size,
|
||||
shuffle=True
|
||||
)
|
||||
trainer_params.train_dataloader = _dataloader
|
||||
trainer_params.evaluate_dataloaders = _dataloader
|
||||
|
||||
return trainer_params
|
||||
|
||||
|
||||
@pytest.mark.torch
|
||||
@pytest.mark.parametrize('device', [[0, 1], None])
|
||||
@magic_argv_env_context
|
||||
def test_1(trainer_params: TrainerParameters, device):
|
||||
# 测试能否自动 aggregate 。
|
||||
for aggregate_when_get_metric, use_dist_sampler in product([True, False], [True, False, None]):
|
||||
metric = DistMetric(aggregate_when_get_metric=aggregate_when_get_metric)
|
||||
|
||||
evaluator = Evaluator(trainer_params.model, dataloaders=trainer_params.evaluate_dataloaders,
|
||||
metrics={'c': metric},
|
||||
driver='torch', device=device, use_dist_sampler=use_dist_sampler,
|
||||
progress_bar='tqdm')
|
||||
if use_dist_sampler is None:
|
||||
use_dist_sampler = device is not None
|
||||
results = evaluator.run()
|
||||
num_samples = len(trainer_params.evaluate_dataloaders.dataset)
|
||||
if device is None:
|
||||
assert results['count#c'] == num_samples
|
||||
assert results['count2#c'] == num_samples
|
||||
else:
|
||||
if aggregate_when_get_metric is True and use_dist_sampler is True:
|
||||
assert results['count#c'] == num_samples
|
||||
assert results['count2#c'] == num_samples
|
||||
elif aggregate_when_get_metric is True and use_dist_sampler is False:
|
||||
assert results['count#c'] == 2*num_samples
|
||||
assert results['count2#c'] == 2*num_samples
|
||||
elif aggregate_when_get_metric is False and use_dist_sampler is True:
|
||||
assert results['count#c'] in (49, 50) # 不同卡,数量不同
|
||||
assert results['count2#c'] in (49, 50)
|
||||
else:
|
||||
assert results['count#c'] == num_samples
|
||||
assert results['count2#c'] == num_samples
|
||||
|
||||
if dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
|
||||
@pytest.mark.torch
|
||||
@pytest.mark.parametrize('device', [[0, 1], None])
|
||||
@magic_argv_env_context
|
||||
def test_2(trainer_params: TrainerParameters, device):
|
||||
# 测试能否自动 aggregate 。
|
||||
for aggregate_when_get_metric, use_dist_sampler in product([True, False], [True, False, None]):
|
||||
metric = DistMetric(aggregate_when_get_metric=aggregate_when_get_metric)
|
||||
|
||||
num_samples = len(trainer_params.evaluate_dataloaders.dataset)
|
||||
|
||||
@Trainer.on(Event.on_sanity_check_end())
|
||||
def on_valid_end(trainer, results):
|
||||
if device is None:
|
||||
assert results['count#c'] == num_samples
|
||||
assert results['count2#c'] == num_samples
|
||||
else:
|
||||
if aggregate_when_get_metric is True and use_dist_sampler is True:
|
||||
assert results['count#c'] == num_samples
|
||||
assert results['count2#c'] == num_samples
|
||||
elif aggregate_when_get_metric is True and use_dist_sampler is False:
|
||||
assert results['count#c'] == 2 * num_samples
|
||||
assert results['count2#c'] == 2 * num_samples
|
||||
elif aggregate_when_get_metric is False and use_dist_sampler is True:
|
||||
assert results['count#c'] in (49, 50) # 不同卡,数量不同
|
||||
assert results['count2#c'] in (49, 50)
|
||||
else:
|
||||
assert results['count#c'] == num_samples
|
||||
assert results['count2#c'] == num_samples
|
||||
|
||||
trainer = Trainer(
|
||||
model=trainer_params.model,
|
||||
driver='torch',
|
||||
device=device,
|
||||
optimizers=trainer_params.optimizers,
|
||||
train_dataloader=trainer_params.train_dataloader,
|
||||
evaluate_dataloaders=trainer_params.evaluate_dataloaders,
|
||||
metrics={'c': metric},
|
||||
evaluate_every=-1,
|
||||
n_epochs=0,
|
||||
output_from_new_proc="all",
|
||||
use_dist_sampler=use_dist_sampler,
|
||||
progress_bar='tqdm'
|
||||
)
|
||||
|
||||
if use_dist_sampler is None:
|
||||
use_dist_sampler = device is not None
|
||||
|
||||
trainer.run(num_eval_sanity_batch=-1)
|
||||
|
||||
trainer = Trainer(
|
||||
model=trainer_params.model,
|
||||
driver='torch',
|
||||
device=device,
|
||||
optimizers=trainer_params.optimizers,
|
||||
train_dataloader=trainer_params.train_dataloader,
|
||||
evaluate_dataloaders=trainer_params.evaluate_dataloaders,
|
||||
metrics={'c': DistMetric(aggregate_when_get_metric=aggregate_when_get_metric)},
|
||||
evaluate_every=-1,
|
||||
n_epochs=0,
|
||||
output_from_new_proc="all",
|
||||
use_dist_sampler=not (use_dist_sampler is True), #取相反的值
|
||||
evaluate_use_dist_sampler=use_dist_sampler,
|
||||
progress_bar='rich' # 刚好测试一下可以替换 progress 么
|
||||
)
|
||||
trainer.run(num_eval_sanity_batch=-1)
|
||||
|
||||
if dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user