Merge branch 'dev0.8.0' of github.com:fastnlp/fastNLP into dev0.8.0

This commit is contained in:
YWMditto 2022-04-12 17:00:18 +08:00
commit 1443065cbd
15 changed files with 522 additions and 217 deletions

View File

@ -10,7 +10,8 @@ __all__ = [
'ProgressCallback',
'RichCallback',
"LRSchedCallback",
'LoadBestModelCallback'
'LoadBestModelCallback',
"EarlyStopCallback"
]
@ -21,4 +22,5 @@ from .checkpoint_callback import ModelCheckpointCallback, TrainerCheckpointCallb
from .progress_callback import choose_progress_callback, ProgressCallback, RichCallback
from .lr_scheduler_callback import LRSchedCallback
from .load_best_model_callback import LoadBestModelCallback
from .early_stop_callback import EarlyStopCallback

View File

@ -1,11 +1,15 @@
from typing import Union, Callable, Dict, Optional
from typing import Union, Callable, Dict, Optional, Any
from abc import ABC
__all__ = [
'Callback',
]
from .callback_events import Events, EventsList, Filter
from .utils import _get_monitor_value
from fastNLP.core.callbacks.callback_events import _SingleEventState
from fastNLP.core.log import logger
from fastNLP.core.utils import apply_to_collection
class Callback:
@ -150,4 +154,82 @@ class _CallbackWrapper(Callback):
return self.fn.__name__
class CanItemDataType(ABC):
"""
检测可以进行传输的对象
"""
@classmethod
def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]:
if cls is CanItemDataType:
item = getattr(subclass, 'item', None)
return callable(item)
return NotImplemented
class HasMonitorCallback(Callback):
def __init__(self, monitor, larger_better, must_have_monitor=False):
self.set_monitor(monitor, larger_better)
self.must_have_moinitor = must_have_monitor
def set_monitor(self, monitor, larger_better):
self.monitor = str(monitor) if monitor is not None else None
self.larger_better = bool(larger_better)
if larger_better:
self.monitor_value = float('-inf')
else:
self.monitor_value = float('inf')
self._real_monitor = self.monitor
def on_after_trainer_initialized(self, trainer, driver):
"""
如果本身的 monitor 没有设置则根据 Trainer 中的 monitor 设置 monitor
同时对于必须要有 monitor 设置的 callback 该函数会进行检查
:param trainer:
:param driver:
:return:
"""
if self.monitor is None and trainer.monitor is not None:
self.set_monitor(monitor=trainer.monitor, larger_better=trainer.larger_better)
if self.must_have_moinitor and self.monitor is None:
raise RuntimeError(f"No `monitor` is set for {self.__class__.__name__}. "
f"You can set it in the initialization or through Trainer.")
def get_monitor_value(self, results:Dict)->float:
"""
获取 monitor 的值如果 monitor 没有直接找到会尝试使用匹配的方式寻找并把匹配到的设置到 self._real_monitor 属性上
:param results:
:return:
"""
if len(results)==0:
return 0
# 保证所有的 tensor 都被转换为了 python 特定的类型
results = apply_to_collection(results, dtype=CanItemDataType, function=lambda x: x.item())
use_monitor, monitor_value = _get_monitor_value(monitor=self.monitor,
real_monitor=self._real_monitor,
res=results)
if self._real_monitor != use_monitor: # 发生了替换需要打印
logger.warning(
f"We can not find `{self.monitor}` in the evaluation result (with keys as {list(results.keys())}), "
f"we use the `{use_monitor}` as the monitor for {self.__class__.__name__}.")
self._real_monitor = use_monitor
return monitor_value
def is_better_monitor_value(self, monitor_value: float, keep_if_better=True):
"""
检测 monitor_value 是否是更好的
:param monitor_value:
:param keep_if_better: 如果传入的 monitor_value 值更好则将其保存下来
:return:
"""
better = False
if (self.larger_better and monitor_value > self.monitor_value) or \
(not self.larger_better and monitor_value < self.monitor_value):
better = True
if keep_if_better:
self.monitor_value = monitor_value
return better

View File

