mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-01 11:48:09 +08:00
修改断点重训部分逻辑
This commit is contained in:
parent
ca93fccf62
commit
64fa182aeb
@ -325,7 +325,6 @@ class PaddleFleetDriver(PaddleDriver):
|
||||
assert dataloader.dataset_kind != _DatasetKind.ITER, \
|
||||
"FastNLP does not support `IteratorDataset` now."
|
||||
# 如果 dist 为 ReproducibleBatchSampler, ReproducibleSampler 说明是在断点重训时 driver.load 函数调用;
|
||||
# 注意这里不需要调用 dist_sampler.set_distributed;因为如果用户使用的是 TorchDDPDriver,那么其在 Trainer 初始化的时候就已经调用了该函数;
|
||||
if isinstance(dist, ReproducibleBatchSampler):
|
||||
dist.set_distributed(
|
||||
num_replicas=self.world_size,
|
||||
@ -345,15 +344,16 @@ class PaddleFleetDriver(PaddleDriver):
|
||||
# trainer, evaluator
|
||||
if dist is None:
|
||||
if reproducible:
|
||||
raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize ddp out of our "
|
||||
raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize fleet out of our "
|
||||
"control.")
|
||||
else:
|
||||
if isinstance(dist, ReproducibleBatchSampler):
|
||||
dist = re_instantiate_sampler(dist)
|
||||
return replace_batch_sampler(dataloader, dist)
|
||||
if isinstance(dist, ReproducibleSampler):
|
||||
dist = re_instantiate_sampler(dist)
|
||||
return replace_sampler(dataloader, dist)
|
||||
args = self.get_dataloader_args(dataloader)
|
||||
if isinstance(args.batch_sampler, ReproducibleBatchSampler):
|
||||
batch_sampler = re_instantiate_sampler(args.batch_sampler)
|
||||
return replace_batch_sampler(dataloader, batch_sampler)
|
||||
if isinstance(args.sampler, ReproducibleSampler):
|
||||
sampler = re_instantiate_sampler(args.sampler)
|
||||
return replace_sampler(dataloader, sampler)
|
||||
return dataloader
|
||||
# trainer
|
||||
elif dist == "dist":
|
||||
|
@ -66,8 +66,8 @@ class PaddleDriver(Driver):
|
||||
|
||||
:param set_to_none: 用来判断是否需要将梯度直接置为 None;Paddle中这个参数无效。
|
||||
"""
|
||||
# if set_to_none:
|
||||
# log.warning("Parameter `set_to_none` does nothing in paddle since grad cannot be set directly.")
|
||||
if set_to_none:
|
||||
logger.warning_once("Parameter `set_to_none` does nothing in paddle since grad cannot be set directly.")
|
||||
for optimizer in self.optimizers:
|
||||
optimizer.clear_grad()
|
||||
|
||||
@ -254,8 +254,21 @@ class PaddleDriver(Driver):
|
||||
else:
|
||||
raise RuntimeError("This condition is not supposed to appear. Please report a bug to us.")
|
||||
|
||||
num_consumed_batches = states.pop('num_consumed_batches')
|
||||
if hasattr(sampler, 'state_dict') and callable(sampler.state_dict):
|
||||
states['sampler_states'] = sampler.state_dict()
|
||||
sampler_states = sampler.state_dict()
|
||||
# 如果有,需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为,直接使用sampler中的num_consumed_samples
|
||||
# 会造成多余实际消耗的问题。
|
||||
num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None)
|
||||
if num_consumed_samples_array is not None:
|
||||
if isinstance(sampler, ReproducibleSampler): # 如果是 sampler 的话,需要考虑 batch_size 。
|
||||
try:
|
||||
num_consumed_batches = num_consumed_batches * dataloader_args.batch_size
|
||||
except: # 有可能 batch_size 为 None,就只有损失精度了
|
||||
num_consumed_batches = sampler_states['num_consumed_samples']
|
||||
sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches]
|
||||
assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report."
|
||||
|
||||
else:
|
||||
raise RuntimeError(
|
||||
'The sampler has no `state_dict()` method, it will fail to recover to the specific batch.')
|
||||
|
Loading…
Reference in New Issue
Block a user