Merge branch 'dev0.8.0' of github.com:fastnlp/fastNLP into dev0.8.0

This commit is contained in:
yh_cc 2022-04-14 16:02:55 +08:00
commit 8a0ffd6278
4 changed files with 30 additions and 19 deletions

View File

@ -244,7 +244,6 @@ class PaddleFleetDriver(PaddleDriver):
"""
if self.local_rank == 0:
# 是 rank0 的话,则拉起其它子进程
print("in launcher")
launcher = FleetLauncher(self.parallel_device, self.output_from_new_proc)
launcher.launch()
# 设置参数和初始化分布式环境
@ -326,7 +325,6 @@ class PaddleFleetDriver(PaddleDriver):
assert dataloader.dataset_kind != _DatasetKind.ITER, \
"FastNLP does not support `IteratorDataset` now."
# 如果 dist 为 ReproducibleBatchSampler, ReproducibleSampler 说明是在断点重训时 driver.load 函数调用;
# 注意这里不需要调用 dist_sampler.set_distributed因为如果用户使用的是 TorchDDPDriver那么其在 Trainer 初始化的时候就已经调用了该函数;
if isinstance(dist, ReproducibleBatchSampler):
dist.set_distributed(
num_replicas=self.world_size,
@ -346,15 +344,16 @@ class PaddleFleetDriver(PaddleDriver):
# trainer, evaluator
if dist is None:
if reproducible:
raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize ddp out of our "
raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize fleet out of our "
"control.")
else:
if isinstance(dist, ReproducibleBatchSampler):
dist = re_instantiate_sampler(dist)
return replace_batch_sampler(dataloader, dist)
if isinstance(dist, ReproducibleSampler):
dist = re_instantiate_sampler(dist)
return replace_sampler(dataloader, dist)
args = self.get_dataloader_args(dataloader)
if isinstance(args.batch_sampler, ReproducibleBatchSampler):
batch_sampler = re_instantiate_sampler(args.batch_sampler)
return replace_batch_sampler(dataloader, batch_sampler)
if isinstance(args.sampler, ReproducibleSampler):
sampler = re_instantiate_sampler(args.sampler)
return replace_sampler(dataloader, sampler)
return dataloader
# trainer
elif dist == "dist":

View File

@ -66,8 +66,8 @@ class PaddleDriver(Driver):
:param set_to_none: 用来判断是否需要将梯度直接置为 NonePaddle中这个参数无效
"""
# if set_to_none:
# log.warning("Parameter `set_to_none` does nothing in paddle since grad cannot be set directly.")
if set_to_none:
logger.warning_once("Parameter `set_to_none` does nothing in paddle since grad cannot be set directly.")
for optimizer in self.optimizers:
optimizer.clear_grad()
@ -254,8 +254,21 @@ class PaddleDriver(Driver):
else:
raise RuntimeError("This condition is not supposed to appear. Please report a bug to us.")
num_consumed_batches = states.pop('num_consumed_batches')
if hasattr(sampler, 'state_dict') and callable(sampler.state_dict):
states['sampler_states'] = sampler.state_dict()
sampler_states = sampler.state_dict()
# 如果有,需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为直接使用sampler中的num_consumed_samples
# 会造成多余实际消耗的问题。
num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None)
if num_consumed_samples_array is not None:
if isinstance(sampler, ReproducibleSampler): # 如果是 sampler 的话,需要考虑 batch_size 。
try:
num_consumed_batches = num_consumed_batches * dataloader_args.batch_size
except: # 有可能 batch_size 为 None就只有损失精度了
num_consumed_batches = sampler_states['num_consumed_samples']
sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches]
assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report."
else:
raise RuntimeError(
'The sampler has no `state_dict()` method, it will fail to recover to the specific batch.')

View File

@ -471,12 +471,11 @@ class TorchDDPDriver(TorchDriver):
raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize ddp out of our "
"control.")
else:
if isinstance(dist, ReproducibleBatchSampler):
dist = re_instantiate_sampler(dist)
return replace_batch_sampler(dataloader, dist)
if isinstance(dist, ReproducibleSampler):
dist = re_instantiate_sampler(dist)
return replace_sampler(dataloader, dist)
args = self.get_dataloader_args(dataloader)
if isinstance(args.batch_sampler, ReproducibleBatchSampler):
return replace_batch_sampler(dataloader, re_instantiate_sampler(args.batch_sampler))
if isinstance(args.sampler, ReproducibleSampler):
return replace_sampler(dataloader, re_instantiate_sampler(args.sampler))
return dataloader
# trainer
elif dist == "dist":

View File

@ -151,7 +151,7 @@ class RandomBatchSampler(ReproducibleBatchSampler):
self.need_reinitialize = False
def set_distributed(self, num_replicas, rank, pad=True):
raise RuntimeError(f"ReproduceBatchSampler does not support to change to distributed training.")
raise RuntimeError(f"RandomBatchSampler does not support to change to distributed training.")
def set_epoch(self, epoch):
if hasattr(self.batch_sampler, "sampler") and hasattr(self.batch_sampler.sampler, 'set_epoch') and callable(self.batch_sampler.sampler.set_epoch):