@ -5,12 +5,12 @@ __all__ = [
import os
from typing import Union, Optional, Callable, Dict, Sequence, Any, Mapping
from pathlib import Path
from abc import ABC
import sys
from copy import deepcopy
import fastNLP
from .callback import Callback, Filter
from .callback import Callback, HasMonitorCallback
from fastNLP.core.callbacks.utils import _get_monitor_value
from fastNLP.core.log import logger
from fastNLP.envs import FASTNLP_LAUNCH_TIME
@ -18,22 +18,7 @@ from fastNLP.core.utils import synchronize_safe_rm, synchronize_mkdir
from fastNLP.core.utils import apply_to_collection
class CanItemDataType(ABC):
"""
检测可以进行传输的对象
"""
@classmethod
def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]:
if cls is CanItemDataType:
item = getattr(subclass, 'item', None)
return callable(item)
return NotImplemented
class CheckpointCallback(Callback):
class CheckpointCallback(HasMonitorCallback):
def __init__(
self,
monitor,
@ -48,13 +33,8 @@ class CheckpointCallback(Callback):
model_save_fn: Optional[Callable] = None,
**kwargs,
):
# 我们新加了逻辑,如果 checkpoint callback 自己没有设置 monitor 和 larger_better那么我们会将其在 trainer 中的设置赋值给它们;
# if monitor is None and save_topk is not None:
# raise ValueError("Parameter `monitor` must be set when you want to use 'save_topk'.")
if monitor is not None and not isinstance(monitor, str):
raise ValueError("Parameter `monitor` should be of 'str' type.")
super().__init__(monitor=monitor, larger_better=larger_better,
must_have_monitor=save_topk is not None)
if save_folder is None:
logger.warning(
"Parameter `path` is None, and we will use the current work directory to find and load your model.")
@ -92,13 +72,12 @@ class CheckpointCallback(Callback):
"`BaseException` type.")
else:
save_on_exception = []
self.monitor = monitor
self.save_folder = Path(save_folder)
self.save_every_n_epochs = save_every_n_epochs
self.save_every_n_batches = save_every_n_batches
self.save_last = save_last
self.save_topk = save_topk
self.larger_better = larger_better
self.only_state_dict = only_state_dict
self.model_save_fn = model_save_fn
self.save_on_exception = save_on_exception
@ -108,12 +87,6 @@ class CheckpointCallback(Callback):
self._topk_model = {}
self._topn = 0 # 表示目前已经保存了几个最好的模型;
# 因为我们在 `_get_validate_metric` 函数中,当在返回的 `validate_res` 字典中找不到 `monitor` 时,是使用匹配找到的
# key 对应的 value 当做结果;但是这样存在的一个问题在于如果用户传入的 metric 返回的 sub_metric 的名字可能会混淆,并且其在下一次
# 训练的代码中修改了这些 sub_metric 返回的顺序,那么就会导致模糊匹配拿到的 key 和 value 与之前的不是同一个,这显然不是合理的行为;
# 因此我们通过该变量来表示我们通过模糊匹配拿到的 key
self._real_monitor = self.monitor
# 注意这里应当保证只有进程 0 在执行这个操作,因为当用户使用 python -m torch.distributed.launch 来拉起进程的时候,
# FASTNLP_LAUNCH_TIME 在每一个进程上的值是不一样的;
self.timestamp_path = self.save_folder.joinpath(os.environ[FASTNLP_LAUNCH_TIME])
@ -121,20 +94,15 @@ class CheckpointCallback(Callback):
synchronize_mkdir(self.timestamp_path)
def on_after_trainer_initialized(self, trainer, driver):
if self.monitor is None:
if trainer.monitor is not None:
self.monitor = trainer.monitor
self.larger_better = trainer.larger_better
elif self.save_topk is not None:
raise RuntimeError("You are using `topk` mode, but you have not set the `monitor` value either in this"
"callback or in trainer.")
else:
self.monitor = None
if self.save_topk is not None:
super().on_after_trainer_initialized(trainer, driver)
if self.save_topk is not None and trainer.evaluator is None:
raise RuntimeError("You are using `topk` mode, but there is no `evaluator` in trainer.")
logger.warning("You set `save_topk`, but `validate_dataloaders` is not set in Trainer.")
def on_validate_end(self, trainer, validate_res):
self._save_topk(trainer, validate_res)
def on_validate_end(self, trainer, results):
if len(results) == 0:
return
self._save_topk(trainer, results)
def on_train_epoch_end(self, trainer: "fastNLP.Trainer"):
if trainer.cur_epoch_idx % self.save_every_n_epochs == 0:
@ -157,7 +125,7 @@ class CheckpointCallback(Callback):
def on_sanity_check_end(self, trainer, sanity_check_res):
# 主要核对一下 monitor 是否存在。
self._get_validate_metric(sanity_check_res)
self.get_monitor_value(results=sanity_check_res)
def on_save_checkpoint(self, trainer) -> Dict:
"""
@ -168,8 +136,7 @@ class CheckpointCallback(Callback):
states = {}
states['timestamp_path'] = str(self.timestamp_path.absolute())
states['_topk_model'] = apply_to_collection(self._topk_model, dtype=CanItemDataType,
function=lambda x:x.item())
states['_topk_model'] = deepcopy(self._topk_model)
states['save_topk'] = 0 if self.save_topk is None else self.save_topk
states['_real_monitor'] = self._real_monitor
return states
@ -190,30 +157,30 @@ class CheckpointCallback(Callback):
self._topk_model.update(self._topk_model)
self._real_monitor = states["real_monitor"]
def _save_topk(self, trainer: "fastNLP.Trainer", validate_res: Dict):
def _save_topk(self, trainer: "fastNLP.Trainer", results: Dict):
"""
根据validate_res决定保存哪些model的函数会自动移除掉不满足topk的文件夹
:param trainer:
:param validate_res:
:param results:
:return:
"""
if self.save_topk is not None:
_metric_value = self._get_validate_metric(validate_res)
monitor_value = self.get_monitor_value(results=results)
folder_name = f"{self.folder_prefix}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}" \
f"-{self._real_monitor}_{_metric_value}"
f"-{self._real_monitor}_{monitor_value}"
_should_save = False
if self._topn < self.save_topk:
self._topk_model[folder_name] = _metric_value
self._topk_model[folder_name] = monitor_value
self._topn += 1
_should_save = True
else:
_least_valuable_model = (min if self.larger_better else max)(self._topk_model,
key=lambda x: self._topk_model[x])
if (self.larger_better and _metric_value > self._topk_model[_least_valuable_model]) or \
(self.larger_better is False and _metric_value < self._topk_model[_least_valuable_model]):
self._topk_model[folder_name] = _metric_value
if (self.larger_better and monitor_value > self._topk_model[_least_valuable_model]) or \
(self.larger_better is False and monitor_value < self._topk_model[_least_valuable_model]):
self._topk_model[folder_name] = monitor_value
_should_save = True
self._topk_model.pop(_least_valuable_model)
synchronize_safe_rm(self.timestamp_path.joinpath(_least_valuable_model))
@ -249,7 +216,11 @@ class CheckpointCallback(Callback):
:return:
"""
use_monitor, value = _get_monitor_value(monitor=self.monitor, real_monitor=self._real_monitor, res=res)
if self._real_monitor != use_monitor:
logger.warning(f"We can not find `{self._real_monitor}` in the evaluation result (with keys as {list(res.keys())}), "
f"we use the `{use_monitor}` as the monitor for {self.__class__.__name__}.")
self._real_monitor = use_monitor
return value
@property
@ -277,7 +248,7 @@ class ModelCheckpointCallback(CheckpointCallback):
model_save_fn 不为 None fastNLP folder 绝对路径传递给该函数fastNLP 不在该 folder 下创建任何文件
:param monitor: 监控的 metric 的名称如果在 evaluation 结果中没有找到完全一致的名称将使用 最短公共字符串算法 找到最匹配
的那个作为 monitor
的那个作为 monitor 如果为 None 将尝试从 Trainer 中获取该值
:param save_folder: 保存的文件夹fastNLP 将在该文件下以时间戳创建子文件夹并在里面保存因此不同次运行可以将被保存到不同的
时间戳文件夹中如果为 None 默认使用当前文件夹
:param save_every_n_epochs: 多少个 epoch 保存一次
@ -324,7 +295,7 @@ class TrainerCheckpointCallback(CheckpointCallback):
model_save_fn 不为 None fastNLP 只会在每个 folder 下生成 fastnlp_trainer.pkl.tar 文件
:param monitor: 监控的 metric 的名称如果在 evaluation 结果中没有找到完全一致的名称将使用 最短公共字符串算法 找到最匹配
的那个作为 monitor
的那个作为 monitor 如果为 None 将尝试从 Trainer 中获取该值
:param save_folder: 保存的文件夹fastNLP 将在该文件下以时间戳创建子文件夹并在里面保存因此不同次运行可以将被保存到不同的
时间戳文件夹中如果为 None 默认使用当前文件夹
:param save_every_n_epochs: 多少个 epoch 保存一次

