mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-11-30 03:07:59 +08:00
Merge branch 'dev0.8.0' of github.com:fastnlp/fastNLP into dev0.8.0
This commit is contained in:
commit
7f42451d6d
@ -3,9 +3,7 @@ __all__ = [
|
||||
'Callback',
|
||||
'Event',
|
||||
'Filter',
|
||||
'CallbackManager',
|
||||
'CheckpointCallback',
|
||||
'choose_progress_callback',
|
||||
'ProgressCallback',
|
||||
'RichCallback',
|
||||
"LRSchedCallback",
|
||||
@ -54,7 +52,6 @@ __all__ = [
|
||||
'DataSet',
|
||||
'FieldArray',
|
||||
'Instance',
|
||||
'ApplyResultException',
|
||||
|
||||
# drivers
|
||||
"TorchSingleDriver",
|
||||
|
@ -180,14 +180,16 @@ class CallbackManager:
|
||||
states[each_callback.callback_name]["states"] = each_callback.on_save_checkpoint(trainer)
|
||||
|
||||
if len(_duplicated_callbacks) > 0:
|
||||
logger.warning(f"Notice these callbacks' `callback_name` are duplicated: {_duplicated_callbacks}, "
|
||||
f"and we will only save the first callback's state we meet.")
|
||||
logger.warning(f"Notice these callback_name: {_duplicated_callbacks} are duplicated, "
|
||||
f"fastNLP will only save the first callback's state.")
|
||||
|
||||
# 2. 每一个具体的 callback 函数的 filter 的状态;
|
||||
_record_duplicated_callback_names = set()
|
||||
for each_callback_filters in self._callback_filters:
|
||||
if each_callback_filters[0] not in _record_duplicated_callback_names:
|
||||
_record_duplicated_callback_names.add(each_callback_filters[0])
|
||||
if 'filter_states' not in states[each_callback_filters[0]]:
|
||||
states[each_callback_filters[0]]["filter_states"] = {}
|
||||
states[each_callback_filters[0]]["filter_states"][each_callback_filters[1]] = each_callback_filters[2].state_dict()
|
||||
|
||||
# 3. 保存 callback_counter;
|
||||
@ -214,13 +216,15 @@ class CallbackManager:
|
||||
if each_callback_filters[0] in states:
|
||||
if each_callback_filters[0] not in _already_loaded_callback_names:
|
||||
_already_loaded_callback_names.add(each_callback_filters[0])
|
||||
each_callback_filters[2].load_state_dict(states[each_callback_filters[0]]["filter_states"][each_callback_filters[1]])
|
||||
if 'filter_states' in states[each_callback_filters[0]] and \
|
||||
each_callback_filters[1] in states[each_callback_filters[0]]['filter_states']:
|
||||
each_callback_filters[2].load_state_dict(states[each_callback_filters[0]]['filter_states'][each_callback_filters[1]])
|
||||
else:
|
||||
_duplicated_callback_names.add(each_callback_filters[0])
|
||||
|
||||
if len(_duplicated_callback_names) > 0:
|
||||
logger.warning(f"Notice these callbacks' `callback_name` are duplicated: {_duplicated_callback_names}, "
|
||||
f"and we will only load the first callback's state we meet.")
|
||||
logger.rank_zero_warning(f"Notice these callback_name: {_duplicated_callback_names} are duplicated, "
|
||||
f"fastNLP will only load the first callback's state.")
|
||||
|
||||
# 2. 再恢复每一个 callback 的单独的状态;
|
||||
# 每一个我们自己提供的类 callback,都需要重写其特定的 `callback_name` 方法,保证如果两个 callback 的 callback_name 一样,
|
||||
@ -231,8 +235,6 @@ class CallbackManager:
|
||||
_already_loaded_callback_names.add(each_callback.callback_name)
|
||||
# 这里要注意,我们已经确保每一个 callback 的 `on_load_checkpoint` 函数拿到的就是其自己的状态;
|
||||
each_callback.on_load_checkpoint(trainer, states[each_callback.callback_name]["states"])
|
||||
else:
|
||||
each_callback.on_load_checkpoint(trainer, None)
|
||||
|
||||
@property
|
||||
def has_trainer_checkpoint(self) -> bool:
|
||||
|
@ -14,6 +14,7 @@ from tests.helpers.utils import magic_argv_env_context
|
||||
from fastNLP.envs.distributed import rank_zero_rm
|
||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
|
||||
from tests.helpers.datasets.torch_data import TorchArgMaxDataset
|
||||
from tests.helpers.utils import Capturing
|
||||
from torchmetrics import Accuracy
|
||||
from fastNLP.core.log import logger
|
||||
|
||||
@ -428,6 +429,78 @@ def test_trainer_checkpoint_callback_1(
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
@pytest.mark.torch
|
||||
def test_load_state(model_and_optimizers):
|
||||
try:
|
||||
path = Path.cwd().joinpath(f"test_model_checkpoint")
|
||||
path.mkdir(exist_ok=True, parents=True)
|
||||
from fastNLP import Event, Callback
|
||||
@Trainer.on(Event.on_before_backward(every=3), marker='all')
|
||||
def print_outputs(*args):
|
||||
print("????")
|
||||
|
||||
class StateCallback(Callback):
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
def on_save_checkpoint(self, trainer):
|
||||
return {'name': self.name}
|
||||
|
||||
def on_load_checkpoint(self, trainer, states):
|
||||
self.name = states['name']
|
||||
|
||||
def on_train_end(self, trainer):
|
||||
print(self.name)
|
||||
|
||||
callbacks = [StateCallback('old_callback1'), StateCallback('old_callback2'),
|
||||
CheckpointCallback(folder=path, every_n_epochs=1, save_object='trainer')]
|
||||
|
||||
trainer = Trainer(
|
||||
model=model_and_optimizers.model,
|
||||
driver='torch',
|
||||
device='cpu',
|
||||
optimizers=model_and_optimizers.optimizers,
|
||||
train_dataloader=model_and_optimizers.train_dataloader,
|
||||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders,
|
||||
input_mapping=model_and_optimizers.input_mapping,
|
||||
output_mapping=model_and_optimizers.output_mapping,
|
||||
metrics=model_and_optimizers.metrics,
|
||||
n_epochs=3,
|
||||
callbacks=callbacks,
|
||||
output_from_new_proc="all"
|
||||
)
|
||||
trainer.run(num_eval_sanity_batch=0, num_train_batch_per_epoch=2)
|
||||
|
||||
all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()}
|
||||
epoch_2_path = all_saved_model_paths['trainer-epoch_2']
|
||||
|
||||
callbacks = [StateCallback('new_callback1'), StateCallback('new_callback2')]
|
||||
trainer = Trainer(
|
||||
model=model_and_optimizers.model,
|
||||
driver='torch',
|
||||
device='cpu',
|
||||
optimizers=model_and_optimizers.optimizers,
|
||||
train_dataloader=model_and_optimizers.train_dataloader,
|
||||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders,
|
||||
input_mapping=model_and_optimizers.input_mapping,
|
||||
output_mapping=model_and_optimizers.output_mapping,
|
||||
metrics=model_and_optimizers.metrics,
|
||||
n_epochs=3,
|
||||
callbacks=callbacks,
|
||||
output_from_new_proc="all"
|
||||
)
|
||||
trainer.load(folder=epoch_2_path)
|
||||
with Capturing() as output:
|
||||
trainer.run(num_eval_sanity_batch=0, num_train_batch_per_epoch=2)
|
||||
|
||||
assert 'old_callback1' in output[0]
|
||||
assert 'new_callback2' in output[0]
|
||||
assert output[0].count('???')==1
|
||||
|
||||
finally:
|
||||
rank_zero_rm(path)
|
||||
|
||||
|
||||
@pytest.mark.torch
|
||||
# 通过自己编写 model_save_fn 和 model_load_fn 来测试 huggingface 的 transformers 的模型的保存和加载;
|
||||
@pytest.mark.parametrize("driver,device", [("torch_ddp", [6, 7]), ("torch", 7)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1)
|
||||
|
Loading…
Reference in New Issue
Block a user