Merge branch 'dev0.8.0' of github.com:fastnlp/fastNLP into dev0.8.0

This commit is contained in:
x54-729 2022-05-10 07:31:46 +00:00
commit 7f42451d6d
3 changed files with 82 additions and 10 deletions

View File

@ -3,9 +3,7 @@ __all__ = [
'Callback',
'Event',
'Filter',
'CallbackManager',
'CheckpointCallback',
'choose_progress_callback',
'ProgressCallback',
'RichCallback',
"LRSchedCallback",
@ -54,7 +52,6 @@ __all__ = [
'DataSet',
'FieldArray',
'Instance',
'ApplyResultException',
# drivers
"TorchSingleDriver",

View File

@ -180,14 +180,16 @@ class CallbackManager:
states[each_callback.callback_name]["states"] = each_callback.on_save_checkpoint(trainer)
if len(_duplicated_callbacks) > 0:
logger.warning(f"Notice these callbacks' `callback_name` are duplicated: {_duplicated_callbacks}, "
f"and we will only save the first callback's state we meet.")
logger.warning(f"Notice these callback_name: {_duplicated_callbacks} are duplicated, "
f"fastNLP will only save the first callback's state.")
# 2. 每一个具体的 callback 函数的 filter 的状态;
_record_duplicated_callback_names = set()
for each_callback_filters in self._callback_filters:
if each_callback_filters[0] not in _record_duplicated_callback_names:
_record_duplicated_callback_names.add(each_callback_filters[0])
if 'filter_states' not in states[each_callback_filters[0]]:
states[each_callback_filters[0]]["filter_states"] = {}
states[each_callback_filters[0]]["filter_states"][each_callback_filters[1]] = each_callback_filters[2].state_dict()
# 3. 保存 callback_counter
@ -214,13 +216,15 @@ class CallbackManager:
if each_callback_filters[0] in states:
if each_callback_filters[0] not in _already_loaded_callback_names:
_already_loaded_callback_names.add(each_callback_filters[0])
each_callback_filters[2].load_state_dict(states[each_callback_filters[0]]["filter_states"][each_callback_filters[1]])
if 'filter_states' in states[each_callback_filters[0]] and \
each_callback_filters[1] in states[each_callback_filters[0]]['filter_states']:
each_callback_filters[2].load_state_dict(states[each_callback_filters[0]]['filter_states'][each_callback_filters[1]])
else:
_duplicated_callback_names.add(each_callback_filters[0])
if len(_duplicated_callback_names) > 0:
logger.warning(f"Notice these callbacks' `callback_name` are duplicated: {_duplicated_callback_names}, "
f"and we will only load the first callback's state we meet.")
logger.rank_zero_warning(f"Notice these callback_name: {_duplicated_callback_names} are duplicated, "
f"fastNLP will only load the first callback's state.")
# 2. 再恢复每一个 callback 的单独的状态;
# 每一个我们自己提供的类 callback都需要重写其特定的 `callback_name` 方法,保证如果两个 callback 的 callback_name 一样,
@ -231,8 +235,6 @@ class CallbackManager:
_already_loaded_callback_names.add(each_callback.callback_name)
# 这里要注意,我们已经确保每一个 callback 的 `on_load_checkpoint` 函数拿到的就是其自己的状态;
each_callback.on_load_checkpoint(trainer, states[each_callback.callback_name]["states"])
else:
each_callback.on_load_checkpoint(trainer, None)
@property
def has_trainer_checkpoint(self) -> bool:

View File

@ -14,6 +14,7 @@ from tests.helpers.utils import magic_argv_env_context
from fastNLP.envs.distributed import rank_zero_rm
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
from tests.helpers.datasets.torch_data import TorchArgMaxDataset
from tests.helpers.utils import Capturing
from torchmetrics import Accuracy
from fastNLP.core.log import logger
@ -428,6 +429,78 @@ def test_trainer_checkpoint_callback_1(
dist.destroy_process_group()
@pytest.mark.torch
def test_load_state(model_and_optimizers):
try:
path = Path.cwd().joinpath(f"test_model_checkpoint")
path.mkdir(exist_ok=True, parents=True)
from fastNLP import Event, Callback
@Trainer.on(Event.on_before_backward(every=3), marker='all')
def print_outputs(*args):
print("????")
class StateCallback(Callback):
def __init__(self, name):
self.name = name
def on_save_checkpoint(self, trainer):
return {'name': self.name}
def on_load_checkpoint(self, trainer, states):
self.name = states['name']
def on_train_end(self, trainer):
print(self.name)
callbacks = [StateCallback('old_callback1'), StateCallback('old_callback2'),
CheckpointCallback(folder=path, every_n_epochs=1, save_object='trainer')]
trainer = Trainer(
model=model_and_optimizers.model,
driver='torch',
device='cpu',
optimizers=model_and_optimizers.optimizers,
train_dataloader=model_and_optimizers.train_dataloader,
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders,
input_mapping=model_and_optimizers.input_mapping,
output_mapping=model_and_optimizers.output_mapping,
metrics=model_and_optimizers.metrics,
n_epochs=3,
callbacks=callbacks,
output_from_new_proc="all"
)
trainer.run(num_eval_sanity_batch=0, num_train_batch_per_epoch=2)
all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()}
epoch_2_path = all_saved_model_paths['trainer-epoch_2']
callbacks = [StateCallback('new_callback1'), StateCallback('new_callback2')]
trainer = Trainer(
model=model_and_optimizers.model,
driver='torch',
device='cpu',
optimizers=model_and_optimizers.optimizers,
train_dataloader=model_and_optimizers.train_dataloader,
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders,
input_mapping=model_and_optimizers.input_mapping,
output_mapping=model_and_optimizers.output_mapping,
metrics=model_and_optimizers.metrics,
n_epochs=3,
callbacks=callbacks,
output_from_new_proc="all"
)
trainer.load(folder=epoch_2_path)
with Capturing() as output:
trainer.run(num_eval_sanity_batch=0, num_train_batch_per_epoch=2)
assert 'old_callback1' in output[0]
assert 'new_callback2' in output[0]
assert output[0].count('???')==1
finally:
rank_zero_rm(path)
@pytest.mark.torch
# 通过自己编写 model_save_fn 和 model_load_fn 来测试 huggingface 的 transformers 的模型的保存和加载;
@pytest.mark.parametrize("driver,device", [("torch_ddp", [6, 7]), ("torch", 7)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1)