View File

@ -0,0 +1,61 @@
__all__ = [
'EarlyStopCallback'
]
from typing import Dict
from .callback import HasMonitorCallback
from fastNLP.core.utils.exceptions import EarlyStopException
class EarlyStopCallback(HasMonitorCallback):
def __init__(self, monitor:str=None, larger_better:bool=True, patience:int=10):
"""
:param str monitor: 监控的 metric 如果为 None将尝试使用 Trainer 设置的 monitor
:param larger_better: monitor 的值是否是越大越好
:param patience: 多少次 validate 不没有提升就停止
"""
super(EarlyStopCallback, self).__init__(monitor=monitor, larger_better=larger_better, must_have_monitor=True)
self.wait = 0
self.patience = patience
def on_validate_end(self, trainer, results):
if len(results)==0:
return
monitor_value = self.get_monitor_value(results)
if self.is_better_monitor_value(monitor_value, keep_if_better=True):
self.wait = 0
else:
self.wait += 1
def on_fetch_data_begin(self, trainer):
# 当是 step validate 的时候,下一步执行的就是这个, 所以在这里检查。
if self.wait >= self.patience:
raise EarlyStopException(f"After {self.wait} validations, no improvement for "
f"metric `{self._real_monitor}`")
def on_train_epoch_begin(self, trainer):
# 当是 epoch validate 的时候,下一步执行的就是这个, 所以在这里检查。
if self.wait >= self.patience:
raise EarlyStopException(f"After {self.wait} validations, no improvement for "
f"metric `{self._real_monitor}`(best value: {self.monitor_value})")
def on_save_checkpoint(self, trainer) -> Dict:
states = {
'patience': self.patience,
'wait': self.wait,
'monitor': self.monitor,
'monitor_value': self.monitor_value
}
return states
def on_load_checkpoint(self, trainer, states):
self.patience = states['patience']
self.wait = states['wait']
self.monitor = states['monitor']
self.monitor_value = float(states['monitor_value'])
def callback_name(self):
return f'EarlyStopCallback#monitor-{self.monitor}#patience-{self.patience}'

View File

