重新修改了断点重训的逻辑,主要修改了 trainer.save/load 和 driver.save 和 load 函数

This commit is contained in:
YWMditto 2022-04-10 12:56:49 +08:00
parent 770b2eaf4b
commit a376eea776
3 changed files with 61 additions and 7 deletions

View File

@ -81,7 +81,7 @@ class LoadBestModelCallback(Callback):
real_monitor=self._real_monitor,
res=results)
if (monitor_value < self.monitor_value and self.larger_better is False) or \
(monitor_value > self.monitor_value and self.larger_better):
(monitor_value > self.monitor_value and self.larger_better):
self.monitor_value = monitor_value
if self.real_save_folder:
trainer.save_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict,

View File

@ -30,6 +30,7 @@ from fastNLP.core.utils import apply_to_collection, torch_move_data_to_device
from fastNLP.envs import rank_zero_call
from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME
from fastNLP.core.log import logger
from fastNLP.core.samplers import ReproducibleBatchSampler
class TorchDriver(Driver):
@ -178,8 +179,28 @@ class TorchDriver(Driver):
model.load_state_dict(res.state_dict())
@rank_zero_call
def save(self, folder: Path, states: Dict, only_state_dict: bool = True, should_save_model: bool = True, **kwargs):
# 1. 保存模型的状态;
def save(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs):
# 传入的 dataloader 参数是 trainer 的 dataloader 属性,因为 driver 的所有 dataloader 我们是不会去改变它的,而是通过改变
# trainer.dataloader 来改变 dataloader 的状态,从而适配训练或者评测环境;
# 1. sampler 的状态,因为我们支持 resume training即精确恢复到具体的一个 batch
# 首先 pytorch 的 DataLoader 一定会有 sampler另一方面我们在断点重训的时候一定会在 `replace_sampler` 中将 dataloader 的
# sampler 替换为 `ReproducibleIterator`;否则就是在单卡情况下将 batch_sampler 替换为 `ReproducibleBatchSampler`
dataloader_args = self.get_dataloader_args(dataloader)
if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler):
sampler = dataloader_args.batch_sampler
elif dataloader_args.sampler:
sampler = dataloader_args.sampler
else:
raise RuntimeError("This condition is not supposed to appear. Please report a bug to us.")
if hasattr(sampler, 'state_dict') and callable(sampler.state_dict):
states['sampler_states'] = sampler.state_dict()
else:
raise RuntimeError(
'The sampler has no `state_dict()` method, it will fail to recover to the specific batch.')
# 2. 保存模型的状态;
if should_save_model:
model = self.unwrap_model()
if only_state_dict:
@ -191,7 +212,7 @@ class TorchDriver(Driver):
torch.save(model, folder.joinpath(FASTNLP_MODEL_FILENAME))
logger.debug("Save model")
# 2. 保存 optimizers 的状态;
# 3. 保存 optimizers 的状态;
optimizers_state_dict = {}
for i in range(len(self.optimizers)):
optimizer: torch.optim.Optimizer = self.optimizers[i]
@ -203,7 +224,7 @@ class TorchDriver(Driver):
states["optimizers_state_dict"] = optimizers_state_dict
torch.save(states, Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME))
def load(self, folder: Path, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict:
def load(self, folder: Path, dataloader, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict:
states = torch.load(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME))
# 1. 加载 optimizers 的状态;
@ -224,6 +245,39 @@ class TorchDriver(Driver):
model.load_state_dict(res.state_dict())
logger.debug("Load model.")
# 3. 恢复 sampler 的状态;
dataloader_args = self.get_dataloader_args(dataloader)
sampler = dataloader_args.sampler
if not (hasattr(sampler, 'load_state_dict') and callable(sampler.load_state_dict)):
# 说明这里需要使用 ReproduceSampler 来弄一下了
if self.is_distributed():
raise RuntimeError(
"It is not allowed to use single device checkpoint retraining before but ddp now.")
sampler = ReproducibleBatchSampler(
batch_sampler=sampler,
batch_size=dataloader_args.batch_size,
drop_last=dataloader_args.drop_last
)
sampler.load_state_dict(states['sampler_states'])
states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler)
# 4. 修改 trainer_state.batch_idx_in_epoch
# sampler 是类似 RandomSampler 的sampler不是 batch_sampler
if not isinstance(sampler, ReproducibleBatchSampler):
if dataloader_args.drop_last:
batch_idx_in_epoch = len(
sampler) // dataloader_args.batch_size - sampler.num_left_samples // dataloader_args.batch_size
else:
batch_idx_in_epoch = (len(sampler) + dataloader_args.batch_size - 1) // dataloader_args.batch_size - \
(sampler.num_left_samples + dataloader_args.batch_size - 1) // dataloader_args.batch_size
# sampler 是 batch_sampler
else:
batch_idx_in_epoch = sampler.batch_idx_in_epoch
states["batch_idx_in_epoch"] = batch_idx_in_epoch
return states
def get_evaluate_context(self):

View File

@ -316,7 +316,7 @@ def test_model_checkpoint_callback_2(
dist.destroy_process_group()
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1)
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch_ddp", [6, 7]), ("torch", 7)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1)
@pytest.mark.parametrize("version", [0, 1])
@pytest.mark.parametrize("only_state_dict", [True, False])
@magic_argv_env_context
@ -466,7 +466,7 @@ def test_trainer_checkpoint_callback_1(
# 通过自己编写 model_save_fn 和 model_load_fn 来测试 huggingface 的 transformers 的模型的保存和加载;
@pytest.mark.parametrize("driver,device", [("torch_ddp", [0, 1]), ("torch", 1)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1)
@pytest.mark.parametrize("driver,device", [("torch_ddp", [6, 7]), ("torch", 7)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1)
@pytest.mark.parametrize("version", [0, 1])
@magic_argv_env_context
def test_trainer_checkpoint_callback_2(