mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-03 20:57:37 +08:00
重新修改了断点重训的逻辑,主要修改了 trainer.save/load 和 driver.save 和 load 函数
This commit is contained in:
parent
770b2eaf4b
commit
a376eea776
@ -81,7 +81,7 @@ class LoadBestModelCallback(Callback):
|
||||
real_monitor=self._real_monitor,
|
||||
res=results)
|
||||
if (monitor_value < self.monitor_value and self.larger_better is False) or \
|
||||
(monitor_value > self.monitor_value and self.larger_better):
|
||||
(monitor_value > self.monitor_value and self.larger_better):
|
||||
self.monitor_value = monitor_value
|
||||
if self.real_save_folder:
|
||||
trainer.save_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict,
|
||||
|
@ -30,6 +30,7 @@ from fastNLP.core.utils import apply_to_collection, torch_move_data_to_device
|
||||
from fastNLP.envs import rank_zero_call
|
||||
from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME
|
||||
from fastNLP.core.log import logger
|
||||
from fastNLP.core.samplers import ReproducibleBatchSampler
|
||||
|
||||
|
||||
class TorchDriver(Driver):
|
||||
@ -178,8 +179,28 @@ class TorchDriver(Driver):
|
||||
model.load_state_dict(res.state_dict())
|
||||
|
||||
@rank_zero_call
|
||||
def save(self, folder: Path, states: Dict, only_state_dict: bool = True, should_save_model: bool = True, **kwargs):
|
||||
# 1. 保存模型的状态;
|
||||
def save(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs):
|
||||
# 传入的 dataloader 参数是 trainer 的 dataloader 属性,因为 driver 的所有 dataloader 我们是不会去改变它的,而是通过改变
|
||||
# trainer.dataloader 来改变 dataloader 的状态,从而适配训练或者评测环境;
|
||||
|
||||
# 1. sampler 的状态,因为我们支持 resume training,即精确恢复到具体的一个 batch;
|
||||
# 首先 pytorch 的 DataLoader 一定会有 sampler;另一方面,我们在断点重训的时候一定会在 `replace_sampler` 中将 dataloader 的
|
||||
# sampler 替换为 `ReproducibleIterator`;否则就是在单卡情况下将 batch_sampler 替换为 `ReproducibleBatchSampler`;
|
||||
dataloader_args = self.get_dataloader_args(dataloader)
|
||||
if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler):
|
||||
sampler = dataloader_args.batch_sampler
|
||||
elif dataloader_args.sampler:
|
||||
sampler = dataloader_args.sampler
|
||||
else:
|
||||
raise RuntimeError("This condition is not supposed to appear. Please report a bug to us.")
|
||||
|
||||
if hasattr(sampler, 'state_dict') and callable(sampler.state_dict):
|
||||
states['sampler_states'] = sampler.state_dict()
|
||||
else:
|
||||
raise RuntimeError(
|
||||
'The sampler has no `state_dict()` method, it will fail to recover to the specific batch.')
|
||||
|
||||
# 2. 保存模型的状态;
|
||||
if should_save_model:
|
||||
model = self.unwrap_model()
|
||||
if only_state_dict:
|
||||
@ -191,7 +212,7 @@ class TorchDriver(Driver):
|
||||
torch.save(model, folder.joinpath(FASTNLP_MODEL_FILENAME))
|
||||
logger.debug("Save model")
|
||||
|
||||
# 2. 保存 optimizers 的状态;
|
||||
# 3. 保存 optimizers 的状态;
|
||||
optimizers_state_dict = {}
|
||||
for i in range(len(self.optimizers)):
|
||||
optimizer: torch.optim.Optimizer = self.optimizers[i]
|
||||
@ -203,7 +224,7 @@ class TorchDriver(Driver):
|
||||
states["optimizers_state_dict"] = optimizers_state_dict
|
||||
torch.save(states, Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME))
|
||||
|
||||
def load(self, folder: Path, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict:
|
||||
def load(self, folder: Path, dataloader, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict:
|
||||
states = torch.load(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME))
|
||||
|
||||
# 1. 加载 optimizers 的状态;
|
||||
@ -224,6 +245,39 @@ class TorchDriver(Driver):
|
||||
model.load_state_dict(res.state_dict())
|
||||
logger.debug("Load model.")
|
||||
|
||||
# 3. 恢复 sampler 的状态;
|
||||
dataloader_args = self.get_dataloader_args(dataloader)
|
||||
|
||||
sampler = dataloader_args.sampler
|
||||
if not (hasattr(sampler, 'load_state_dict') and callable(sampler.load_state_dict)):
|
||||
# 说明这里需要使用 ReproduceSampler 来弄一下了
|
||||
if self.is_distributed():
|
||||
raise RuntimeError(
|
||||
"It is not allowed to use single device checkpoint retraining before but ddp now.")
|
||||
sampler = ReproducibleBatchSampler(
|
||||
batch_sampler=sampler,
|
||||
batch_size=dataloader_args.batch_size,
|
||||
drop_last=dataloader_args.drop_last
|
||||
)
|
||||
sampler.load_state_dict(states['sampler_states'])
|
||||
|
||||
states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler)
|
||||
|
||||
# 4. 修改 trainer_state.batch_idx_in_epoch
|
||||
# sampler 是类似 RandomSampler 的sampler,不是 batch_sampler;
|
||||
if not isinstance(sampler, ReproducibleBatchSampler):
|
||||
if dataloader_args.drop_last:
|
||||
batch_idx_in_epoch = len(
|
||||
sampler) // dataloader_args.batch_size - sampler.num_left_samples // dataloader_args.batch_size
|
||||
else:
|
||||
batch_idx_in_epoch = (len(sampler) + dataloader_args.batch_size - 1) // dataloader_args.batch_size - \
|
||||
(sampler.num_left_samples + dataloader_args.batch_size - 1) // dataloader_args.batch_size
|
||||
# sampler 是 batch_sampler;
|
||||
else:
|
||||
batch_idx_in_epoch = sampler.batch_idx_in_epoch
|
||||
|
||||
states["batch_idx_in_epoch"] = batch_idx_in_epoch
|
||||
|
||||
return states
|
||||
|
||||
def get_evaluate_context(self):
|
||||
|
@ -316,7 +316,7 @@ def test_model_checkpoint_callback_2(
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1)
|
||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch_ddp", [6, 7]), ("torch", 7)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1)
|
||||
@pytest.mark.parametrize("version", [0, 1])
|
||||
@pytest.mark.parametrize("only_state_dict", [True, False])
|
||||
@magic_argv_env_context
|
||||
@ -466,7 +466,7 @@ def test_trainer_checkpoint_callback_1(
|
||||
|
||||
|
||||
# 通过自己编写 model_save_fn 和 model_load_fn 来测试 huggingface 的 transformers 的模型的保存和加载;
|
||||
@pytest.mark.parametrize("driver,device", [("torch_ddp", [0, 1]), ("torch", 1)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1)
|
||||
@pytest.mark.parametrize("driver,device", [("torch_ddp", [6, 7]), ("torch", 7)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1)
|
||||
@pytest.mark.parametrize("version", [0, 1])
|
||||
@magic_argv_env_context
|
||||
def test_trainer_checkpoint_callback_2(
|
||||
|
Loading…
Reference in New Issue
Block a user