@ -4,8 +4,7 @@ __all__ = [
import os
from typing import Optional, Callable
from .callback import Callback
from .utils import _get_monitor_value
from .callback import HasMonitorCallback
from io import BytesIO
import shutil
@ -14,15 +13,15 @@ from fastNLP.core.log import logger
from fastNLP.envs import all_rank_call
class LoadBestModelCallback(Callback):
def __init__(self, monitor:str, larger_better:bool = True, only_state_dict:bool = True,
class LoadBestModelCallback(HasMonitorCallback):
def __init__(self, monitor:str=None, larger_better:bool = True, only_state_dict:bool = True,
save_folder:Optional[str] = None, model_save_fn:Optional[Callable] = None,
model_load_fn:Optional[Callable] = None,
delete_after_train:bool = True):
"""
保存最佳的 monitor 值最佳的模型并在训练结束的时候重新加载模型仅在训练正常结束的时候才能加载最好的模型
:param str monitor: 监控的 metric
:param str monitor: 监控的 metric 如果为 None将尝试使用 Trainer 设置的 monitor
:param larger_better: metric 值是否是越大越好
:param save_folder: 保存的文件夹如果为空则保存在内存中不为空则保存一份权重到文件中当为多机训练且本值不为空时请确保
不同的机器均可访问当该路径 model_save_fn 不为 None 时该值一定不能为空
@ -33,6 +32,7 @@ class LoadBestModelCallback(Callback):
请在函数内完成对模型的加载
:param delete_after_train: 在训练结束后是否删掉模型
"""
super().__init__(monitor=monitor, larger_better=larger_better, must_have_monitor=True)
if model_load_fn is not None:
assert callable(model_load_fn), "`model_load_fn` must be a callable object."
assert model_save_fn is not None, "`model_load_fn` and `model_save_fn` must be passed at the same time."
@ -56,15 +56,11 @@ class LoadBestModelCallback(Callback):
self.real_save_folder = None
self.buffer = BytesIO()
self.monitor = monitor
self.larger_better = larger_better
self.save_folder = save_folder
self.only_state_dict = only_state_dict
self.model_save_fn = model_save_fn
self.model_load_fn = model_load_fn
self.delete_after_after = delete_after_train
self._real_monitor = None
self.monitor_value = float('-inf') if larger_better else float('inf')
def on_after_trainer_initialized(self, trainer, driver):
if self.save_folder is not None and driver.is_distributed() and int(os.environ.get(FASTNLP_BACKEND_LAUNCH, 0))==1:
@ -76,13 +72,16 @@ class LoadBestModelCallback(Callback):
raise RuntimeError(f"Currently {driver.__class__.__name__} does not support using `save_folder` to "
f"save best model when launch using script.")
super().on_after_trainer_initialized(trainer, driver)
def on_sanity_check_end(self, trainer, sanity_check_res):
self.get_monitor_value(sanity_check_res)
def on_validate_end(self, trainer, results):
self._real_monitor, monitor_value = _get_monitor_value(monitor=self.monitor,
real_monitor=self._real_monitor,
res=results)
if (monitor_value < self.monitor_value and self.larger_better is False) or \
(monitor_value > self.monitor_value and self.larger_better):
self.monitor_value = monitor_value
if len(results)==0:
return
monitor_value = self.get_monitor_value(results)
if self.is_better_monitor_value(monitor_value, keep_if_better=True):
if self.real_save_folder:
trainer.save_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict,
model_save_fn=self.model_save_fn)

View File

@ -8,7 +8,7 @@ __all__ = [
'RichCallback'
]
from .callback import Callback
from .callback import HasMonitorCallback
from fastNLP.core.callbacks.utils import _get_monitor_value
from fastNLP.core.utils import f_rich_progress
from fastNLP.core.log import logger
@ -28,15 +28,13 @@ def choose_progress_callback(progress_bar:str):
return None
class ProgressCallback(Callback):
class ProgressCallback(HasMonitorCallback):
def on_train_end(self, trainer):
f_rich_progress.stop()
def on_sanity_check_end(self, trainer, sanity_check_res):
if len(sanity_check_res) and getattr(self, 'monitor', None) is not None:
self._real_monitor, monitor_value = _get_monitor_value(monitor=self.monitor,
real_monitor=self._real_monitor,
res=sanity_check_res)
self.get_monitor_value(sanity_check_res)
class RichCallback(ProgressCallback):
@ -46,28 +44,22 @@ class RichCallback(ProgressCallback):
:param print_every: 多少个 batch 更新一次显示
:param loss_round_ndigit: 显示的 loss 保留多少位有效数字
:param monitor: 当检测到这个key的结果更好时会打印出不同的颜色进行提示
:param monitor: 当检测到这个key的结果更好时会打印出不同的颜色进行提示如果为 None 会尝试使用 trainer 中设置的 monitor
:param larger_better: 是否是monitor的结果越大越好
:param format_json: 是否format json再打印
"""
super().__init__()
super().__init__(monitor=monitor, larger_better=larger_better, must_have_monitor=False)
self.print_every = print_every
self.progress_bar = f_rich_progress
self.task2id = {}
self.loss = 0
self.loss_round_ndigit = loss_round_ndigit
self.monitor = monitor
self.larger_better = larger_better
if larger_better:
self.monitor_value = float('-inf')
else:
self.monitor_value = float('inf')
self._real_monitor = monitor
self.format_json = format_json
def on_after_trainer_initialized(self, trainer, driver):
if not self.progress_bar.disable:
self.progress_bar.set_disable(flag=trainer.driver.get_local_rank() != 0)
super(RichCallback, self).on_after_trainer_initialized(trainer, driver)
def on_train_begin(self, trainer):
self.task2id['epoch'] = self.progress_bar.add_task(description='Epoch:0', total=trainer.n_epochs,
@ -109,16 +101,12 @@ class RichCallback(ProgressCallback):
text_style = ''
characters = '-'
if self.monitor is not None:
self._real_monitor, monitor_value = _get_monitor_value(monitor=self.monitor,
real_monitor=self._real_monitor,
res=results)
if (self.larger_better and monitor_value > self.monitor_value) or \
(not self.larger_better and monitor_value < self.monitor_value):
monitor_value = self.get_monitor_value(results)
if self.is_better_monitor_value(monitor_value, keep_if_better=True):
if abs(self.monitor_value) != float('inf'):
rule_style = 'spring_green3'
text_style = '[bold]'
characters = '+'
self.monitor_value = monitor_value
self.progress_bar.print()
self.progress_bar.console.rule(text_style+f"Eval. results on Epoch:{trainer.cur_epoch_idx}, "
f"Batch:{trainer.batch_idx_in_epoch}",
@ -151,18 +139,12 @@ class RawTextCallback(ProgressCallback):
:param larger_better: 是否是monitor的结果越大越好
:param format_json: 是否format json再打印
"""
super().__init__()
super().__init__(monitor=monitor, larger_better=larger_better, must_have_monitor=False)
self.print_every = print_every
self.task2id = {}
self.loss = 0
self.loss_round_ndigit = loss_round_ndigit
self.monitor = monitor
self.larger_better = larger_better
if larger_better:
self.monitor_value = float('-inf')
else:
self.monitor_value = float('inf')
self._real_monitor = monitor
self.set_monitor(monitor, larger_better)
self.format_json = format_json
self.num_signs = 10
@ -189,14 +171,10 @@ class RawTextCallback(ProgressCallback):
base_text = f'Eval. results on Epoch:{trainer.cur_epoch_idx}, Batch:{trainer.batch_idx_in_epoch}'
text = ''
if self.monitor is not None:
self._real_monitor, monitor_value = _get_monitor_value(monitor=self.monitor,
real_monitor=self._real_monitor,
res=results)
if (self.larger_better and monitor_value > self.monitor_value) or \
(not self.larger_better and monitor_value < self.monitor_value):
monitor_value = self.get_monitor_value(results)
if self.is_better_monitor_value(monitor_value, keep_if_better=True):
if abs(self.monitor_value) != float('inf'):
text = '+'*self.num_signs + base_text + '+'*self.num_signs
self.monitor_value = monitor_value
if len(text) == 0:
text = '-'*self.num_signs + base_text + '-'*self.num_signs

View File

@ -19,23 +19,31 @@ def _get_monitor_value(monitor: str, real_monitor: Optional[str], res: dict) ->(
if monitor in res:
return monitor, res[monitor]
if real_monitor in res:
return real_monitor, res[real_monitor]
pairs = []
for idx, (key, value) in enumerate(res.items()):
match = SequenceMatcher(None, key, monitor).find_longest_match(0, len(key), 0, len(monitor))
pairs.append((key, value, match.size, idx))
match_size = _match_length(monitor, key)
pairs.append((key, value, match_size, idx))
pairs.sort(key=lambda pair: (pair[2], -pair[3]), reverse=True)
key, value, match_size = pairs[0][:3]
if real_monitor is not None and real_monitor in res and real_monitor != key:
# 如果 real_monitor 比新找的更长就继续用之前的。
match = SequenceMatcher(None, real_monitor, monitor).find_longest_match(0, len(real_monitor), 0, len(monitor))
if match.size > match_size:
return real_monitor, res[real_monitor]
logger.warning(f"We can not find `{monitor}` in the evaluation result (with keys as {list(res.keys())}), "
f"we use the `{key}` as the monitor.")
real_monitor = key
return real_monitor, value
return key, value
def _match_length(a:str, b:str)->int:
"""
需要把长度短的放在前面
:param a:
:param b:
:return:
"""
short = a if len(a) < len(b) else b
long = a if len(a)>=len(b) else b
match = SequenceMatcher(None, short, long).find_longest_match(0, len(short), 0, len(long))
return match.size

View File

@ -25,6 +25,7 @@ from fastNLP.core.utils import check_fn_not_empty_params, get_fn_arg_names, matc
from fastNLP.envs import rank_zero_call
from fastNLP.core.log import logger
from fastNLP.envs import FASTNLP_MODEL_FILENAME
from fastNLP.core.utils.exceptions import EarlyStopException
class Trainer(TrainerEventTrigger):
@ -50,6 +51,8 @@ class Trainer(TrainerEventTrigger):
model_wo_auto_param_call: bool = False,
accumulation_steps: int = 1,
fp16: bool = False,
monitor: str = None,
larger_better: bool = True,
marker: Optional[str] = None,
**kwargs
):
@ -106,6 +109,10 @@ class Trainer(TrainerEventTrigger):
False那么我们会将 batch 直接透传给 forward 函数注意上述逻辑同样应用于 `train_step`, `validate_step` `test_step`
:param accumulation_steps: 梯度累积的步数表示每隔几个 batch 优化器迭代一次默认为 1
:param fp16: 是否开启混合精度训练默认为 False
:param monitor: 当存在 validate_dataloaders 默认的 monitor metric 的名字传入的 callback 如果有 monitor 参数且没有
callback 初始化设定的将采取这个值如果在 evaluation 结果中没有找到完全一致的名称将使用 最短公共字符串算法 找到最匹配
的那个作为 monitor
:param larger_better: monitor 的值是否是越大越好
:param marker: 用于标记一个 Trainer 实例从而在用户调用 `Trainer.on` 函数时标记该 callback 函数属于哪一个具体的 'trainer' 实例默认为 None
:param kwargs: 一些其它的可能需要的参数
torch_non_blocking: 表示用于 pytorch tensor to 方法的参数 non_blocking
@ -213,6 +220,8 @@ class Trainer(TrainerEventTrigger):
self.evaluator = None
self.epoch_validate = lambda *args, **kwargs: ...
self.step_validate = lambda *args, **kwargs: ...
self.monitor = monitor
self.larger_better = larger_better
if metrics is not None and validate_dataloaders is not None:
if not callable(validate_every) and (not isinstance(validate_every, int) or validate_every == 0):
raise ValueError("Parameter 'validate_every' should be set to 'int' type and either < 0 or > 0.")
@ -242,6 +251,7 @@ class Trainer(TrainerEventTrigger):
else:
# validate_every > 0
self._step_validate_filter = Filter(every=validate_every)
self.metrics = metrics
self.validate_every = validate_every
@ -323,6 +333,10 @@ class Trainer(TrainerEventTrigger):
self.driver.barrier()
self.on_train_end()
self.driver.barrier()
except EarlyStopException as e:
logger.info(f"Catch early stop exception: {e.msg}.")
self.on_exception(e)
except KeyboardInterrupt as e:
self.driver.on_exception()
self.on_exception(e)

View File

@ -29,14 +29,16 @@ def _compute_f_pre_rec(beta_square, tp, fn, fp):
class ClassifyFPreRecMetric(Metric):
def __init__(self, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = False,
tag_vocab: Vocabulary = None, encoding_type: str = None, ignore_labels: List[str] = None,
only_gross: bool = True, f_type='micro', beta=1) -> None:
def __init__(self, tag_vocab: Vocabulary = None, ignore_labels: List[str] = None, num_class: int = 0,
only_gross: bool = True, f_type='micro', beta=1, backend: Union[str, Backend, None] = 'auto',
aggregate_when_get_metric: bool = False) -> None:
super(ClassifyFPreRecMetric, self).__init__(backend=backend,
aggregate_when_get_metric=aggregate_when_get_metric)
if f_type not in ('micro', 'macro'):
raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type))
if tag_vocab:
if not isinstance(tag_vocab, Vocabulary):
raise TypeError("tag_vocab can only be fastNLP.Vocabulary, not {}.".format(type(tag_vocab)))
self.ignore_labels = ignore_labels
self.f_type = f_type
self.beta = beta
@ -45,9 +47,32 @@ class ClassifyFPreRecMetric(Metric):
self.tag_vocab = tag_vocab
self._tp, self._fp, self._fn = defaultdict(partial(self.register_element, aggregate_method='sum')),\
defaultdict(partial(self.register_element, aggregate_method='sum')),\
defaultdict(partial(self.register_element, aggregate_method='sum'))
self._tp = {}
self._fp = {}
self._fn = {}
if tag_vocab:
for word, _ in tag_vocab:
word = word.lower()
if word != 'o':
word = word[2:]
if word in self._true_positives:
continue
self._tp[word] = self.register_element(name=f'tp_{word}', aggregate_method='sum',
backend=backend)
self._fn[word] = self.register_element(name=f'fn_{word}', aggregate_method='sum',
backend=backend)
self._fp[word] = self.register_element(name=f'fp_{word}', aggregate_method='sum',
backend=backend)
elif num_class > 0:
for word in range(num_class):
self._tp[word] = self.register_element(name=f'tp_{word}', aggregate_method='sum',
backend=backend)
self._fn[word] = self.register_element(name=f'fn_{word}', aggregate_method='sum',
backend=backend)
self._fp[word] = self.register_element(name=f'fp_{word}', aggregate_method='sum',
backend=backend)
else:
raise ValueError()
def get_metric(self) -> dict:
r"""
@ -68,9 +93,11 @@ class ClassifyFPreRecMetric(Metric):
tag_name = self.tag_vocab.to_word(tag)
else:
tag_name = int(tag)
tp = self._tp[tag]
fn = self._fn[tag]
fp = self._fp[tag]
tp = self._tp[tag].get_scalar()
fn = self._fn[tag].get_scalar()
fp = self._fp[tag].get_scalar()
if tp == fn == fp == 0:
continue
f, pre, rec = _compute_f_pre_rec(self.beta_square, tp, fn, fp)
f_sum += f
pre_sum += pre
@ -90,20 +117,29 @@ class ClassifyFPreRecMetric(Metric):
if self.f_type == 'micro':
f, pre, rec = _compute_f_pre_rec(self.beta_square,
sum(self._tp.values()),
sum(self._fn.values()),
sum(self._fp.values()))
sum(val.get_scalar() for val in self._tp.values()),
sum(val.get_scalar() for val in self._fn.values()),
sum(val.get_scalar() for val in self._fp.values()))
evaluate_result['f'] = f
evaluate_result['pre'] = pre
evaluate_result['rec'] = rec
for key, value in evaluate_result.items():
evaluate_result[key] = round(value, 6)
return evaluate_result
def update(self, pred, target, seq_len=None):
r"""
evaluate函数将针对一个批次的预测结果做评价指标的累计
:param torch.Tensor pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]),
torch.Size([B, max_len]), 或者torch.Size([B, max_len, n_classes])
:param torch.Tensor target: 真实值的tensor, tensor的形状可以是Element's can be: torch.Size([B,]),
torch.Size([B,]), torch.Size([B, max_len]), 或者torch.Size([B, max_len])
:param torch.Tensor seq_len: 序列长度标记, 标记的形状可以是None, None, torch.Size([B]), 或者torch.Size([B]).
如果mask也被传进来的话seq_len会被忽略.
"""
pred = self.tensor2numpy(pred)
target = self.tensor2numpy(target)
if seq_len is not None:
@ -122,14 +158,14 @@ class ClassifyFPreRecMetric(Metric):
f"pred have element numbers: {len(target.flatten())}")
pass
elif len(pred.ndim) == len(target.ndim) + 1:
elif pred.ndim == target.ndim + 1:
pred = pred.argmax(axis=-1)
if seq_len is None and len(target.ndim) > 1:
if seq_len is None and target.ndim > 1:
warnings.warn("You are not passing `seq_len` to exclude pad when calculate accuracy.")
else:
raise RuntimeError(f"when pred have "
f"size:{pred.ndim}, target should have size: {pred.ndim} or "
f"{pred.ndim[:-1]}, got {target.ndim}.")
f"size:{pred.shape}, target should have size: {pred.shape} or "
f"{pred.shape[:-1]}, got {target.shape}.")
if masks is not None:
target = target * masks
pred = pred * masks
@ -138,5 +174,3 @@ class ClassifyFPreRecMetric(Metric):
self._tp[target_idx] += ((pred == target_idx) * (target != target_idx)).sum().item()
self._fp[target_idx] += ((pred == target_idx) * (target == target_idx)).sum().item()
self._fn[target_idx] += ((pred != target_idx) * (target != target_idx)).sum().item()

View File

@ -0,0 +1,10 @@
class EarlyStopException(BaseException):
r"""
用于EarlyStop时从Trainer训练循环中跳出
"""
def __init__(self, msg):
super(EarlyStopException, self).__init__(msg)
self.msg = msg

View File

@ -12,32 +12,27 @@ def test_get_monitor_value():
with Capturing() as output:
monitor, value = _get_monitor_value(monitor='f1', real_monitor=None, res=res)
assert monitor == 'f1' and value==0.2
assert 'We can not find' not in output[0]
# 测试可以匹配,且选择更靠前的
res = {'acc#f1': 0.2, 'acc#rec': 0.3, 'add#f':0.4}
with Capturing() as output:
monitor, value = _get_monitor_value(monitor='f1', real_monitor=None, res=res)
assert monitor=='acc#f1' and value==0.2
assert 'We can not find' in output[0]
# 测试monitor匹配不上使用real_monitor
res = {'acc#f1': 0.2, 'acc#rec': 0.3, 'add#f':0.4}
with Capturing() as output:
monitor, value = _get_monitor_value(monitor='acc#f', real_monitor='acc#rec', res=res)
monitor, value = _get_monitor_value(monitor='acc', real_monitor='acc#rec', res=res)
assert monitor=='acc#rec' and value==0.3
assert 'We can not find' not in output[0]
# 测试monitor/real_monitor匹配不上, 重新选择
res = {'acc#f1': 0.2, 'acc#rec': 0.3, 'add#f':0.4}
with Capturing() as output:
monitor, value = _get_monitor_value(monitor='acc#f', real_monitor='acc#r', res=res)
assert monitor=='acc#f1' and value==0.2
assert 'We can not find' in output[0]
# 测试partial的位置
res = {"acc#acc": 0.52, "loss#loss": 2}
with Capturing() as output:
monitor, value = _get_monitor_value(monitor='-loss', real_monitor=None, res=res)
assert monitor=='loss#loss' and value==2
assert 'We can not find' in output[0]

View File

@ -15,6 +15,7 @@ from sklearn.metrics import accuracy_score as sklearn_accuracy
from fastNLP.core.dataset import DataSet
from fastNLP.core.metrics.accuracy import Accuracy
from fastNLP.core.metrics.metric import Metric
from .utils import find_free_network_port, setup_ddp, _assert_allclose
set_start_method("spawn", force=True)
@ -23,42 +24,6 @@ NUM_PROCESSES = 2
pool = None
def setup_ddp(rank: int, world_size: int, master_port: int) -> None:
"""Setup ddp environment."""
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(master_port)
print(torch.cuda.device_count())
if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"):
torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size)
def find_free_network_port() -> int:
"""Finds a free port on localhost.
It is useful in single-node training when we don't want to connect to a real master node but have to set the
`MASTER_PORT` environment variable.
"""
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(("", 0))
s.listen(1)
port = s.getsockname()[1]
s.close()
return port
def _assert_allclose(my_result: Union[float, np.ndarray], sklearn_result: Union[float, np.ndarray],
atol: float = 1e-8) -> None:
"""
测试对比结果这里不用非得是必须数组且维度对应一些其他情况例如 np.allclose(np.array([[1e10, ], ]), 1e10+1) 也是 True
:param my_result: 可以不限设备等
:param sklearn_result:
:param atol:
:return:
"""
assert np.allclose(a=my_result, b=sklearn_result, atol=atol)
def _test(local_rank: int,
world_size: int,
device: torch.device,

View File

@ -0,0 +1,177 @@
from functools import partial
import copy
import pytest
import torch
import numpy as np
from torch.multiprocessing import Pool, set_start_method
from fastNLP.core.metrics import ClassifyFPreRecMetric
from fastNLP.core.dataset import DataSet
from .utils import find_free_network_port, setup_ddp
set_start_method("spawn", force=True)
def _test(local_rank: int, world_size: int, device: torch.device,
dataset: DataSet, metric_class, metric_kwargs, metric_result):
metric = metric_class(**metric_kwargs)
# dataset 也类似(每个进程有自己的一个)
dataset = copy.deepcopy(dataset)
metric.to(device)
# 把数据拆到每个 GPU 上,有点模仿 DistributedSampler 的感觉,但这里数据单位是一个 batch即每个 i 取了一个 batch 到自己的 GPU 上)
for i in range(local_rank, len(dataset), world_size):
pred, tg = dataset[i]['pred'].to(device), dataset[i]['tg'].to(device)
metric.update(pred, tg)
my_result = metric.get_metric()
for keys in ['f', 'pre', 'rec']:
np.allclose(my_result[keys], metric_result[keys], atol=0.000001)
class TestClassfiyFPreRecMetric:
def test_case_1(self):
pred = torch.tensor([[-0.4375, -0.1779, -1.0985, -1.1592, 0.4910],
[1.3410, 0.2889, -0.8667, -1.8580, 0.3029],
[0.7459, -1.1957, 0.3231, 0.0308, -0.1847],
[1.1439, -0.0057, 0.8203, 0.0312, -1.0051],
[-0.4870, 0.3215, -0.8290, 0.9221, 0.4683],
[0.9078, 1.0674, -0.5629, 0.3895, 0.8917],
[-0.7743, -0.4041, -0.9026, 0.2112, 1.0892],
[1.8232, -1.4188, -2.5615, -2.4187, 0.5907],
[-1.0592, 0.4164, -0.1192, 1.4238, -0.9258],
[-1.1137, 0.5773, 2.5778, 0.5398, -0.3323],
[-0.3868, -0.5165, 0.2286, -1.3876, 0.5561],
[-0.3304, 1.3619, -1.5744, 0.4902, -0.7661],
[1.8387, 0.5234, 0.4269, 1.3748, -1.2793],
[0.6692, 0.2571, 1.2425, -0.5894, -0.0184],
[0.4165, 0.4084, -0.1280, 1.4489, -2.3058],
[-0.5826, -0.5469, 1.5898, -0.2786, -0.9882],
[-1.5548, -2.2891, 0.2983, -1.2145, -0.1947],
[-0.7222, 2.3543, -0.5801, -0.0640, -1.5614],
[-1.4978, 1.9297, -1.3652, -0.2358, 2.5566],
[0.1561, -0.0316, 0.9331, 1.0363, 2.3949],
[0.2650, -0.8459, 1.3221, 0.1321, -1.1900],
[0.0664, -1.2353, -0.5242, -1.4491, 1.3300],
[-0.2744, 0.0941, 0.7157, 0.1404, 1.2046],
[0.9341, -0.6652, 1.4512, 0.9608, -0.3623],
[-1.1641, 0.0873, 0.1163, -0.2068, -0.7002],
[1.4775, -2.0025, -0.5634, -0.1589, 0.0247],
[1.0151, 1.0304, -0.1042, -0.6955, -0.0629],
[-0.3119, -0.4558, 0.7757, 0.0758, -1.6297],
[1.0654, 0.0313, -0.7716, 0.1194, 0.6913],
[-0.8088, -0.6648, -0.5018, -0.0230, -0.8207],
[-0.7753, -0.3508, 1.6163, 0.7158, 1.5207],
[0.8692, 0.7718, -0.6734, 0.6515, 0.0641]])
arg_max_pred = torch.argmax(pred, dim=-1)
target = torch.tensor([0, 2, 4, 1, 4, 0, 1, 3, 3, 3, 1, 3, 4, 4, 3, 4, 0, 2, 4, 4, 3, 4, 4, 3,
0, 3, 0, 0, 0, 1, 3, 1])
metric = ClassifyFPreRecMetric(f_type='macro', num_class=5)
metric.update(pred, target)
result_dict = metric.get_metric()
f1_score = 0.1882051282051282
recall = 0.1619047619047619
pre = 0.23928571428571427
ground_truth = {'f': f1_score, 'pre': pre, 'rec': recall}
for keys in ['f', 'pre', 'rec']:
np.allclose(result_dict[keys], ground_truth[keys], atol=0.000001)
metric = ClassifyFPreRecMetric(f_type='micro', num_class=5)
metric.update(pred, target)
result_dict = metric.get_metric()
f1_score = 0.21875
recall = 0.21875
pre = 0.21875
ground_truth = {'f': f1_score, 'pre': pre, 'rec': recall}
for keys in ['f', 'pre', 'rec']:
np.allclose(result_dict[keys], ground_truth[keys], atol=0.000001)
metric = ClassifyFPreRecMetric(only_gross=False, f_type='macro', num_class=5)
metric.update(pred, target)
result_dict = metric.get_metric()
ground_truth = {
'0': {'f1-score': 0.13333333333333333, 'precision': 0.125, 'recall': 0.14285714285714285, 'support': 7},
'1': {'f1-score': 0.0, 'precision': 0.0, 'recall': 0.0, 'support': 5},
'2': {'f1-score': 0.0, 'precision': 0.0, 'recall': 0.0, 'support': 2},
'3': {'f1-score': 0.30769230769230765, 'precision': 0.5, 'recall': 0.2222222222222222, 'support': 9},
'4': {'f1-score': 0.5, 'precision': 0.5714285714285714, 'recall': 0.4444444444444444, 'support': 9},
'macro avg': {'f1-score': 0.1882051282051282, 'precision': 0.23928571428571427,
'recall': 0.1619047619047619, 'support': 32},
'micro avg': {'f1-score': 0.21875, 'precision': 0.21875, 'recall': 0.21875, 'support': 32},
'weighted avg': {'f1-score': 0.2563301282051282, 'precision': 0.3286830357142857, 'recall': 0.21875,
'support': 32}}
for keys in result_dict.keys():
if keys == "f" or "pre" or "rec":
continue
gl = str(keys[-1])
tmp_d = {"p": "precision", "r": "recall", "f": "f1-score"}
gk = tmp_d[keys[0]]
np.allclose(result_dict[keys], ground_truth[gl][gk], atol=0.000001)
@pytest.mark.parametrize("f_type, f1_score,recall,pre",
[('macro', 0.1882051282051282, 0.1619047619047619, 0.23928571428571427),
('micro', 0.21875, 0.21875, 0.21875)])
def test_case_2(self, f_type, f1_score, recall, pre):
dataset = DataSet({
'pred': [torch.tensor([[-0.4375, -0.1779, -1.0985, -1.1592, 0.4910],
[1.3410, 0.2889, -0.8667, -1.8580, 0.3029],
[0.7459, -1.1957, 0.3231, 0.0308, -0.1847],
[1.1439, -0.0057, 0.8203, 0.0312, -1.0051],
[-0.4870, 0.3215, -0.8290, 0.9221, 0.4683],
[0.9078, 1.0674, -0.5629, 0.3895, 0.8917],
[-0.7743, -0.4041, -0.9026, 0.2112, 1.0892],
[1.8232, -1.4188, -2.5615, -2.4187, 0.5907],
[-1.0592, 0.4164, -0.1192, 1.4238, -0.9258],
[-1.1137, 0.5773, 2.5778, 0.5398, -0.3323],
[-0.3868, -0.5165, 0.2286, -1.3876, 0.5561],
[-0.3304, 1.3619, -1.5744, 0.4902, -0.7661],
[1.8387, 0.5234, 0.4269, 1.3748, -1.2793],
[0.6692, 0.2571, 1.2425, -0.5894, -0.0184],
[0.4165, 0.4084, -0.1280, 1.4489, -2.3058],
[-0.5826, -0.5469, 1.5898, -0.2786, -0.9882]]),
torch.tensor([
[-1.5548, -2.2891, 0.2983, -1.2145, -0.1947],
[-0.7222, 2.3543, -0.5801, -0.0640, -1.5614],
[-1.4978, 1.9297, -1.3652, -0.2358, 2.5566],
[0.1561, -0.0316, 0.9331, 1.0363, 2.3949],
[0.2650, -0.8459, 1.3221, 0.1321, -1.1900],
[0.0664, -1.2353, -0.5242, -1.4491, 1.3300],
[-0.2744, 0.0941, 0.7157, 0.1404, 1.2046],
[0.9341, -0.6652, 1.4512, 0.9608, -0.3623],
[-1.1641, 0.0873, 0.1163, -0.2068, -0.7002],
[1.4775, -2.0025, -0.5634, -0.1589, 0.0247],
[1.0151, 1.0304, -0.1042, -0.6955, -0.0629],
[-0.3119, -0.4558, 0.7757, 0.0758, -1.6297],
[1.0654, 0.0313, -0.7716, 0.1194, 0.6913],
[-0.8088, -0.6648, -0.5018, -0.0230, -0.8207],
[-0.7753, -0.3508, 1.6163, 0.7158, 1.5207],
[0.8692, 0.7718, -0.6734, 0.6515, 0.0641]
])],
'tg': [
torch.LongTensor([0, 2, 4, 1, 4, 0, 1, 3, 3, 3, 1, 3, 4, 4, 3, 4]),
torch.LongTensor([0, 2, 4, 4, 3, 4, 4, 3, 0, 3, 0, 0, 0, 1, 3, 1])
]
})
metric_kwargs = {
'f_type': f_type,
'num_class': 5,
'only_gross': False,
'aggregate_when_get_metric': True
}
ground_truth = {'f': f1_score, 'pre': pre, 'rec': recall}
NUM_PROCESSES = 2
pool = Pool(processes=NUM_PROCESSES)
master_port = find_free_network_port()
pool.starmap(setup_ddp, [(rank, NUM_PROCESSES, master_port) for rank in range(NUM_PROCESSES)])
pool.starmap(partial(_test, dataset=dataset,
metric_class=ClassifyFPreRecMetric,
metric_kwargs=metric_kwargs,
metric_result=ground_truth),
[(rank, NUM_PROCESSES, torch.device(f'cuda:{rank+4}')) for rank in range(NUM_PROCESSES)])
pool.close()
pool.join()

View File

@ -14,6 +14,7 @@ from torch.multiprocessing import Pool, set_start_method
from fastNLP.core.vocabulary import Vocabulary
from fastNLP.core.metrics import SpanFPreRecMetric
from fastNLP.core.dataset import DataSet
from .utils import find_free_network_port, setup_ddp
set_start_method("spawn", force=True)
@ -41,40 +42,6 @@ NUM_PROCESSES = 2
pool = None
def setup_ddp(rank: int, world_size: int, master_port: int) -> None:
"""Setup ddp environment."""
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(master_port)
if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"):
torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size)
def find_free_network_port() -> int:
"""Finds a free port on localhost.
It is useful in single-node training when we don't want to connect to a real master node but have to set the
`MASTER_PORT` environment variable.
"""
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(("", 0))
s.listen(1)
port = s.getsockname()[1]
s.close()
return port
# @pytest.fixture(scope='class', autouse=True)
# def pre_process():
# global pool
# pool = Pool(processes=NUM_PROCESSES)
# master_port = find_free_network_port()
# pool.starmap(setup_ddp, [(rank, NUM_PROCESSES, master_port) for rank in range(NUM_PROCESSES)])
# yield
# pool.close()
# pool.join()
def _test(local_rank: int,
world_size: int,
device: torch.device,

View File

@ -0,0 +1,42 @@
import os, sys
import socket
from typing import Union
import torch
from torch import distributed
import numpy as np
def setup_ddp(rank: int, world_size: int, master_port: int) -> None:
"""Setup ddp environment."""
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(master_port)
if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"):
torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size)
def find_free_network_port() -> int:
"""Finds a free port on localhost.
It is useful in single-node training when we don't want to connect to a real master node but have to set the
`MASTER_PORT` environment variable.
"""
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(("", 0))
s.listen(1)
port = s.getsockname()[1]
s.close()
return port
def _assert_allclose(my_result: Union[float, np.ndarray], sklearn_result: Union[float, np.ndarray],
atol: float = 1e-8) -> None:
"""
测试对比结果这里不用非得是必须数组且维度对应一些其他情况例如 np.allclose(np.array([[1e10, ], ]), 1e10+1) 也是 True
:param my_result: 可以不限设备等
:param sklearn_result:
:param atol:
:return:
"""
assert np.allclose(a=my_result, b=sklearn_result, atol=atol)