mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-03 12:47:35 +08:00
Merge branch 'dev0.8.0' of github.com:fastnlp/fastNLP into dev0.8.0
This commit is contained in:
commit
1443065cbd
@ -10,7 +10,8 @@ __all__ = [
|
||||
'ProgressCallback',
|
||||
'RichCallback',
|
||||
"LRSchedCallback",
|
||||
'LoadBestModelCallback'
|
||||
'LoadBestModelCallback',
|
||||
"EarlyStopCallback"
|
||||
]
|
||||
|
||||
|
||||
@ -21,4 +22,5 @@ from .checkpoint_callback import ModelCheckpointCallback, TrainerCheckpointCallb
|
||||
from .progress_callback import choose_progress_callback, ProgressCallback, RichCallback
|
||||
from .lr_scheduler_callback import LRSchedCallback
|
||||
from .load_best_model_callback import LoadBestModelCallback
|
||||
from .early_stop_callback import EarlyStopCallback
|
||||
|
||||
|
@ -1,11 +1,15 @@
|
||||
from typing import Union, Callable, Dict, Optional
|
||||
from typing import Union, Callable, Dict, Optional, Any
|
||||
from abc import ABC
|
||||
|
||||
__all__ = [
|
||||
'Callback',
|
||||
]
|
||||
|
||||
from .callback_events import Events, EventsList, Filter
|
||||
from .utils import _get_monitor_value
|
||||
from fastNLP.core.callbacks.callback_events import _SingleEventState
|
||||
from fastNLP.core.log import logger
|
||||
from fastNLP.core.utils import apply_to_collection
|
||||
|
||||
|
||||
class Callback:
|
||||
@ -150,4 +154,82 @@ class _CallbackWrapper(Callback):
|
||||
return self.fn.__name__
|
||||
|
||||
|
||||
class CanItemDataType(ABC):
|
||||
"""
|
||||
检测可以进行传输的对象。
|
||||
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]:
|
||||
if cls is CanItemDataType:
|
||||
item = getattr(subclass, 'item', None)
|
||||
return callable(item)
|
||||
return NotImplemented
|
||||
|
||||
|
||||
class HasMonitorCallback(Callback):
|
||||
def __init__(self, monitor, larger_better, must_have_monitor=False):
|
||||
self.set_monitor(monitor, larger_better)
|
||||
self.must_have_moinitor = must_have_monitor
|
||||
|
||||
def set_monitor(self, monitor, larger_better):
|
||||
self.monitor = str(monitor) if monitor is not None else None
|
||||
self.larger_better = bool(larger_better)
|
||||
if larger_better:
|
||||
self.monitor_value = float('-inf')
|
||||
else:
|
||||
self.monitor_value = float('inf')
|
||||
self._real_monitor = self.monitor
|
||||
|
||||
def on_after_trainer_initialized(self, trainer, driver):
|
||||
"""
|
||||
如果本身的 monitor 没有设置,则根据 Trainer 中的 monitor 设置 monitor 。
|
||||
同时对于必须要有 monitor 设置的 callback ,该函数会进行检查。
|
||||
|
||||
:param trainer:
|
||||
:param driver:
|
||||
:return:
|
||||
"""
|
||||
if self.monitor is None and trainer.monitor is not None:
|
||||
self.set_monitor(monitor=trainer.monitor, larger_better=trainer.larger_better)
|
||||
if self.must_have_moinitor and self.monitor is None:
|
||||
raise RuntimeError(f"No `monitor` is set for {self.__class__.__name__}. "
|
||||
f"You can set it in the initialization or through Trainer.")
|
||||
|
||||
def get_monitor_value(self, results:Dict)->float:
|
||||
"""
|
||||
获取 monitor 的值,如果 monitor 没有直接找到,会尝试使用匹配的方式寻找,并把匹配到的设置到 self._real_monitor 属性上。
|
||||
|
||||
:param results:
|
||||
:return:
|
||||
"""
|
||||
if len(results)==0:
|
||||
return 0
|
||||
# 保证所有的 tensor 都被转换为了 python 特定的类型
|
||||
results = apply_to_collection(results, dtype=CanItemDataType, function=lambda x: x.item())
|
||||
use_monitor, monitor_value = _get_monitor_value(monitor=self.monitor,
|
||||
real_monitor=self._real_monitor,
|
||||
res=results)
|
||||
if self._real_monitor != use_monitor: # 发生了替换需要打印
|
||||
logger.warning(
|
||||
f"We can not find `{self.monitor}` in the evaluation result (with keys as {list(results.keys())}), "
|
||||
f"we use the `{use_monitor}` as the monitor for {self.__class__.__name__}.")
|
||||
self._real_monitor = use_monitor
|
||||
return monitor_value
|
||||
|
||||
def is_better_monitor_value(self, monitor_value: float, keep_if_better=True):
|
||||
"""
|
||||
检测 monitor_value 是否是更好的
|
||||
|
||||
:param monitor_value:
|
||||
:param keep_if_better: 如果传入的 monitor_value 值更好,则将其保存下来。
|
||||
:return:
|
||||
"""
|
||||
better = False
|
||||
if (self.larger_better and monitor_value > self.monitor_value) or \
|
||||
(not self.larger_better and monitor_value < self.monitor_value):
|
||||
better = True
|
||||
if keep_if_better:
|
||||
self.monitor_value = monitor_value
|
||||
return better
|
@ -5,12 +5,12 @@ __all__ = [
|
||||
import os
|
||||
from typing import Union, Optional, Callable, Dict, Sequence, Any, Mapping
|
||||
from pathlib import Path
|
||||
from abc import ABC
|
||||
import sys
|
||||
from copy import deepcopy
|
||||
|
||||
|
||||
import fastNLP
|
||||
from .callback import Callback, Filter
|
||||
from .callback import Callback, HasMonitorCallback
|
||||
from fastNLP.core.callbacks.utils import _get_monitor_value
|
||||
from fastNLP.core.log import logger
|
||||
from fastNLP.envs import FASTNLP_LAUNCH_TIME
|
||||
@ -18,22 +18,7 @@ from fastNLP.core.utils import synchronize_safe_rm, synchronize_mkdir
|
||||
from fastNLP.core.utils import apply_to_collection
|
||||
|
||||
|
||||
class CanItemDataType(ABC):
|
||||
"""
|
||||
检测可以进行传输的对象。
|
||||
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]:
|
||||
if cls is CanItemDataType:
|
||||
item = getattr(subclass, 'item', None)
|
||||
return callable(item)
|
||||
return NotImplemented
|
||||
|
||||
|
||||
|
||||
class CheckpointCallback(Callback):
|
||||
class CheckpointCallback(HasMonitorCallback):
|
||||
def __init__(
|
||||
self,
|
||||
monitor,
|
||||
@ -48,13 +33,8 @@ class CheckpointCallback(Callback):
|
||||
model_save_fn: Optional[Callable] = None,
|
||||
**kwargs,
|
||||
):
|
||||
# 我们新加了逻辑,如果 checkpoint callback 自己没有设置 monitor 和 larger_better,那么我们会将其在 trainer 中的设置赋值给它们;
|
||||
# if monitor is None and save_topk is not None:
|
||||
# raise ValueError("Parameter `monitor` must be set when you want to use 'save_topk'.")
|
||||
|
||||
if monitor is not None and not isinstance(monitor, str):
|
||||
raise ValueError("Parameter `monitor` should be of 'str' type.")
|
||||
|
||||
super().__init__(monitor=monitor, larger_better=larger_better,
|
||||
must_have_monitor=save_topk is not None)
|
||||
if save_folder is None:
|
||||
logger.warning(
|
||||
"Parameter `path` is None, and we will use the current work directory to find and load your model.")
|
||||
@ -92,13 +72,12 @@ class CheckpointCallback(Callback):
|
||||
"`BaseException` type.")
|
||||
else:
|
||||
save_on_exception = []
|
||||
self.monitor = monitor
|
||||
|
||||
self.save_folder = Path(save_folder)
|
||||
self.save_every_n_epochs = save_every_n_epochs
|
||||
self.save_every_n_batches = save_every_n_batches
|
||||
self.save_last = save_last
|
||||
self.save_topk = save_topk
|
||||
self.larger_better = larger_better
|
||||
self.only_state_dict = only_state_dict
|
||||
self.model_save_fn = model_save_fn
|
||||
self.save_on_exception = save_on_exception
|
||||
@ -108,12 +87,6 @@ class CheckpointCallback(Callback):
|
||||
self._topk_model = {}
|
||||
self._topn = 0 # 表示目前已经保存了几个最好的模型;
|
||||
|
||||
# 因为我们在 `_get_validate_metric` 函数中,当在返回的 `validate_res` 字典中找不到 `monitor` 时,是使用匹配找到的
|
||||
# key 对应的 value 当做结果;但是这样存在的一个问题在于如果用户传入的 metric 返回的 sub_metric 的名字可能会混淆,并且其在下一次
|
||||
# 训练的代码中修改了这些 sub_metric 返回的顺序,那么就会导致模糊匹配拿到的 key 和 value 与之前的不是同一个,这显然不是合理的行为;
|
||||
# 因此我们通过该变量来表示我们通过模糊匹配拿到的 key;
|
||||
self._real_monitor = self.monitor
|
||||
|
||||
# 注意这里应当保证只有进程 0 在执行这个操作,因为当用户使用 python -m torch.distributed.launch 来拉起进程的时候,
|
||||
# FASTNLP_LAUNCH_TIME 在每一个进程上的值是不一样的;
|
||||
self.timestamp_path = self.save_folder.joinpath(os.environ[FASTNLP_LAUNCH_TIME])
|
||||
@ -121,20 +94,15 @@ class CheckpointCallback(Callback):
|
||||
synchronize_mkdir(self.timestamp_path)
|
||||
|
||||
def on_after_trainer_initialized(self, trainer, driver):
|
||||
if self.monitor is None:
|
||||
if trainer.monitor is not None:
|
||||
self.monitor = trainer.monitor
|
||||
self.larger_better = trainer.larger_better
|
||||
elif self.save_topk is not None:
|
||||
raise RuntimeError("You are using `topk` mode, but you have not set the `monitor` value either in this"
|
||||
"callback or in trainer.")
|
||||
else:
|
||||
self.monitor = None
|
||||
if self.save_topk is not None:
|
||||
super().on_after_trainer_initialized(trainer, driver)
|
||||
if self.save_topk is not None and trainer.evaluator is None:
|
||||
raise RuntimeError("You are using `topk` mode, but there is no `evaluator` in trainer.")
|
||||
logger.warning("You set `save_topk`, but `validate_dataloaders` is not set in Trainer.")
|
||||
|
||||
def on_validate_end(self, trainer, validate_res):
|
||||
self._save_topk(trainer, validate_res)
|
||||
def on_validate_end(self, trainer, results):
|
||||
if len(results) == 0:
|
||||
return
|
||||
self._save_topk(trainer, results)
|
||||
|
||||
def on_train_epoch_end(self, trainer: "fastNLP.Trainer"):
|
||||
if trainer.cur_epoch_idx % self.save_every_n_epochs == 0:
|
||||
@ -157,7 +125,7 @@ class CheckpointCallback(Callback):
|
||||
|
||||
def on_sanity_check_end(self, trainer, sanity_check_res):
|
||||
# 主要核对一下 monitor 是否存在。
|
||||
self._get_validate_metric(sanity_check_res)
|
||||
self.get_monitor_value(results=sanity_check_res)
|
||||
|
||||
def on_save_checkpoint(self, trainer) -> Dict:
|
||||
"""
|
||||
@ -168,8 +136,7 @@ class CheckpointCallback(Callback):
|
||||
|
||||
states = {}
|
||||
states['timestamp_path'] = str(self.timestamp_path.absolute())
|
||||
states['_topk_model'] = apply_to_collection(self._topk_model, dtype=CanItemDataType,
|
||||
function=lambda x:x.item())
|
||||
states['_topk_model'] = deepcopy(self._topk_model)
|
||||
states['save_topk'] = 0 if self.save_topk is None else self.save_topk
|
||||
states['_real_monitor'] = self._real_monitor
|
||||
return states
|
||||
@ -190,30 +157,30 @@ class CheckpointCallback(Callback):
|
||||
self._topk_model.update(self._topk_model)
|
||||
self._real_monitor = states["real_monitor"]
|
||||
|
||||
def _save_topk(self, trainer: "fastNLP.Trainer", validate_res: Dict):
|
||||
def _save_topk(self, trainer: "fastNLP.Trainer", results: Dict):
|
||||
"""
|
||||
根据validate_res决定保存哪些model的函数。会自动移除掉不满足topk的文件夹。
|
||||
|
||||
:param trainer:
|
||||
:param validate_res:
|
||||
:param results:
|
||||
:return:
|
||||
"""
|
||||
if self.save_topk is not None:
|
||||
_metric_value = self._get_validate_metric(validate_res)
|
||||
monitor_value = self.get_monitor_value(results=results)
|
||||
folder_name = f"{self.folder_prefix}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}" \
|
||||
f"-{self._real_monitor}_{_metric_value}"
|
||||
f"-{self._real_monitor}_{monitor_value}"
|
||||
|
||||
_should_save = False
|
||||
if self._topn < self.save_topk:
|
||||
self._topk_model[folder_name] = _metric_value
|
||||
self._topk_model[folder_name] = monitor_value
|
||||
self._topn += 1
|
||||
_should_save = True
|
||||
else:
|
||||
_least_valuable_model = (min if self.larger_better else max)(self._topk_model,
|
||||
key=lambda x: self._topk_model[x])
|
||||
if (self.larger_better and _metric_value > self._topk_model[_least_valuable_model]) or \
|
||||
(self.larger_better is False and _metric_value < self._topk_model[_least_valuable_model]):
|
||||
self._topk_model[folder_name] = _metric_value
|
||||
if (self.larger_better and monitor_value > self._topk_model[_least_valuable_model]) or \
|
||||
(self.larger_better is False and monitor_value < self._topk_model[_least_valuable_model]):
|
||||
self._topk_model[folder_name] = monitor_value
|
||||
_should_save = True
|
||||
self._topk_model.pop(_least_valuable_model)
|
||||
synchronize_safe_rm(self.timestamp_path.joinpath(_least_valuable_model))
|
||||
@ -249,7 +216,11 @@ class CheckpointCallback(Callback):
|
||||
:return:
|
||||
"""
|
||||
use_monitor, value = _get_monitor_value(monitor=self.monitor, real_monitor=self._real_monitor, res=res)
|
||||
if self._real_monitor != use_monitor:
|
||||
logger.warning(f"We can not find `{self._real_monitor}` in the evaluation result (with keys as {list(res.keys())}), "
|
||||
f"we use the `{use_monitor}` as the monitor for {self.__class__.__name__}.")
|
||||
self._real_monitor = use_monitor
|
||||
|
||||
return value
|
||||
|
||||
@property
|
||||
@ -277,7 +248,7 @@ class ModelCheckpointCallback(CheckpointCallback):
|
||||
若 model_save_fn 不为 None,则 fastNLP 将 folder 绝对路径传递给该函数,fastNLP 不在该 folder 下创建任何文件。
|
||||
|
||||
:param monitor: 监控的 metric 的名称。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配
|
||||
的那个作为 monitor 。
|
||||
的那个作为 monitor 。如果为 None 将尝试从 Trainer 中获取该值。
|
||||
:param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的
|
||||
时间戳文件夹中。如果为 None ,默认使用当前文件夹。
|
||||
:param save_every_n_epochs: 多少个 epoch 保存一次。
|
||||
@ -324,7 +295,7 @@ class TrainerCheckpointCallback(CheckpointCallback):
|
||||
若 model_save_fn 不为 None,则 fastNLP 只会在每个 folder 下生成 fastnlp_trainer.pkl.tar 文件。
|
||||
|
||||
:param monitor: 监控的 metric 的名称。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配
|
||||
的那个作为 monitor 。
|
||||
的那个作为 monitor 。如果为 None 将尝试从 Trainer 中获取该值。
|
||||
:param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的
|
||||
时间戳文件夹中。如果为 None ,默认使用当前文件夹。
|
||||
:param save_every_n_epochs: 多少个 epoch 保存一次。
|
||||
|
61
fastNLP/core/callbacks/early_stop_callback.py
Normal file
61
fastNLP/core/callbacks/early_stop_callback.py
Normal file
@ -0,0 +1,61 @@
|
||||
__all__ = [
|
||||
'EarlyStopCallback'
|
||||
]
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from .callback import HasMonitorCallback
|
||||
from fastNLP.core.utils.exceptions import EarlyStopException
|
||||
|
||||
|
||||
class EarlyStopCallback(HasMonitorCallback):
|
||||
def __init__(self, monitor:str=None, larger_better:bool=True, patience:int=10):
|
||||
"""
|
||||
|
||||
:param str monitor: 监控的 metric 值。如果为 None,将尝试使用 Trainer 设置的 monitor 。
|
||||
:param larger_better: monitor 的值是否是越大越好。
|
||||
:param patience: 多少次 validate 不没有提升就停止。
|
||||
"""
|
||||
super(EarlyStopCallback, self).__init__(monitor=monitor, larger_better=larger_better, must_have_monitor=True)
|
||||
self.wait = 0
|
||||
self.patience = patience
|
||||
|
||||
def on_validate_end(self, trainer, results):
|
||||
if len(results)==0:
|
||||
return
|
||||
monitor_value = self.get_monitor_value(results)
|
||||
if self.is_better_monitor_value(monitor_value, keep_if_better=True):
|
||||
self.wait = 0
|
||||
else:
|
||||
self.wait += 1
|
||||
|
||||
def on_fetch_data_begin(self, trainer):
|
||||
# 当是 step validate 的时候,下一步执行的就是这个, 所以在这里检查。
|
||||
if self.wait >= self.patience:
|
||||
raise EarlyStopException(f"After {self.wait} validations, no improvement for "
|
||||
f"metric `{self._real_monitor}`")
|
||||
|
||||
def on_train_epoch_begin(self, trainer):
|
||||
# 当是 epoch validate 的时候,下一步执行的就是这个, 所以在这里检查。
|
||||
if self.wait >= self.patience:
|
||||
raise EarlyStopException(f"After {self.wait} validations, no improvement for "
|
||||
f"metric `{self._real_monitor}`(best value: {self.monitor_value})")
|
||||
|
||||
def on_save_checkpoint(self, trainer) -> Dict:
|
||||
states = {
|
||||
'patience': self.patience,
|
||||
'wait': self.wait,
|
||||
'monitor': self.monitor,
|
||||
'monitor_value': self.monitor_value
|
||||
}
|
||||
return states
|
||||
|
||||
def on_load_checkpoint(self, trainer, states):
|
||||
self.patience = states['patience']
|
||||
self.wait = states['wait']
|
||||
self.monitor = states['monitor']
|
||||
self.monitor_value = float(states['monitor_value'])
|
||||
|
||||
def callback_name(self):
|
||||
return f'EarlyStopCallback#monitor-{self.monitor}#patience-{self.patience}'
|
||||
|
@ -4,8 +4,7 @@ __all__ = [
|
||||
|
||||
import os
|
||||
from typing import Optional, Callable
|
||||
from .callback import Callback
|
||||
from .utils import _get_monitor_value
|
||||
from .callback import HasMonitorCallback
|
||||
from io import BytesIO
|
||||
import shutil
|
||||
|
||||
@ -14,15 +13,15 @@ from fastNLP.core.log import logger
|
||||
from fastNLP.envs import all_rank_call
|
||||
|
||||
|
||||
class LoadBestModelCallback(Callback):
|
||||
def __init__(self, monitor:str, larger_better:bool = True, only_state_dict:bool = True,
|
||||
class LoadBestModelCallback(HasMonitorCallback):
|
||||
def __init__(self, monitor:str=None, larger_better:bool = True, only_state_dict:bool = True,
|
||||
save_folder:Optional[str] = None, model_save_fn:Optional[Callable] = None,
|
||||
model_load_fn:Optional[Callable] = None,
|
||||
delete_after_train:bool = True):
|
||||
"""
|
||||
保存最佳的 monitor 值最佳的模型,并在训练结束的时候重新加载模型。仅在训练正常结束的时候才能加载最好的模型。
|
||||
|
||||
:param str monitor: 监控的 metric 值。
|
||||
:param str monitor: 监控的 metric 值。如果为 None,将尝试使用 Trainer 设置的 monitor 。
|
||||
:param larger_better: 该 metric 值是否是越大越好。
|
||||
:param save_folder: 保存的文件夹,如果为空,则保存在内存中。不为空,则保存一份权重到文件中,当为多机训练,且本值不为空时,请确保
|
||||
不同的机器均可访问当该路径。当 model_save_fn 不为 None 时该值一定不能为空。
|
||||
@ -33,6 +32,7 @@ class LoadBestModelCallback(Callback):
|
||||
请在函数内完成对模型的加载。
|
||||
:param delete_after_train: 在训练结束后是否删掉模型。
|
||||
"""
|
||||
super().__init__(monitor=monitor, larger_better=larger_better, must_have_monitor=True)
|
||||
if model_load_fn is not None:
|
||||
assert callable(model_load_fn), "`model_load_fn` must be a callable object."
|
||||
assert model_save_fn is not None, "`model_load_fn` and `model_save_fn` must be passed at the same time."
|
||||
@ -56,15 +56,11 @@ class LoadBestModelCallback(Callback):
|
||||
self.real_save_folder = None
|
||||
self.buffer = BytesIO()
|
||||
|
||||
self.monitor = monitor
|
||||
self.larger_better = larger_better
|
||||
self.save_folder = save_folder
|
||||
self.only_state_dict = only_state_dict
|
||||
self.model_save_fn = model_save_fn
|
||||
self.model_load_fn = model_load_fn
|
||||
self.delete_after_after = delete_after_train
|
||||
self._real_monitor = None
|
||||
self.monitor_value = float('-inf') if larger_better else float('inf')
|
||||
|
||||
def on_after_trainer_initialized(self, trainer, driver):
|
||||
if self.save_folder is not None and driver.is_distributed() and int(os.environ.get(FASTNLP_BACKEND_LAUNCH, 0))==1:
|
||||
@ -76,13 +72,16 @@ class LoadBestModelCallback(Callback):
|
||||
raise RuntimeError(f"Currently {driver.__class__.__name__} does not support using `save_folder` to "
|
||||
f"save best model when launch using script.")
|
||||
|
||||
super().on_after_trainer_initialized(trainer, driver)
|
||||
|
||||
def on_sanity_check_end(self, trainer, sanity_check_res):
|
||||
self.get_monitor_value(sanity_check_res)
|
||||
|
||||
def on_validate_end(self, trainer, results):
|
||||
self._real_monitor, monitor_value = _get_monitor_value(monitor=self.monitor,
|
||||
real_monitor=self._real_monitor,
|
||||
res=results)
|
||||
if (monitor_value < self.monitor_value and self.larger_better is False) or \
|
||||
(monitor_value > self.monitor_value and self.larger_better):
|
||||
self.monitor_value = monitor_value
|
||||
if len(results)==0:
|
||||
return
|
||||
monitor_value = self.get_monitor_value(results)
|
||||
if self.is_better_monitor_value(monitor_value, keep_if_better=True):
|
||||
if self.real_save_folder:
|
||||
trainer.save_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict,
|
||||
model_save_fn=self.model_save_fn)
|
||||
|
@ -8,7 +8,7 @@ __all__ = [
|
||||
'RichCallback'
|
||||
]
|
||||
|
||||
from .callback import Callback
|
||||
from .callback import HasMonitorCallback
|
||||
from fastNLP.core.callbacks.utils import _get_monitor_value
|
||||
from fastNLP.core.utils import f_rich_progress
|
||||
from fastNLP.core.log import logger
|
||||
@ -28,15 +28,13 @@ def choose_progress_callback(progress_bar:str):
|
||||
return None
|
||||
|
||||
|
||||
class ProgressCallback(Callback):
|
||||
class ProgressCallback(HasMonitorCallback):
|
||||
def on_train_end(self, trainer):
|
||||
f_rich_progress.stop()
|
||||
|
||||
def on_sanity_check_end(self, trainer, sanity_check_res):
|
||||
if len(sanity_check_res) and getattr(self, 'monitor', None) is not None:
|
||||
self._real_monitor, monitor_value = _get_monitor_value(monitor=self.monitor,
|
||||
real_monitor=self._real_monitor,
|
||||
res=sanity_check_res)
|
||||
self.get_monitor_value(sanity_check_res)
|
||||
|
||||
|
||||
class RichCallback(ProgressCallback):
|
||||
@ -46,28 +44,22 @@ class RichCallback(ProgressCallback):
|
||||
|
||||
:param print_every: 多少个 batch 更新一次显示。
|
||||
:param loss_round_ndigit: 显示的 loss 保留多少位有效数字
|
||||
:param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。
|
||||
:param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。如果为 None ,会尝试使用 trainer 中设置的 monitor 。
|
||||
:param larger_better: 是否是monitor的结果越大越好。
|
||||
:param format_json: 是否format json再打印
|
||||
"""
|
||||
super().__init__()
|
||||
super().__init__(monitor=monitor, larger_better=larger_better, must_have_monitor=False)
|
||||
self.print_every = print_every
|
||||
self.progress_bar = f_rich_progress
|
||||
self.task2id = {}
|
||||
self.loss = 0
|
||||
self.loss_round_ndigit = loss_round_ndigit
|
||||
self.monitor = monitor
|
||||
self.larger_better = larger_better
|
||||
if larger_better:
|
||||
self.monitor_value = float('-inf')
|
||||
else:
|
||||
self.monitor_value = float('inf')
|
||||
self._real_monitor = monitor
|
||||
self.format_json = format_json
|
||||
|
||||
def on_after_trainer_initialized(self, trainer, driver):
|
||||
if not self.progress_bar.disable:
|
||||
self.progress_bar.set_disable(flag=trainer.driver.get_local_rank() != 0)
|
||||
super(RichCallback, self).on_after_trainer_initialized(trainer, driver)
|
||||
|
||||
def on_train_begin(self, trainer):
|
||||
self.task2id['epoch'] = self.progress_bar.add_task(description='Epoch:0', total=trainer.n_epochs,
|
||||
@ -109,16 +101,12 @@ class RichCallback(ProgressCallback):
|
||||
text_style = ''
|
||||
characters = '-'
|
||||
if self.monitor is not None:
|
||||
self._real_monitor, monitor_value = _get_monitor_value(monitor=self.monitor,
|
||||
real_monitor=self._real_monitor,
|
||||
res=results)
|
||||
if (self.larger_better and monitor_value > self.monitor_value) or \
|
||||
(not self.larger_better and monitor_value < self.monitor_value):
|
||||
monitor_value = self.get_monitor_value(results)
|
||||
if self.is_better_monitor_value(monitor_value, keep_if_better=True):
|
||||
if abs(self.monitor_value) != float('inf'):
|
||||
rule_style = 'spring_green3'
|
||||
text_style = '[bold]'
|
||||
characters = '+'
|
||||
self.monitor_value = monitor_value
|
||||
self.progress_bar.print()
|
||||
self.progress_bar.console.rule(text_style+f"Eval. results on Epoch:{trainer.cur_epoch_idx}, "
|
||||
f"Batch:{trainer.batch_idx_in_epoch}",
|
||||
@ -151,18 +139,12 @@ class RawTextCallback(ProgressCallback):
|
||||
:param larger_better: 是否是monitor的结果越大越好。
|
||||
:param format_json: 是否format json再打印
|
||||
"""
|
||||
super().__init__()
|
||||
super().__init__(monitor=monitor, larger_better=larger_better, must_have_monitor=False)
|
||||
self.print_every = print_every
|
||||
self.task2id = {}
|
||||
self.loss = 0
|
||||
self.loss_round_ndigit = loss_round_ndigit
|
||||
self.monitor = monitor
|
||||
self.larger_better = larger_better
|
||||
if larger_better:
|
||||
self.monitor_value = float('-inf')
|
||||
else:
|
||||
self.monitor_value = float('inf')
|
||||
self._real_monitor = monitor
|
||||
self.set_monitor(monitor, larger_better)
|
||||
self.format_json = format_json
|
||||
self.num_signs = 10
|
||||
|
||||
@ -189,14 +171,10 @@ class RawTextCallback(ProgressCallback):
|
||||
base_text = f'Eval. results on Epoch:{trainer.cur_epoch_idx}, Batch:{trainer.batch_idx_in_epoch}'
|
||||
text = ''
|
||||
if self.monitor is not None:
|
||||
self._real_monitor, monitor_value = _get_monitor_value(monitor=self.monitor,
|
||||
real_monitor=self._real_monitor,
|
||||
res=results)
|
||||
if (self.larger_better and monitor_value > self.monitor_value) or \
|
||||
(not self.larger_better and monitor_value < self.monitor_value):
|
||||
monitor_value = self.get_monitor_value(results)
|
||||
if self.is_better_monitor_value(monitor_value, keep_if_better=True):
|
||||
if abs(self.monitor_value) != float('inf'):
|
||||
text = '+'*self.num_signs + base_text + '+'*self.num_signs
|
||||
self.monitor_value = monitor_value
|
||||
if len(text) == 0:
|
||||
text = '-'*self.num_signs + base_text + '-'*self.num_signs
|
||||
|
||||
|
@ -19,23 +19,31 @@ def _get_monitor_value(monitor: str, real_monitor: Optional[str], res: dict) ->(
|
||||
if monitor in res:
|
||||
return monitor, res[monitor]
|
||||
|
||||
if real_monitor in res:
|
||||
return real_monitor, res[real_monitor]
|
||||
|
||||
pairs = []
|
||||
for idx, (key, value) in enumerate(res.items()):
|
||||
match = SequenceMatcher(None, key, monitor).find_longest_match(0, len(key), 0, len(monitor))
|
||||
pairs.append((key, value, match.size, idx))
|
||||
match_size = _match_length(monitor, key)
|
||||
pairs.append((key, value, match_size, idx))
|
||||
|
||||
pairs.sort(key=lambda pair: (pair[2], -pair[3]), reverse=True)
|
||||
key, value, match_size = pairs[0][:3]
|
||||
|
||||
if real_monitor is not None and real_monitor in res and real_monitor != key:
|
||||
# 如果 real_monitor 比新找的更长就继续用之前的。
|
||||
match = SequenceMatcher(None, real_monitor, monitor).find_longest_match(0, len(real_monitor), 0, len(monitor))
|
||||
if match.size > match_size:
|
||||
return real_monitor, res[real_monitor]
|
||||
|
||||
logger.warning(f"We can not find `{monitor}` in the evaluation result (with keys as {list(res.keys())}), "
|
||||
f"we use the `{key}` as the monitor.")
|
||||
real_monitor = key
|
||||
return real_monitor, value
|
||||
return key, value
|
||||
|
||||
|
||||
def _match_length(a:str, b:str)->int:
|
||||
"""
|
||||
需要把长度短的放在前面
|
||||
|
||||
:param a:
|
||||
:param b:
|
||||
:return:
|
||||
"""
|
||||
short = a if len(a) < len(b) else b
|
||||
long = a if len(a)>=len(b) else b
|
||||
match = SequenceMatcher(None, short, long).find_longest_match(0, len(short), 0, len(long))
|
||||
return match.size
|
||||
|
||||
|
||||
|
@ -25,6 +25,7 @@ from fastNLP.core.utils import check_fn_not_empty_params, get_fn_arg_names, matc
|
||||
from fastNLP.envs import rank_zero_call
|
||||
from fastNLP.core.log import logger
|
||||
from fastNLP.envs import FASTNLP_MODEL_FILENAME
|
||||
from fastNLP.core.utils.exceptions import EarlyStopException
|
||||
|
||||
|
||||
class Trainer(TrainerEventTrigger):
|
||||
@ -50,6 +51,8 @@ class Trainer(TrainerEventTrigger):
|
||||
model_wo_auto_param_call: bool = False,
|
||||
accumulation_steps: int = 1,
|
||||
fp16: bool = False,
|
||||
monitor: str = None,
|
||||
larger_better: bool = True,
|
||||
marker: Optional[str] = None,
|
||||
**kwargs
|
||||
):
|
||||
@ -106,6 +109,10 @@ class Trainer(TrainerEventTrigger):
|
||||
为 False,那么我们会将 batch 直接透传给 forward 函数。注意上述逻辑同样应用于 `train_step`, `validate_step` 和 `test_step`;
|
||||
:param accumulation_steps: 梯度累积的步数,表示每隔几个 batch 优化器迭代一次;默认为 1;
|
||||
:param fp16: 是否开启混合精度训练;默认为 False;
|
||||
:param monitor: 当存在 validate_dataloaders 时,默认的 monitor metric 的名字。传入的 callback 如果有 monitor 参数且没有
|
||||
在 callback 初始化设定的,将采取这个值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配
|
||||
的那个作为 monitor 。
|
||||
:param larger_better: monitor 的值是否是越大越好。
|
||||
:param marker: 用于标记一个 Trainer 实例,从而在用户调用 `Trainer.on` 函数时,标记该 callback 函数属于哪一个具体的 'trainer' 实例;默认为 None;
|
||||
:param kwargs: 一些其它的可能需要的参数;
|
||||
torch_non_blocking: 表示用于 pytorch 的 tensor 的 to 方法的参数 non_blocking;
|
||||
@ -213,6 +220,8 @@ class Trainer(TrainerEventTrigger):
|
||||
self.evaluator = None
|
||||
self.epoch_validate = lambda *args, **kwargs: ...
|
||||
self.step_validate = lambda *args, **kwargs: ...
|
||||
self.monitor = monitor
|
||||
self.larger_better = larger_better
|
||||
if metrics is not None and validate_dataloaders is not None:
|
||||
if not callable(validate_every) and (not isinstance(validate_every, int) or validate_every == 0):
|
||||
raise ValueError("Parameter 'validate_every' should be set to 'int' type and either < 0 or > 0.")
|
||||
@ -242,6 +251,7 @@ class Trainer(TrainerEventTrigger):
|
||||
else:
|
||||
# validate_every > 0
|
||||
self._step_validate_filter = Filter(every=validate_every)
|
||||
|
||||
self.metrics = metrics
|
||||
self.validate_every = validate_every
|
||||
|
||||
@ -323,6 +333,10 @@ class Trainer(TrainerEventTrigger):
|
||||
self.driver.barrier()
|
||||
self.on_train_end()
|
||||
self.driver.barrier()
|
||||
|
||||
except EarlyStopException as e:
|
||||
logger.info(f"Catch early stop exception: {e.msg}.")
|
||||
self.on_exception(e)
|
||||
except KeyboardInterrupt as e:
|
||||
self.driver.on_exception()
|
||||
self.on_exception(e)
|
||||
|
@ -29,14 +29,16 @@ def _compute_f_pre_rec(beta_square, tp, fn, fp):
|
||||
|
||||
|
||||
class ClassifyFPreRecMetric(Metric):
|
||||
def __init__(self, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = False,
|
||||
tag_vocab: Vocabulary = None, encoding_type: str = None, ignore_labels: List[str] = None,
|
||||
only_gross: bool = True, f_type='micro', beta=1) -> None:
|
||||
def __init__(self, tag_vocab: Vocabulary = None, ignore_labels: List[str] = None, num_class: int = 0,
|
||||
only_gross: bool = True, f_type='micro', beta=1, backend: Union[str, Backend, None] = 'auto',
|
||||
aggregate_when_get_metric: bool = False) -> None:
|
||||
super(ClassifyFPreRecMetric, self).__init__(backend=backend,
|
||||
aggregate_when_get_metric=aggregate_when_get_metric)
|
||||
if f_type not in ('micro', 'macro'):
|
||||
raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type))
|
||||
|
||||
if tag_vocab:
|
||||
if not isinstance(tag_vocab, Vocabulary):
|
||||
raise TypeError("tag_vocab can only be fastNLP.Vocabulary, not {}.".format(type(tag_vocab)))
|
||||
self.ignore_labels = ignore_labels
|
||||
self.f_type = f_type
|
||||
self.beta = beta
|
||||
@ -45,9 +47,32 @@ class ClassifyFPreRecMetric(Metric):
|
||||
|
||||
self.tag_vocab = tag_vocab
|
||||
|
||||
self._tp, self._fp, self._fn = defaultdict(partial(self.register_element, aggregate_method='sum')),\
|
||||
defaultdict(partial(self.register_element, aggregate_method='sum')),\
|
||||
defaultdict(partial(self.register_element, aggregate_method='sum'))
|
||||
self._tp = {}
|
||||
self._fp = {}
|
||||
self._fn = {}
|
||||
if tag_vocab:
|
||||
for word, _ in tag_vocab:
|
||||
word = word.lower()
|
||||
if word != 'o':
|
||||
word = word[2:]
|
||||
if word in self._true_positives:
|
||||
continue
|
||||
self._tp[word] = self.register_element(name=f'tp_{word}', aggregate_method='sum',
|
||||
backend=backend)
|
||||
self._fn[word] = self.register_element(name=f'fn_{word}', aggregate_method='sum',
|
||||
backend=backend)
|
||||
self._fp[word] = self.register_element(name=f'fp_{word}', aggregate_method='sum',
|
||||
backend=backend)
|
||||
elif num_class > 0:
|
||||
for word in range(num_class):
|
||||
self._tp[word] = self.register_element(name=f'tp_{word}', aggregate_method='sum',
|
||||
backend=backend)
|
||||
self._fn[word] = self.register_element(name=f'fn_{word}', aggregate_method='sum',
|
||||
backend=backend)
|
||||
self._fp[word] = self.register_element(name=f'fp_{word}', aggregate_method='sum',
|
||||
backend=backend)
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
def get_metric(self) -> dict:
|
||||
r"""
|
||||
@ -68,9 +93,11 @@ class ClassifyFPreRecMetric(Metric):
|
||||
tag_name = self.tag_vocab.to_word(tag)
|
||||
else:
|
||||
tag_name = int(tag)
|
||||
tp = self._tp[tag]
|
||||
fn = self._fn[tag]
|
||||
fp = self._fp[tag]
|
||||
tp = self._tp[tag].get_scalar()
|
||||
fn = self._fn[tag].get_scalar()
|
||||
fp = self._fp[tag].get_scalar()
|
||||
if tp == fn == fp == 0:
|
||||
continue
|
||||
f, pre, rec = _compute_f_pre_rec(self.beta_square, tp, fn, fp)
|
||||
f_sum += f
|
||||
pre_sum += pre
|
||||
@ -90,20 +117,29 @@ class ClassifyFPreRecMetric(Metric):
|
||||
|
||||
if self.f_type == 'micro':
|
||||
f, pre, rec = _compute_f_pre_rec(self.beta_square,
|
||||
sum(self._tp.values()),
|
||||
sum(self._fn.values()),
|
||||
sum(self._fp.values()))
|
||||
sum(val.get_scalar() for val in self._tp.values()),
|
||||
sum(val.get_scalar() for val in self._fn.values()),
|
||||
sum(val.get_scalar() for val in self._fp.values()))
|
||||
evaluate_result['f'] = f
|
||||
evaluate_result['pre'] = pre
|
||||
evaluate_result['rec'] = rec
|
||||
|
||||
|
||||
for key, value in evaluate_result.items():
|
||||
evaluate_result[key] = round(value, 6)
|
||||
|
||||
return evaluate_result
|
||||
|
||||
def update(self, pred, target, seq_len=None):
|
||||
r"""
|
||||
evaluate函数将针对一个批次的预测结果做评价指标的累计
|
||||
|
||||
:param torch.Tensor pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]),
|
||||
torch.Size([B, max_len]), 或者torch.Size([B, max_len, n_classes])
|
||||
:param torch.Tensor target: 真实值的tensor, tensor的形状可以是Element's can be: torch.Size([B,]),
|
||||
torch.Size([B,]), torch.Size([B, max_len]), 或者torch.Size([B, max_len])
|
||||
:param torch.Tensor seq_len: 序列长度标记, 标记的形状可以是None, None, torch.Size([B]), 或者torch.Size([B]).
|
||||
如果mask也被传进来的话seq_len会被忽略.
|
||||
"""
|
||||
pred = self.tensor2numpy(pred)
|
||||
target = self.tensor2numpy(target)
|
||||
if seq_len is not None:
|
||||
@ -122,14 +158,14 @@ class ClassifyFPreRecMetric(Metric):
|
||||
f"pred have element numbers: {len(target.flatten())}")
|
||||
|
||||
pass
|
||||
elif len(pred.ndim) == len(target.ndim) + 1:
|
||||
elif pred.ndim == target.ndim + 1:
|
||||
pred = pred.argmax(axis=-1)
|
||||
if seq_len is None and len(target.ndim) > 1:
|
||||
if seq_len is None and target.ndim > 1:
|
||||
warnings.warn("You are not passing `seq_len` to exclude pad when calculate accuracy.")
|
||||
else:
|
||||
raise RuntimeError(f"when pred have "
|
||||
f"size:{pred.ndim}, target should have size: {pred.ndim} or "
|
||||
f"{pred.ndim[:-1]}, got {target.ndim}.")
|
||||
f"size:{pred.shape}, target should have size: {pred.shape} or "
|
||||
f"{pred.shape[:-1]}, got {target.shape}.")
|
||||
if masks is not None:
|
||||
target = target * masks
|
||||
pred = pred * masks
|
||||
@ -138,5 +174,3 @@ class ClassifyFPreRecMetric(Metric):
|
||||
self._tp[target_idx] += ((pred == target_idx) * (target != target_idx)).sum().item()
|
||||
self._fp[target_idx] += ((pred == target_idx) * (target == target_idx)).sum().item()
|
||||
self._fn[target_idx] += ((pred != target_idx) * (target != target_idx)).sum().item()
|
||||
|
||||
|
||||
|
10
fastNLP/core/utils/exceptions.py
Normal file
10
fastNLP/core/utils/exceptions.py
Normal file
@ -0,0 +1,10 @@
|
||||
|
||||
class EarlyStopException(BaseException):
|
||||
r"""
|
||||
用于EarlyStop时从Trainer训练循环中跳出。
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, msg):
|
||||
super(EarlyStopException, self).__init__(msg)
|
||||
self.msg = msg
|
@ -12,32 +12,27 @@ def test_get_monitor_value():
|
||||
with Capturing() as output:
|
||||
monitor, value = _get_monitor_value(monitor='f1', real_monitor=None, res=res)
|
||||
assert monitor == 'f1' and value==0.2
|
||||
assert 'We can not find' not in output[0]
|
||||
|
||||
# 测试可以匹配,且选择更靠前的
|
||||
res = {'acc#f1': 0.2, 'acc#rec': 0.3, 'add#f':0.4}
|
||||
with Capturing() as output:
|
||||
monitor, value = _get_monitor_value(monitor='f1', real_monitor=None, res=res)
|
||||
assert monitor=='acc#f1' and value==0.2
|
||||
assert 'We can not find' in output[0]
|
||||
|
||||
# 测试monitor匹配不上,使用real_monitor
|
||||
res = {'acc#f1': 0.2, 'acc#rec': 0.3, 'add#f':0.4}
|
||||
with Capturing() as output:
|
||||
monitor, value = _get_monitor_value(monitor='acc#f', real_monitor='acc#rec', res=res)
|
||||
monitor, value = _get_monitor_value(monitor='acc', real_monitor='acc#rec', res=res)
|
||||
assert monitor=='acc#rec' and value==0.3
|
||||
assert 'We can not find' not in output[0]
|
||||
|
||||
# 测试monitor/real_monitor匹配不上, 重新选择
|
||||
res = {'acc#f1': 0.2, 'acc#rec': 0.3, 'add#f':0.4}
|
||||
with Capturing() as output:
|
||||
monitor, value = _get_monitor_value(monitor='acc#f', real_monitor='acc#r', res=res)
|
||||
assert monitor=='acc#f1' and value==0.2
|
||||
assert 'We can not find' in output[0]
|
||||
|
||||
# 测试partial的位置
|
||||
res = {"acc#acc": 0.52, "loss#loss": 2}
|
||||
with Capturing() as output:
|
||||
monitor, value = _get_monitor_value(monitor='-loss', real_monitor=None, res=res)
|
||||
assert monitor=='loss#loss' and value==2
|
||||
assert 'We can not find' in output[0]
|
||||
|
@ -15,6 +15,7 @@ from sklearn.metrics import accuracy_score as sklearn_accuracy
|
||||
from fastNLP.core.dataset import DataSet
|
||||
from fastNLP.core.metrics.accuracy import Accuracy
|
||||
from fastNLP.core.metrics.metric import Metric
|
||||
from .utils import find_free_network_port, setup_ddp, _assert_allclose
|
||||
|
||||
set_start_method("spawn", force=True)
|
||||
|
||||
@ -23,42 +24,6 @@ NUM_PROCESSES = 2
|
||||
pool = None
|
||||
|
||||
|
||||
def setup_ddp(rank: int, world_size: int, master_port: int) -> None:
|
||||
"""Setup ddp environment."""
|
||||
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = str(master_port)
|
||||
print(torch.cuda.device_count())
|
||||
if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"):
|
||||
torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size)
|
||||
|
||||
|
||||
def find_free_network_port() -> int:
|
||||
"""Finds a free port on localhost.
|
||||
|
||||
It is useful in single-node training when we don't want to connect to a real master node but have to set the
|
||||
`MASTER_PORT` environment variable.
|
||||
"""
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
s.bind(("", 0))
|
||||
s.listen(1)
|
||||
port = s.getsockname()[1]
|
||||
s.close()
|
||||
return port
|
||||
|
||||
|
||||
def _assert_allclose(my_result: Union[float, np.ndarray], sklearn_result: Union[float, np.ndarray],
|
||||
atol: float = 1e-8) -> None:
|
||||
"""
|
||||
测试对比结果,这里不用非得是必须数组且维度对应,一些其他情况例如 np.allclose(np.array([[1e10, ], ]), 1e10+1) 也是 True
|
||||
:param my_result: 可以不限设备等
|
||||
:param sklearn_result:
|
||||
:param atol:
|
||||
:return:
|
||||
"""
|
||||
assert np.allclose(a=my_result, b=sklearn_result, atol=atol)
|
||||
|
||||
|
||||
def _test(local_rank: int,
|
||||
world_size: int,
|
||||
device: torch.device,
|
||||
|
177
tests/core/metrics/test_classify_f1_pre_rec_metric_torch.py
Normal file
177
tests/core/metrics/test_classify_f1_pre_rec_metric_torch.py
Normal file
@ -0,0 +1,177 @@
|
||||
from functools import partial
|
||||
import copy
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch.multiprocessing import Pool, set_start_method
|
||||
|
||||
from fastNLP.core.metrics import ClassifyFPreRecMetric
|
||||
from fastNLP.core.dataset import DataSet
|
||||
from .utils import find_free_network_port, setup_ddp
|
||||
|
||||
set_start_method("spawn", force=True)
|
||||
|
||||
|
||||
def _test(local_rank: int, world_size: int, device: torch.device,
|
||||
dataset: DataSet, metric_class, metric_kwargs, metric_result):
|
||||
metric = metric_class(**metric_kwargs)
|
||||
# dataset 也类似(每个进程有自己的一个)
|
||||
dataset = copy.deepcopy(dataset)
|
||||
metric.to(device)
|
||||
# 把数据拆到每个 GPU 上,有点模仿 DistributedSampler 的感觉,但这里数据单位是一个 batch(即每个 i 取了一个 batch 到自己的 GPU 上)
|
||||
for i in range(local_rank, len(dataset), world_size):
|
||||
pred, tg = dataset[i]['pred'].to(device), dataset[i]['tg'].to(device)
|
||||
metric.update(pred, tg)
|
||||
|
||||
my_result = metric.get_metric()
|
||||
for keys in ['f', 'pre', 'rec']:
|
||||
np.allclose(my_result[keys], metric_result[keys], atol=0.000001)
|
||||
|
||||
|
||||
class TestClassfiyFPreRecMetric:
|
||||
def test_case_1(self):
|
||||
pred = torch.tensor([[-0.4375, -0.1779, -1.0985, -1.1592, 0.4910],
|
||||
[1.3410, 0.2889, -0.8667, -1.8580, 0.3029],
|
||||
[0.7459, -1.1957, 0.3231, 0.0308, -0.1847],
|
||||
[1.1439, -0.0057, 0.8203, 0.0312, -1.0051],
|
||||
[-0.4870, 0.3215, -0.8290, 0.9221, 0.4683],
|
||||
[0.9078, 1.0674, -0.5629, 0.3895, 0.8917],
|
||||
[-0.7743, -0.4041, -0.9026, 0.2112, 1.0892],
|
||||
[1.8232, -1.4188, -2.5615, -2.4187, 0.5907],
|
||||
[-1.0592, 0.4164, -0.1192, 1.4238, -0.9258],
|
||||
[-1.1137, 0.5773, 2.5778, 0.5398, -0.3323],
|
||||
[-0.3868, -0.5165, 0.2286, -1.3876, 0.5561],
|
||||
[-0.3304, 1.3619, -1.5744, 0.4902, -0.7661],
|
||||
[1.8387, 0.5234, 0.4269, 1.3748, -1.2793],
|
||||
[0.6692, 0.2571, 1.2425, -0.5894, -0.0184],
|
||||
[0.4165, 0.4084, -0.1280, 1.4489, -2.3058],
|
||||
[-0.5826, -0.5469, 1.5898, -0.2786, -0.9882],
|
||||
[-1.5548, -2.2891, 0.2983, -1.2145, -0.1947],
|
||||
[-0.7222, 2.3543, -0.5801, -0.0640, -1.5614],
|
||||
[-1.4978, 1.9297, -1.3652, -0.2358, 2.5566],
|
||||
[0.1561, -0.0316, 0.9331, 1.0363, 2.3949],
|
||||
[0.2650, -0.8459, 1.3221, 0.1321, -1.1900],
|
||||
[0.0664, -1.2353, -0.5242, -1.4491, 1.3300],
|
||||
[-0.2744, 0.0941, 0.7157, 0.1404, 1.2046],
|
||||
[0.9341, -0.6652, 1.4512, 0.9608, -0.3623],
|
||||
[-1.1641, 0.0873, 0.1163, -0.2068, -0.7002],
|
||||
[1.4775, -2.0025, -0.5634, -0.1589, 0.0247],
|
||||
[1.0151, 1.0304, -0.1042, -0.6955, -0.0629],
|
||||
[-0.3119, -0.4558, 0.7757, 0.0758, -1.6297],
|
||||
[1.0654, 0.0313, -0.7716, 0.1194, 0.6913],
|
||||
[-0.8088, -0.6648, -0.5018, -0.0230, -0.8207],
|
||||
[-0.7753, -0.3508, 1.6163, 0.7158, 1.5207],
|
||||
[0.8692, 0.7718, -0.6734, 0.6515, 0.0641]])
|
||||
arg_max_pred = torch.argmax(pred, dim=-1)
|
||||
target = torch.tensor([0, 2, 4, 1, 4, 0, 1, 3, 3, 3, 1, 3, 4, 4, 3, 4, 0, 2, 4, 4, 3, 4, 4, 3,
|
||||
0, 3, 0, 0, 0, 1, 3, 1])
|
||||
|
||||
metric = ClassifyFPreRecMetric(f_type='macro', num_class=5)
|
||||
metric.update(pred, target)
|
||||
result_dict = metric.get_metric()
|
||||
f1_score = 0.1882051282051282
|
||||
recall = 0.1619047619047619
|
||||
pre = 0.23928571428571427
|
||||
|
||||
ground_truth = {'f': f1_score, 'pre': pre, 'rec': recall}
|
||||
for keys in ['f', 'pre', 'rec']:
|
||||
np.allclose(result_dict[keys], ground_truth[keys], atol=0.000001)
|
||||
|
||||
metric = ClassifyFPreRecMetric(f_type='micro', num_class=5)
|
||||
metric.update(pred, target)
|
||||
result_dict = metric.get_metric()
|
||||
f1_score = 0.21875
|
||||
recall = 0.21875
|
||||
pre = 0.21875
|
||||
|
||||
ground_truth = {'f': f1_score, 'pre': pre, 'rec': recall}
|
||||
for keys in ['f', 'pre', 'rec']:
|
||||
np.allclose(result_dict[keys], ground_truth[keys], atol=0.000001)
|
||||
|
||||
metric = ClassifyFPreRecMetric(only_gross=False, f_type='macro', num_class=5)
|
||||
metric.update(pred, target)
|
||||
result_dict = metric.get_metric()
|
||||
ground_truth = {
|
||||
'0': {'f1-score': 0.13333333333333333, 'precision': 0.125, 'recall': 0.14285714285714285, 'support': 7},
|
||||
'1': {'f1-score': 0.0, 'precision': 0.0, 'recall': 0.0, 'support': 5},
|
||||
'2': {'f1-score': 0.0, 'precision': 0.0, 'recall': 0.0, 'support': 2},
|
||||
'3': {'f1-score': 0.30769230769230765, 'precision': 0.5, 'recall': 0.2222222222222222, 'support': 9},
|
||||
'4': {'f1-score': 0.5, 'precision': 0.5714285714285714, 'recall': 0.4444444444444444, 'support': 9},
|
||||
'macro avg': {'f1-score': 0.1882051282051282, 'precision': 0.23928571428571427,
|
||||
'recall': 0.1619047619047619, 'support': 32},
|
||||
'micro avg': {'f1-score': 0.21875, 'precision': 0.21875, 'recall': 0.21875, 'support': 32},
|
||||
'weighted avg': {'f1-score': 0.2563301282051282, 'precision': 0.3286830357142857, 'recall': 0.21875,
|
||||
'support': 32}}
|
||||
for keys in result_dict.keys():
|
||||
if keys == "f" or "pre" or "rec":
|
||||
continue
|
||||
gl = str(keys[-1])
|
||||
tmp_d = {"p": "precision", "r": "recall", "f": "f1-score"}
|
||||
gk = tmp_d[keys[0]]
|
||||
np.allclose(result_dict[keys], ground_truth[gl][gk], atol=0.000001)
|
||||
|
||||
@pytest.mark.parametrize("f_type, f1_score,recall,pre",
|
||||
[('macro', 0.1882051282051282, 0.1619047619047619, 0.23928571428571427),
|
||||
('micro', 0.21875, 0.21875, 0.21875)])
|
||||
def test_case_2(self, f_type, f1_score, recall, pre):
|
||||
dataset = DataSet({
|
||||
'pred': [torch.tensor([[-0.4375, -0.1779, -1.0985, -1.1592, 0.4910],
|
||||
[1.3410, 0.2889, -0.8667, -1.8580, 0.3029],
|
||||
[0.7459, -1.1957, 0.3231, 0.0308, -0.1847],
|
||||
[1.1439, -0.0057, 0.8203, 0.0312, -1.0051],
|
||||
[-0.4870, 0.3215, -0.8290, 0.9221, 0.4683],
|
||||
[0.9078, 1.0674, -0.5629, 0.3895, 0.8917],
|
||||
[-0.7743, -0.4041, -0.9026, 0.2112, 1.0892],
|
||||
[1.8232, -1.4188, -2.5615, -2.4187, 0.5907],
|
||||
[-1.0592, 0.4164, -0.1192, 1.4238, -0.9258],
|
||||
[-1.1137, 0.5773, 2.5778, 0.5398, -0.3323],
|
||||
[-0.3868, -0.5165, 0.2286, -1.3876, 0.5561],
|
||||
[-0.3304, 1.3619, -1.5744, 0.4902, -0.7661],
|
||||
[1.8387, 0.5234, 0.4269, 1.3748, -1.2793],
|
||||
[0.6692, 0.2571, 1.2425, -0.5894, -0.0184],
|
||||
[0.4165, 0.4084, -0.1280, 1.4489, -2.3058],
|
||||
[-0.5826, -0.5469, 1.5898, -0.2786, -0.9882]]),
|
||||
torch.tensor([
|
||||
[-1.5548, -2.2891, 0.2983, -1.2145, -0.1947],
|
||||
[-0.7222, 2.3543, -0.5801, -0.0640, -1.5614],
|
||||
[-1.4978, 1.9297, -1.3652, -0.2358, 2.5566],
|
||||
[0.1561, -0.0316, 0.9331, 1.0363, 2.3949],
|
||||
[0.2650, -0.8459, 1.3221, 0.1321, -1.1900],
|
||||
[0.0664, -1.2353, -0.5242, -1.4491, 1.3300],
|
||||
[-0.2744, 0.0941, 0.7157, 0.1404, 1.2046],
|
||||
[0.9341, -0.6652, 1.4512, 0.9608, -0.3623],
|
||||
[-1.1641, 0.0873, 0.1163, -0.2068, -0.7002],
|
||||
[1.4775, -2.0025, -0.5634, -0.1589, 0.0247],
|
||||
[1.0151, 1.0304, -0.1042, -0.6955, -0.0629],
|
||||
[-0.3119, -0.4558, 0.7757, 0.0758, -1.6297],
|
||||
[1.0654, 0.0313, -0.7716, 0.1194, 0.6913],
|
||||
[-0.8088, -0.6648, -0.5018, -0.0230, -0.8207],
|
||||
[-0.7753, -0.3508, 1.6163, 0.7158, 1.5207],
|
||||
[0.8692, 0.7718, -0.6734, 0.6515, 0.0641]
|
||||
])],
|
||||
'tg': [
|
||||
torch.LongTensor([0, 2, 4, 1, 4, 0, 1, 3, 3, 3, 1, 3, 4, 4, 3, 4]),
|
||||
torch.LongTensor([0, 2, 4, 4, 3, 4, 4, 3, 0, 3, 0, 0, 0, 1, 3, 1])
|
||||
]
|
||||
})
|
||||
metric_kwargs = {
|
||||
'f_type': f_type,
|
||||
'num_class': 5,
|
||||
'only_gross': False,
|
||||
'aggregate_when_get_metric': True
|
||||
}
|
||||
ground_truth = {'f': f1_score, 'pre': pre, 'rec': recall}
|
||||
|
||||
NUM_PROCESSES = 2
|
||||
pool = Pool(processes=NUM_PROCESSES)
|
||||
master_port = find_free_network_port()
|
||||
pool.starmap(setup_ddp, [(rank, NUM_PROCESSES, master_port) for rank in range(NUM_PROCESSES)])
|
||||
|
||||
pool.starmap(partial(_test, dataset=dataset,
|
||||
metric_class=ClassifyFPreRecMetric,
|
||||
metric_kwargs=metric_kwargs,
|
||||
metric_result=ground_truth),
|
||||
[(rank, NUM_PROCESSES, torch.device(f'cuda:{rank+4}')) for rank in range(NUM_PROCESSES)])
|
||||
pool.close()
|
||||
pool.join()
|
@ -14,6 +14,7 @@ from torch.multiprocessing import Pool, set_start_method
|
||||
from fastNLP.core.vocabulary import Vocabulary
|
||||
from fastNLP.core.metrics import SpanFPreRecMetric
|
||||
from fastNLP.core.dataset import DataSet
|
||||
from .utils import find_free_network_port, setup_ddp
|
||||
|
||||
set_start_method("spawn", force=True)
|
||||
|
||||
@ -41,40 +42,6 @@ NUM_PROCESSES = 2
|
||||
pool = None
|
||||
|
||||
|
||||
def setup_ddp(rank: int, world_size: int, master_port: int) -> None:
|
||||
"""Setup ddp environment."""
|
||||
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = str(master_port)
|
||||
if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"):
|
||||
torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size)
|
||||
|
||||
|
||||
def find_free_network_port() -> int:
|
||||
"""Finds a free port on localhost.
|
||||
|
||||
It is useful in single-node training when we don't want to connect to a real master node but have to set the
|
||||
`MASTER_PORT` environment variable.
|
||||
"""
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
s.bind(("", 0))
|
||||
s.listen(1)
|
||||
port = s.getsockname()[1]
|
||||
s.close()
|
||||
return port
|
||||
|
||||
|
||||
# @pytest.fixture(scope='class', autouse=True)
|
||||
# def pre_process():
|
||||
# global pool
|
||||
# pool = Pool(processes=NUM_PROCESSES)
|
||||
# master_port = find_free_network_port()
|
||||
# pool.starmap(setup_ddp, [(rank, NUM_PROCESSES, master_port) for rank in range(NUM_PROCESSES)])
|
||||
# yield
|
||||
# pool.close()
|
||||
# pool.join()
|
||||
|
||||
|
||||
def _test(local_rank: int,
|
||||
world_size: int,
|
||||
device: torch.device,
|
||||
|
42
tests/core/metrics/utils.py
Normal file
42
tests/core/metrics/utils.py
Normal file
@ -0,0 +1,42 @@
|
||||
import os, sys
|
||||
import socket
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from torch import distributed
|
||||
import numpy as np
|
||||
|
||||
|
||||
def setup_ddp(rank: int, world_size: int, master_port: int) -> None:
|
||||
"""Setup ddp environment."""
|
||||
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = str(master_port)
|
||||
if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"):
|
||||
torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size)
|
||||
|
||||
|
||||
def find_free_network_port() -> int:
|
||||
"""Finds a free port on localhost.
|
||||
|
||||
It is useful in single-node training when we don't want to connect to a real master node but have to set the
|
||||
`MASTER_PORT` environment variable.
|
||||
"""
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
s.bind(("", 0))
|
||||
s.listen(1)
|
||||
port = s.getsockname()[1]
|
||||
s.close()
|
||||
return port
|
||||
|
||||
|
||||
def _assert_allclose(my_result: Union[float, np.ndarray], sklearn_result: Union[float, np.ndarray],
|
||||
atol: float = 1e-8) -> None:
|
||||
"""
|
||||
测试对比结果,这里不用非得是必须数组且维度对应,一些其他情况例如 np.allclose(np.array([[1e10, ], ]), 1e10+1) 也是 True
|
||||
:param my_result: 可以不限设备等
|
||||
:param sklearn_result:
|
||||
:param atol:
|
||||
:return:
|
||||
"""
|
||||
assert np.allclose(a=my_result, b=sklearn_result, atol=atol)
|
Loading…
Reference in New Issue
Block a user