Merge branch 'dev0.8.0' of github.com:fastnlp/fastNLP into dev0.8.0

This commit is contained in:
MorningForest 2022-05-03 22:20:25 +08:00
commit efa3d5451b
73 changed files with 2899 additions and 1391 deletions

View File

@ -1,4 +1,53 @@
__all__ = [
# callbacks
'Callback',
'Event',
'Filter',
'CallbackManager',
'CheckpointCallback',
'choose_progress_callback',
'ProgressCallback',
'RichCallback',
"LRSchedCallback",
'LoadBestModelCallback',
"EarlyStopCallback",
'MoreEvaluateCallback',
"TorchWarmupCallback",
"TorchGradClipCallback",
# collators
'Collator',
'NumpyNumberPadder',
'NumpySequencePadder',
"NumpyTensorPadder",
"Padder",
"NullPadder",
"RawNumberPadder",
"RawSequencePadder",
'TorchNumberPadder',
'TorchSequencePadder',
'TorchTensorPadder',
"PaddleNumberPadder",
"PaddleTensorPadder",
"PaddleSequencePadder",
"get_padded_numpy_array",
# controllers
'Loop',
'EvaluateBatchLoop',
'TrainBatchLoop',
'Evaluator',
'Trainer',
# dataloaders TODO 需要把 mix_dataloader 的搞定
# dataset
'DataSet',
'FieldArray',
'Instance',
'ApplyResultException',
# drivers
"TorchSingleDriver",
"TorchDDPDriver",
"PaddleSingleDriver",
@ -7,16 +56,16 @@ __all__ = [
"JittorMPIDriver",
"TorchPaddleDriver",
"paddle_to",
"get_paddle_gpu_str",
"get_paddle_device_id",
"paddle_move_data_to_device",
"torch_paddle_move_data_to_device",
]
# TODO之后要优化一下这里的导入应该是每一个 sub module 先import自己内部的类和函数然后外层的 module 再直接从 submodule 中 import
from fastNLP.core.controllers.trainer import Trainer
from fastNLP.core.controllers.evaluator import Evaluator
from fastNLP.core.dataloaders.torch_dataloader import *
# log
"logger"
#
]
from .callbacks import *
from .collators import *
from .controllers import *
from .dataloaders import *
from .dataset import *
from .drivers import *
from .log import *
from .utils import *

View File

@ -1,7 +1,6 @@
__all__ = [
'Callback',
'Events',
'EventsList',
'Event',
'Filter',
'CallbackManager',
'CheckpointCallback',
@ -20,7 +19,7 @@ __all__ = [
from .callback import Callback
from .callback_events import EventsList, Events, Filter
from .callback_event import Event, Filter
from .callback_manager import CallbackManager
from .checkpoint_callback import CheckpointCallback
from .progress_callback import choose_progress_callback, ProgressCallback, RichCallback

View File

@ -3,10 +3,9 @@ __all__ = [
'Callback',
]
from typing import Union, Callable, Dict, Optional, Any
from typing import Callable, Dict, Optional
from .callback_events import Events, EventsList, Filter
from fastNLP.core.callbacks.callback_events import _SingleEventState
from .callback_event import Event, Filter
class Callback:
@ -14,32 +13,35 @@ class Callback:
实际使用的 callback 不管是我们 fastNLP 默认提供的一些 callback 还是用户自己定制的 callback 都应该继承该基类
callback 调用时机顺序大概如下
Trainer.__init__():
on_after_trainer_initialized()
on_after_trainer_initialized(trainer, driver)
Trainer.run():
if num_eval_sanity_batch>0:
on_sanity_check_begin() # 如果设置了num_eval_sanity_batch
on_sanity_check_end()
on_sanity_check_begin(trainer) # 如果设置了num_eval_sanity_batch
on_sanity_check_end(trainer, sanity_check_res)
try:
on_train_begin()
on_train_begin(trainer)
while cur_epoch_idx < n_epochs:
on_train_epoch_begin()
on_train_epoch_begin(trainer)
while batch_idx_in_epoch<=num_batches_per_epoch:
on_fetch_data_begin()
on_fetch_data_end()
on_train_batch_begin()
on_before_backward()
on_after_backward()
on_before_zero_grad() # 实际调用受到 accumulation_steps 影响
on_after_zero_grad() # 实际调用受到 accumulation_steps 影响
on_before_optimizers_step() # 实际调用受到 accumulation_steps 影响
on_after_optimizers_step() # 实际调用受到 accumulation_steps 影响
on_train_batch_end()
on_train_epoch_end()
on_fetch_data_begin(trainer)
batch = next(dataloader)
on_fetch_data_end(trainer)
on_train_batch_begin(trainer, batch, indices)
on_before_backward(trainer, outputs) # 其中 outputs 是经过 output_mapping如果设置了 后的,否则即为 model 的输出。
on_after_backward(trainer)
on_before_zero_grad(trainer, optimizers) # 实际调用受到 accumulation_steps 影响
on_after_zero_grad(trainer, optimizers) # 实际调用受到 accumulation_steps 影响
on_before_optimizers_step(trainer, optimizers) # 实际调用受到 accumulation_steps 影响
on_after_optimizers_step(trainer, optimizers) # 实际调用受到 accumulation_steps 影响
on_train_batch_end(trainer)
on_train_epoch_end(trainer)
except BaseException:
self.on_exception()
self.on_exception(trainer, exception)
finally:
on_train_end()
其它 callback 例如 on_evaluate_begin()/on_evaluate_end()
on_train_end(trainer)
其它 callback 例如 on_evaluate_begin(trainer)/on_evaluate_end(trainer, results)/on_save_model(trainer)/
on_load_model(trainer)/on_save_checkpoint(trainer)/on_load_checkpoint(trainer)将根据需要在Trainer.run()中特定
的时间调用
"""
def on_after_trainer_initialized(self, trainer, driver):
@ -294,18 +296,14 @@ class _CallbackWrapper(Callback):
对于用户使用函数修饰器加入的 callback 函数使用该 _CallbackWrapper 类为其进行定制这一个类只保留用户的
这一个 callback 函数
"""
def __init__(self, event: Union[Events, EventsList], fn: Callable):
def __init__(self, event: Event, fn: Callable):
r"""
:param event: 具体的 callback 时机例如 'on_train_begin' 可以多个时机此时 `event` type 应当为 'EventsList'
:param event: 具体的 callback 时机例如 'on_train_begin'
:param fn: 用户定制的 callback 函数
"""
self.fn = fn
if isinstance(event, EventsList):
for each_event in event:
_filter = Filter(each_event.every, each_event.once, each_event.filter_fn)
setattr(self, each_event.value, _filter(fn))
elif isinstance(event, _SingleEventState):
if isinstance(event, Event):
_filter = Filter(event.every, event.once, event.filter_fn)
setattr(self, event.value, _filter(fn))

View File

@ -0,0 +1,499 @@
from typing import Optional, Callable, Dict
from functools import wraps
__all__ = [
'Event',
'Filter'
]
def check_legality(fn):
@wraps(fn)
def wrap(every=None, once=None, filter_fn=None):
if (every is None) and (once is None) and (filter_fn is None):
every = 1
if not ((every is not None) ^ (once is not None) ^ (filter_fn is not None)):
raise ValueError("These three values should be only set one.")
if (filter_fn is not None) and not callable(filter_fn):
raise TypeError("Argument filter_fn should be a callable")
if (every is not None) and not (isinstance(every, int) and every > 0):
raise ValueError("Argument every should be integer and greater than zero")
if (once is not None) and not (isinstance(once, int) and once > 0):
raise ValueError("Argument once should be integer and positive")
return fn(every=every, once=once, filter_fn=filter_fn)
return wrap
class Event:
every: Optional[int]
once: Optional[int]
def __init__(self, value: str, every: Optional[int] = None, once: Optional[int] = None,
filter_fn: Optional[Callable] = None):
"""
请勿直接使用本对象而是通过调用 Event.on_after_trainer_initialized() 等方式调用
:param value: Trainer callback 时机
:param int every: 触发了多少次才真正运行一次
:param bool once: 是否只在第一次运行后就不再执行了
:param Callable filter_fn: 输入参数的应该为 (filter, trainer)其中 filter 对象中包含了 filter.num_called
filter.num_executed 两个变了分别获取当前被调用了多少次真正执行了多少次trainer 对象即为当前正在运行的 Trainer
"""
self.every = every
self.once = once
self.filter_fn = filter_fn
self.value = value
def __str__(self):
return "<event={0}, every={1}, once={2}, filter fn is:{3}>".format(self.value, self.every, self.once,
self.filter_fn)
@staticmethod
@check_legality
def on_after_trainer_initialized(every=None, once=None, filter_fn=None):
"""
Trainer 运行到 on_after_trainer_initialized
以下三个参数互斥只能设置其中一个默认为行为等同于 every=1 默认为
:param int every: 触发了多少次才真正运行一次
:param bool once: 是否只在第一次运行后就不再执行了
:param Callable filter_fn: 输入参数的应该为 (filter, trainer)其中 filter 对象中包含了 filter.num_called
filter.num_executed 两个变了分别获取当前被调用了多少次真正执行了多少次trainer 对象即为当前正在运行的 Trainer
:return:
"""
return Event(value='on_after_trainer_initialized', every=every, once=once, filter_fn=filter_fn)
@staticmethod
@check_legality
def on_sanity_check_begin(every=None, once=None, filter_fn=None):
"""
Trainer 运行到 on_sanity_check_begin
以下三个参数互斥只能设置其中一个默认为行为等同于 every=1
:param int every: 触发了多少次才真正运行一次
:param bool once: 是否只在第一次运行后就不再执行了
:param Callable filter_fn: 输入参数的应该为 (filter, trainer)其中 filter 对象中包含了 filter.num_called
filter.num_executed 两个变了分别获取当前被调用了多少次真正执行了多少次trainer 对象即为当前正在运行的 Trainer
:return:
"""
return Event(value='on_sanity_check_begin', every=every, once=once, filter_fn=filter_fn)
@staticmethod
@check_legality
def on_sanity_check_end(every=None, once=None, filter_fn=None):
"""
Trainer 运行到 on_sanity_check_end
以下三个参数互斥只能设置其中一个默认为行为等同于 every=1
:param int every: 触发了多少次才真正运行一次
:param bool once: 是否只在第一次运行后就不再执行了
:param Callable filter_fn: 输入参数的应该为 (filter, trainer)其中 filter 对象中包含了 filter.num_called
filter.num_executed 两个变了分别获取当前被调用了多少次真正执行了多少次trainer 对象即为当前正在运行的 Trainer
:return:
"""
return Event(value='on_sanity_check_end', every=every, once=once, filter_fn=filter_fn)
@staticmethod
@check_legality
def on_train_begin(every=None, once=None, filter_fn=None):
"""
Trainer 运行到 on_train_begin
以下三个参数互斥只能设置其中一个默认为行为等同于 every=1
:param int every: 触发了多少次才真正运行一次
:param bool once: 是否只在第一次运行后就不再执行了
:param Callable filter_fn: 输入参数的应该为 (filter, trainer)其中 filter 对象中包含了 filter.num_called
filter.num_executed 两个变了分别获取当前被调用了多少次真正执行了多少次trainer 对象即为当前正在运行的 Trainer
:return:
"""
return Event(value='on_train_begin', every=every, once=once, filter_fn=filter_fn)
@staticmethod
@check_legality
def on_train_end(every=None, once=None, filter_fn=None):
"""
Trainer 运行到 on_train_end
以下三个参数互斥只能设置其中一个默认为行为等同于 every=1
:param int every: 触发了多少次才真正运行一次
:param bool once: 是否只在第一次运行后就不再执行了
:param Callable filter_fn: 输入参数的应该为 (filter, trainer)其中 filter 对象中包含了 filter.num_called
filter.num_executed 两个变了分别获取当前被调用了多少次真正执行了多少次trainer 对象即为当前正在运行的 Trainer
:return:
"""
return Event(value='on_train_end', every=every, once=once, filter_fn=filter_fn)
@staticmethod
@check_legality
def on_train_epoch_begin(every=None, once=None, filter_fn=None):
"""
Trainer 运行到 on_train_epoch_begin
以下三个参数互斥只能设置其中一个默认为行为等同于 every=1
:param int every: 触发了多少次才真正运行一次
:param bool once: 是否只在第一次运行后就不再执行了
:param Callable filter_fn: 输入参数的应该为 (filter, trainer)其中 filter 对象中包含了 filter.num_called
filter.num_executed 两个变了分别获取当前被调用了多少次真正执行了多少次trainer 对象即为当前正在运行的 Trainer
:return:
"""
return Event(value='on_train_epoch_begin', every=every, once=once, filter_fn=filter_fn)
@staticmethod
@check_legality
def on_train_epoch_end(every=None, once=None, filter_fn=None):
"""
Trainer 运行到 on_train_epoch_end
以下三个参数互斥只能设置其中一个默认为行为等同于 every=1
:param int every: 触发了多少次才真正运行一次
:param bool once: 是否只在第一次运行后就不再执行了
:param Callable filter_fn: 输入参数的应该为 (filter, trainer)其中 filter 对象中包含了 filter.num_called
filter.num_executed 两个变了分别获取当前被调用了多少次真正执行了多少次trainer 对象即为当前正在运行的 Trainer
:return:
"""
return Event(value='on_train_epoch_end', every=every, once=once, filter_fn=filter_fn)
@staticmethod
@check_legality
def on_fetch_data_begin(every=None, once=None, filter_fn=None):
"""
Trainer 运行到 on_fetch_data_begin
以下三个参数互斥只能设置其中一个默认为行为等同于 every=1
:param int every: 触发了多少次才真正运行一次
:param bool once: 是否只在第一次运行后就不再执行了
:param Callable filter_fn: 输入参数的应该为 (filter, trainer)其中 filter 对象中包含了 filter.num_called
filter.num_executed 两个变了分别获取当前被调用了多少次真正执行了多少次trainer 对象即为当前正在运行的 Trainer
:return:
"""
return Event(value='on_fetch_data_begin', every=every, once=once, filter_fn=filter_fn)
@staticmethod
@check_legality
def on_fetch_data_end(every=None, once=None, filter_fn=None):
"""
Trainer 运行到 on_fetch_data_end
以下三个参数互斥只能设置其中一个默认为行为等同于 every=1
:param int every: 触发了多少次才真正运行一次
:param bool once: 是否只在第一次运行后就不再执行了
:param Callable filter_fn: 输入参数的应该为 (filter, trainer)其中 filter 对象中包含了 filter.num_called
filter.num_executed 两个变了分别获取当前被调用了多少次真正执行了多少次trainer 对象即为当前正在运行的 Trainer
:return:
"""
return Event(value='on_fetch_data_end', every=every, once=once, filter_fn=filter_fn)
@staticmethod
@check_legality
def on_train_batch_begin(every=None, once=None, filter_fn=None):
"""
Trainer 运行到 on_train_batch_begin
以下三个参数互斥只能设置其中一个默认为行为等同于 every=1
:param int every: 触发了多少次才真正运行一次
:param bool once: 是否只在第一次运行后就不再执行了
:param Callable filter_fn: 输入参数的应该为 (filter, trainer)其中 filter 对象中包含了 filter.num_called
filter.num_executed 两个变了分别获取当前被调用了多少次真正执行了多少次trainer 对象即为当前正在运行的 Trainer
:return:
"""
return Event(value='on_train_batch_begin', every=every, once=once, filter_fn=filter_fn)
@staticmethod
@check_legality
def on_train_batch_end(every=None, once=None, filter_fn=None):
"""
Trainer 运行到 on_train_batch_end
以下三个参数互斥只能设置其中一个默认为行为等同于 every=1
:param int every: 触发了多少次才真正运行一次
:param bool once: 是否只在第一次运行后就不再执行了
:param Callable filter_fn: 输入参数的应该为 (filter, trainer)其中 filter 对象中包含了 filter.num_called
filter.num_executed 两个变了分别获取当前被调用了多少次真正执行了多少次trainer 对象即为当前正在运行的 Trainer
:return:
"""
return Event(value='on_train_batch_end', every=every, once=once, filter_fn=filter_fn)
@staticmethod
@check_legality
def on_exception(every=None, once=None, filter_fn=None):
"""
Trainer 运行到 on_exception
以下三个参数互斥只能设置其中一个默认为行为等同于 every=1
:param int every: 触发了多少次才真正运行一次
:param bool once: 是否只在第一次运行后就不再执行了
:param Callable filter_fn: 输入参数的应该为 (filter, trainer)其中 filter 对象中包含了 filter.num_called
filter.num_executed 两个变了分别获取当前被调用了多少次真正执行了多少次trainer 对象即为当前正在运行的 Trainer
:return:
"""
return Event(value='on_exception', every=every, once=once, filter_fn=filter_fn)
@staticmethod
@check_legality
def on_save_model(every=None, once=None, filter_fn=None):
"""
Trainer 运行到 on_save_model
以下三个参数互斥只能设置其中一个默认为行为等同于 every=1
:param int every: 触发了多少次才真正运行一次
:param bool once: 是否只在第一次运行后就不再执行了
:param Callable filter_fn: 输入参数的应该为 (filter, trainer)其中 filter 对象中包含了 filter.num_called
filter.num_executed 两个变了分别获取当前被调用了多少次真正执行了多少次trainer 对象即为当前正在运行的 Trainer
:return:
"""
return Event(value='on_save_model', every=every, once=once, filter_fn=filter_fn)
@staticmethod
@check_legality
def on_load_model(every=None, once=None, filter_fn=None):
"""
Trainer 运行到 on_load_model
以下三个参数互斥只能设置其中一个默认为行为等同于 every=1
:param int every: 触发了多少次才真正运行一次
:param bool once: 是否只在第一次运行后就不再执行了
:param Callable filter_fn: 输入参数的应该为 (filter, trainer)其中 filter 对象中包含了 filter.num_called
filter.num_executed 两个变了分别获取当前被调用了多少次真正执行了多少次trainer 对象即为当前正在运行的 Trainer
:return:
"""
return Event(value='on_load_model', every=every, once=once, filter_fn=filter_fn)
@staticmethod
@check_legality
def on_save_checkpoint(every=None, once=None, filter_fn=None):
"""
Trainer 运行到 on_save_checkpoint
以下三个参数互斥只能设置其中一个默认为行为等同于 every=1
:param int every: 触发了多少次才真正运行一次
:param bool once: 是否只在第一次运行后就不再执行了
:param Callable filter_fn: 输入参数的应该为 (filter, trainer)其中 filter 对象中包含了 filter.num_called
filter.num_executed 两个变了分别获取当前被调用了多少次真正执行了多少次trainer 对象即为当前正在运行的 Trainer
:return:
"""
return Event(value='on_save_checkpoint', every=every, once=once, filter_fn=filter_fn)
@staticmethod
@check_legality
def on_load_checkpoint(every=None, once=None, filter_fn=None):
"""
Trainer 运行到 on_load_checkpoint
以下三个参数互斥只能设置其中一个默认为行为等同于 every=1
:param int every: 触发了多少次才真正运行一次
:param bool once: 是否只在第一次运行后就不再执行了
:param Callable filter_fn: 输入参数的应该为 (filter, trainer)其中 filter 对象中包含了 filter.num_called
filter.num_executed 两个变了分别获取当前被调用了多少次真正执行了多少次trainer 对象即为当前正在运行的 Trainer
:return:
"""
return Event(value='on_load_checkpoint', every=every, once=once, filter_fn=filter_fn)
@staticmethod
@check_legality
def on_load_checkpoint(every=None, once=None, filter_fn=None):
"""
Trainer 运行到 on_load_checkpoint
以下三个参数互斥只能设置其中一个默认为行为等同于 every=1
:param int every: 触发了多少次才真正运行一次
:param bool once: 是否只在第一次运行后就不再执行了
:param Callable filter_fn: 输入参数的应该为 (filter, trainer)其中 filter 对象中包含了 filter.num_called
filter.num_executed 两个变了分别获取当前被调用了多少次真正执行了多少次trainer 对象即为当前正在运行的 Trainer
:return:
"""
return Event(value='on_load_checkpoint', every=every, once=once, filter_fn=filter_fn)
@staticmethod
@check_legality
def on_before_backward(every=None, once=None, filter_fn=None):
"""
Trainer 运行到 on_before_backward
以下三个参数互斥只能设置其中一个默认为行为等同于 every=1
:param int every: 触发了多少次才真正运行一次
:param bool once: 是否只在第一次运行后就不再执行了
:param Callable filter_fn: 输入参数的应该为 (filter, trainer)其中 filter 对象中包含了 filter.num_called
filter.num_executed 两个变了分别获取当前被调用了多少次真正执行了多少次trainer 对象即为当前正在运行的 Trainer
:return:
"""
return Event(value='on_before_backward', every=every, once=once, filter_fn=filter_fn)
@staticmethod
@check_legality
def on_after_backward(every=None, once=None, filter_fn=None):
"""
Trainer 运行到 on_after_backward
以下三个参数互斥只能设置其中一个默认为行为等同于 every=1
:param int every: 触发了多少次才真正运行一次
:param bool once: 是否只在第一次运行后就不再执行了
:param Callable filter_fn: 输入参数的应该为 (filter, trainer)其中 filter 对象中包含了 filter.num_called
filter.num_executed 两个变了分别获取当前被调用了多少次真正执行了多少次trainer 对象即为当前正在运行的 Trainer
:return:
"""
return Event(value='on_after_backward', every=every, once=once, filter_fn=filter_fn)
@staticmethod
@check_legality
def on_before_optimizers_step(every=None, once=None, filter_fn=None):
"""
Trainer 运行到 on_before_optimizers_step
以下三个参数互斥只能设置其中一个默认为行为等同于 every=1
:param int every: 触发了多少次才真正运行一次
:param bool once: 是否只在第一次运行后就不再执行了
:param Callable filter_fn: 输入参数的应该为 (filter, trainer)其中 filter 对象中包含了 filter.num_called
filter.num_executed 两个变了分别获取当前被调用了多少次真正执行了多少次trainer 对象即为当前正在运行的 Trainer
:return:
"""
return Event(value='on_before_optimizers_step', every=every, once=once, filter_fn=filter_fn)
@staticmethod
@check_legality
def on_after_optimizers_step(every=None, once=None, filter_fn=None):
"""
Trainer 运行到 on_after_optimizers_step
以下三个参数互斥只能设置其中一个默认为行为等同于 every=1
:param int every: 触发了多少次才真正运行一次
:param bool once: 是否只在第一次运行后就不再执行了
:param Callable filter_fn: 输入参数的应该为 (filter, trainer)其中 filter 对象中包含了 filter.num_called
filter.num_executed 两个变了分别获取当前被调用了多少次真正执行了多少次trainer 对象即为当前正在运行的 Trainer
:return:
"""
return Event(value='on_after_optimizers_step', every=every, once=once, filter_fn=filter_fn)
@staticmethod
@check_legality
def on_before_zero_grad(every=None, once=None, filter_fn=None):
"""
Trainer 运行到 on_before_zero_grad
以下三个参数互斥只能设置其中一个默认为行为等同于 every=1
:param int every: 触发了多少次才真正运行一次
:param bool once: 是否只在第一次运行后就不再执行了
:param Callable filter_fn: 输入参数的应该为 (filter, trainer)其中 filter 对象中包含了 filter.num_called
filter.num_executed 两个变了分别获取当前被调用了多少次真正执行了多少次trainer 对象即为当前正在运行的 Trainer
:return:
"""
return Event(value='on_before_zero_grad', every=every, once=once, filter_fn=filter_fn)
@staticmethod
@check_legality
def on_after_zero_grad(every=None, once=None, filter_fn=None):
"""
Trainer 运行到 on_after_zero_grad
以下三个参数互斥只能设置其中一个默认为行为等同于 every=1
:param int every: 触发了多少次才真正运行一次
:param bool once: 是否只在第一次运行后就不再执行了
:param Callable filter_fn: 输入参数的应该为 (filter, trainer)其中 filter 对象中包含了 filter.num_called
filter.num_executed 两个变了分别获取当前被调用了多少次真正执行了多少次trainer 对象即为当前正在运行的 Trainer
:return:
"""
return Event(value='on_after_zero_grad', every=every, once=once, filter_fn=filter_fn)
@staticmethod
@check_legality
def on_evaluate_begin(every=None, once=None, filter_fn=None):
"""
Trainer 运行到 on_evaluate_begin
以下三个参数互斥只能设置其中一个默认为行为等同于 every=1
:param int every: 触发了多少次才真正运行一次
:param bool once: 是否只在第一次运行后就不再执行了
:param Callable filter_fn: 输入参数的应该为 (filter, trainer)其中 filter 对象中包含了 filter.num_called
filter.num_executed 两个变了分别获取当前被调用了多少次真正执行了多少次trainer 对象即为当前正在运行的 Trainer
:return:
"""
return Event(value='on_evaluate_begin', every=every, once=once, filter_fn=filter_fn)
@staticmethod
@check_legality
def on_evaluate_end(every=None, once=None, filter_fn=None):
"""
Trainer 运行到 on_evaluate_end
以下三个参数互斥只能设置其中一个默认为行为等同于 every=1
:param int every: 触发了多少次才真正运行一次
:param bool once: 是否只在第一次运行后就不再执行了
:param Callable filter_fn: 输入参数的应该为 (filter, trainer)其中 filter 对象中包含了 filter.num_called
filter.num_executed 两个变了分别获取当前被调用了多少次真正执行了多少次trainer 对象即为当前正在运行的 Trainer
:return:
"""
return Event(value='on_evaluate_end', every=every, once=once, filter_fn=filter_fn)
class Filter:
def __init__(self, every: Optional[int] = None, once: Optional[bool] = None, filter_fn: Optional[Callable] = None):
r"""
通过该 `Filter` 作为函数修饰器来控制一个函数的实际的运行频率
:param every: 表示一个函数隔多少次运行一次
:param once: 表示一个函数只运行一次
:param filter_fn: 用户定制的频率控制函数注意该函数内部的频率判断应当是无状态的除了参数 `self.num_called`
`self.num_executed` 因为我们会在预跑后重置这两个参数的状态
"""
# check legality
check_legality(lambda *args,**kwargs:...)(every, once, filter_fn)
if (every is None) and (once is None) and (filter_fn is None):
every = 1
# 设置变量,包括全局变量;
self.num_called = 0
self.num_executed = 0
if every is not None:
self._every = every
self._filter = self.every_filter
elif once is not None:
self._once = once
self._filter = self.once_filter
else:
self._filter = filter_fn
def __call__(self, fn: Callable):
@wraps(fn)
def wrapper(*args, **kwargs) -> Callable:
self.num_called += 1
# 因为我们的 callback 函数的输入是固定的,而且我们能够保证第一个参数一定是 trainer
trainer = args[0]
if self._filter(self, trainer):
self.num_executed += 1
return fn(*args, **kwargs)
wrapper.__fastNLP_filter__ = self
return wrapper
def every_filter(self, *args):
return self.num_called % self._every == 0
def once_filter(self, *args):
return self.num_called == self._once
def state_dict(self) -> Dict:
r"""
通过该函数来保存该 `Filter` 的状态
"""
return {"num_called": self.num_called, "num_executed": self.num_executed}
def load_state_dict(self, state: Dict):
r"""
通过该函数来加载 `Filter` 的状态
:param state: 通过 `Filter.state_dict` 函数保存的状态元组
"""
self.num_called = state["num_called"]
self.num_executed = state["num_executed"]

View File

@ -1,206 +0,0 @@
from enum import Enum, unique
from typing import Union, Optional, List, Iterator, Callable, Tuple, Dict
from types import DynamicClassAttribute
from functools import wraps
__all__ = [
'Events',
'EventsList',
'Filter'
]
class _SingleEventState:
every: Optional[int]
once: Optional[int]
def __init__(self, value: str, every: Optional[int] = None, once: Optional[int] = None,
filter_fn: Optional[Callable] = None, name: Optional[str] = None):
# 具体的检测参数对错的逻辑放在具体的 Filter 里;
if every is None and once is None and filter_fn is None:
self.every = 1
self.once = None
self.filter_fn = None
else:
self.every = every
self.once = once
self.filter_fn = filter_fn
if not hasattr(self, "_value_"):
self._value_ = value
if not hasattr(self, "_name_") and name is not None:
self._name_ = name
# copied to be compatible to enum
@DynamicClassAttribute
def name(self) -> str:
"""The name of the Enum member."""
return self._name_
@DynamicClassAttribute
def value(self) -> str:
"""The value of the Enum member."""
return self._value_
def __call__(self, every: Optional[int] = None, once: Optional[int] = None, filter_fn: Optional[Callable] = None):
return _SingleEventState(self.value, every, once, filter_fn, self.name)
def __str__(self):
return "<event={0}, every={1}, once={2}, filter fn is None:{3}>".format(self.name, self.every, self.once,
self.filter_fn)
def __eq__(self, other) -> bool:
if isinstance(other, _SingleEventState):
return self.name == other.name
elif isinstance(other, str):
return self.name == other
else:
raise NotImplemented
def __hash__(self):
return hash(self._name_)
def __or__(self, other) -> "EventsList":
return EventsList() | self | other
class EventEnum(_SingleEventState, Enum):
pass
@unique
class Events(EventEnum):
on_after_trainer_initialized = "on_after_trainer_initialized"
on_sanity_check_begin = "on_sanity_check_begin"
on_sanity_check_end = "on_sanity_check_end"
on_train_begin = "on_train_begin"
on_train_end = "on_train_end"
on_train_epoch_begin = "on_train_epoch_begin"
on_train_epoch_end = "on_train_epoch_end"
on_fetch_data_begin = "on_fetch_data_begin"
on_fetch_data_end = "on_fetch_data_end"
on_train_batch_begin = "on_train_batch_begin"
on_train_batch_end = "on_train_batch_end"
on_exception = "on_exception"
on_save_model = "on_save_model"
on_load_model = "on_load_model"
on_save_checkpoint = "on_save_checkpoint"
on_load_checkpoint = "on_load_checkpoint"
on_before_backward = "on_before_backward"
on_after_backward = "on_after_backward"
on_before_optimizers_step = "on_before_optimizers_step"
on_after_optimizers_step = "on_after_optimizers_step"
on_before_zero_grad = "on_before_zero_grad"
on_after_zero_grad = "on_after_zero_grad"
on_evaluate_begin = "on_evaluate_begin"
on_evaluate_end = "on_evaluate_end"
class EventsList:
"""Collection of events stacked by operator `__or__`.
"""
def __init__(self) -> None:
self._events = [] # type: List[Union[Events, _SingleEventState]]
def _append(self, event: Union[Events, _SingleEventState]) -> None:
if not isinstance(event, (Events, _SingleEventState)):
raise TypeError(f"Argument event should be Events or CallableEventWithFilter, got: {type(event)}")
self._events.append(event)
def __getitem__(self, item: int) -> Union[Events, _SingleEventState]:
return self._events[item]
def __iter__(self) -> Iterator[Union[Events, _SingleEventState]]:
return iter(self._events)
def __len__(self) -> int:
return len(self._events)
def __or__(self, other: Union[Events, _SingleEventState]) -> "EventsList":
self._append(event=other)
return self
class Filter:
def __init__(self, every: Optional[int] = None, once: Optional[int] = None, filter_fn: Optional[Callable] = None):
r"""
通过该 `Filter` 作为函数修饰器来控制一个函数的实际的运行频率
:param every: 表示一个函数隔多少次运行一次
:param once: 表示一个函数只在第多少次时运行一次
:param filter_fn: 用户定制的频率控制函数注意该函数内部的频率判断应当是无状态的除了参数 `self.num_called`
`self.num_executed` 因为我们会在预跑后重置这两个参数的状态
"""
if (every is None) and (once is None) and (filter_fn is None):
raise ValueError("If you mean your decorated function should be called every time, you do not need this filter.")
if not ((every is not None) ^ (once is not None) ^ (filter_fn is not None)):
raise ValueError("These three values should be only set one.")
if (filter_fn is not None) and not callable(filter_fn):
raise TypeError("Argument event_filter should be a callable")
if (every is not None) and not (isinstance(every, int) and every > 0):
raise ValueError("Argument every should be integer and greater than zero")
if (once is not None) and not (isinstance(once, int) and once > 0):
raise ValueError("Argument once should be integer and positive")
# 设置变量,包括全局变量;
self.num_called = 0
self.num_executed = 0
if every is not None:
self._every = every
self._filter = self.every_filter
elif once is not None:
self._once = once
self._filter = self.once_filter
else:
self._filter = filter_fn
def __call__(self, fn: Callable):
@wraps(fn)
def wrapper(*args, **kwargs) -> Callable:
self.num_called += 1
# 因为我们的 callback 函数的输入是固定的,而且我们能够保证第一个参数一定是 trainer
trainer = args[0]
if self._filter(self, trainer):
self.num_executed += 1
return fn(*args, **kwargs)
wrapper.__fastNLP_filter__ = self
return wrapper
def every_filter(self, *args):
return self.num_called % self._every == 0
def once_filter(self, *args):
return self.num_called == self._once
def state_dict(self) -> Dict:
r"""
通过该函数来保存该 `Filter` 的状态
"""
return {"num_called": self.num_called, "num_executed": self.num_executed}
def load_state_dict(self, state: Dict):
r"""
通过该函数来加载 `Filter` 的状态
:param state: 通过 `Filter.state_dict` 函数保存的状态元组
"""
self.num_called = state["num_called"]
self.num_executed = state["num_executed"]

View File

@ -6,7 +6,7 @@ __all__ = [
'CallbackManager'
]
from .callback_events import Events
from .callback_event import Event
from .callback import Callback
from fastNLP.core.log import logger
from .progress_callback import ProgressCallback, choose_progress_callback
@ -110,7 +110,7 @@ class CallbackManager:
def initialize_class_callbacks(self):
r"""
在实际的运行过程中我们是将具体的一个 callback 实例拆分为单独的一个个 callback 函数然后将它们加在一个字典里该字典的键值就是
一个个 callback 时机也就是 `Events` 的类别
一个个 callback 时机也就是 `Event` 的类别
如果一个 callback 类的 callback 函数并不具备任何作用我们实际并不会将其加在字典当中
:param callbacks:
@ -127,11 +127,12 @@ class CallbackManager:
:param callback: 一个具体的 callback 实例
"""
self.all_callbacks.append(callback)
for name, member in Events.__members__.items():
_fn = getattr(callback, member.value)
if inspect.getsource(_fn) != inspect.getsource(getattr(Callback, member.value)):
self.callback_fns[member.value].append(_fn)
self.extract_callback_filter_state(callback.callback_name, _fn)
for name, member in Event.__dict__.items():
if isinstance(member, staticmethod):
_fn = getattr(callback, name)
if inspect.getsource(_fn) != inspect.getsource(getattr(Callback, name)):
self.callback_fns[name].append(_fn)
self.extract_callback_filter_state(callback.callback_name, _fn)
def extract_callback_filter_state(self, callback_name, callback_fn):
r"""

View File

@ -161,7 +161,6 @@ class MonitorUtility:
return monitor_name
class HasMonitorCallback(MonitorUtility, Callback):
def __init__(self, monitor, larger_better, must_have_monitor=False):
"""

View File

@ -1,4 +1,20 @@
__all__ = [
'Collator'
'Collator',
'NumpyNumberPadder',
'NumpySequencePadder',
"NumpyTensorPadder",
"Padder",
"NullPadder",
"RawNumberPadder",
"RawSequencePadder",
'TorchNumberPadder',
'TorchSequencePadder',
'TorchTensorPadder',
"PaddleNumberPadder",
"PaddleTensorPadder",
"PaddleSequencePadder",
"get_padded_numpy_array",
]
from .collator import Collator
from .padders import *

View File

@ -65,12 +65,16 @@ def _get_backend() -> str:
return catch_backend[0]
# 方式 (2)
for backend in CHECK_BACKEND:
if backend in sys.modules:
logger.debug(f"sys.modules contains backend:{catch_backend[0]}.")
return backend
for key, module in sys.modules.items():
catch_backend = _check_module(module)
if catch_backend:
break
if len(catch_backend):
logger.debug(f"Find a file named:{catch_backend[1]} from sys.modules contains backend:{catch_backend[0]}.")
logger.debug(f"Find a module file named:{catch_backend[1]} from sys.modules contains backend:{catch_backend[0]}.")
return catch_backend[0]
return 'numpy'
@ -227,7 +231,7 @@ class Collator:
设置可以 pad field 默认 pad 为什么类型的 tensor
:param backend: 对于可以 pad field使用哪种 tensor支持 ['torch','jittor','paddle','numpy','raw', 'auto', None]
若为 auto 则在进行 pad 的时候会根据调用的环境决定其 backend
若为 auto 则在进行 pad 的时候会自动根据调用的环境决定其 backend
:return:
"""
assert backend in SUPPORTED_BACKENDS

View File

@ -0,0 +1,30 @@
__all__ = [
'NumpyNumberPadder',
'NumpySequencePadder',
"NumpyTensorPadder",
"Padder",
"NullPadder",
"RawNumberPadder",
"RawSequencePadder",
'TorchNumberPadder',
'TorchSequencePadder',
'TorchTensorPadder',
"PaddleNumberPadder",
"PaddleTensorPadder",
"PaddleSequencePadder",
"get_padded_numpy_array",
]
from .numpy_padder import *
from .padder import Padder, NullPadder
from .raw_padder import *
from .torch_padder import *
from .paddle_padder import *
from .utils import get_padded_numpy_array

View File

@ -1,8 +1,3 @@
from typing import Dict
from typing import Sequence, Any, Union, Dict
from abc import ABC
@ -12,7 +7,7 @@ from fastNLP.core.log import logger
from .padder import Padder, NullPadder
from .numpy_padder import NumpyNumberPadder, NumpySequencePadder, NumpyTensorPadder
from .torch_padder import TorchNumberPadder, TorchSequencePadder, TorchTensorPadder
from .raw_padder import RawNumberPadder, RawSequencePadder
from .raw_padder import RawNumberPadder, RawSequencePadder, RawTensorPadder
from .paddle_padder import PaddleTensorPadder, PaddleSequencePadder, PaddleNumberPadder
from .exceptions import *
@ -28,7 +23,7 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)->
:param field_name: 方便报错的
:return:
"""
assert len(batch_field)!=0, "Empty batch encountered."
logger.debug(f"The content in the field:`{field_name}` is:\n" + str(batch_field))
if pad_val is None:
logger.debug(f"The pad_val for field:{field_name} is None, not padding this field.")
@ -68,7 +63,10 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)->
return NullPadder()
# 再检查所有的元素 type 是否一致
ele_dtypes = set([v[1] for v in catalog.values()])
try:
ele_dtypes = set([v[1] for v in catalog.values()])
except TypeError:
ele_dtypes = set([str(v[1]) for v in catalog.values()])
num_eletypes = len(ele_dtypes)
if num_eletypes != 1:
msg = f'Field:`{field_name}` cannot pad, since it has various types({ele_dtypes}) of data. To view more ' \
@ -80,7 +78,7 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)->
depth = depths.pop()
shape_len = shape_lens.pop()
ele_dtype = ele_dtypes.pop()
ele_dtype = list(catalog.values())[0][1] # 因为上面有except的情况所以这样处理了
# 需要由 padder 自己决定是否能够 pad 。
try:
@ -93,6 +91,8 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)->
return TorchNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)
elif backend == 'paddle':
return PaddleNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)
else:
raise ValueError(f"backend={backend} is not supported for list(Field:{field_name}).")
if depth > 1 and shape_len == 0: # 形如 [[0, 1], [2]] 这种
if backend == 'raw':
@ -103,14 +103,21 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)->
return TorchSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)
elif backend == 'paddle':
return PaddleSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)
else:
raise ValueError(f"backend={backend} is not supported for nested list(Field:{field_name}).")
if depth == 1 and shape_len != 0:
if backend == 'numpy':
return NumpyTensorPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)
# 如果有有 shape 的话,只有当该对象拥有 tolist() 方法才行
if depth == 1 and shape_len != 0 and callable(getattr(batch_field[0], 'tolist', None)):
if backend == 'raw':
return RawTensorPadder(pad_val=pad_val, ele_dtype=None, dtype=dtype)
elif backend == 'numpy':
return NumpyTensorPadder(pad_val=pad_val, ele_dtype=None, dtype=dtype)
elif backend == 'torch':
return TorchTensorPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)
return TorchTensorPadder(pad_val=pad_val, ele_dtype=None, dtype=dtype)
elif backend == 'paddle':
return PaddleTensorPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)
return PaddleTensorPadder(pad_val=pad_val, ele_dtype=None, dtype=dtype)
else:
raise ValueError(f"backend={backend} is not supported for tensors(Field:{field_name}).")
if shape_len != 0 and depth>1:
msg = "Does not support pad tensor under nested list. If you need this, please report."
@ -179,23 +186,3 @@ def _get_element_shape_dtype(content, parent=None, catalog=None)->Dict:
else: # 包括 int/float/bool/dict 以及 其它无法pad 的等
catalog[parent] = ((), type(content)) # () 表示 shape 的长度为 0后面表示其类别
return catalog
"""
from numbers import Number
issubclass(type(3), Number) # True
issubclass(type(3.1), Number) # True
issubclass(type('3'), Number) # False
issubclass(type(True), Number) # True
issubclass(type(np.zeros(3)[0]), Number) # True
isinstance(np.zeros(3, dtype=float).dtype, np.dtype) # True
isinstance(np.zeros(3, dtype=int).dtype, np.dtype) # True
isinstance(np.zeros(3, dtype=str).dtype, np.dtype) # True, 需要通过和来判定
is_torch_tensor_dtype() # 可以通过isinstance(torch.zeros(3).dtype, torch.dtype)
"""

View File

@ -66,7 +66,7 @@ class NumpySequencePadder(Padder):
class NumpyTensorPadder(Padder):
def __init__(self, pad_val=0, ele_dtype=None, dtype=None):
"""
pad 类似于 [np.array([3, 4], np.array([1])] field
pad 类似于 [np.array([3, 4], np.array([1])] field 若内部元素不为 np.ndarray 则必须含有 tolist() 方法
:param pad_val: pad 的值是多少
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 np.array 类型
@ -77,6 +77,13 @@ class NumpyTensorPadder(Padder):
@staticmethod
def pad(batch_field, pad_val, dtype):
try:
if not isinstance(batch_field[0], np.ndarray):
batch_field = [np.array(field.tolist()) for field in batch_field]
except AttributeError:
raise RuntimeError(f"If the field is not a np.ndarray (it is {type(batch_field[0])}), "
f"it must have tolist() method.")
shapes = [field.shape for field in batch_field]
max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)]
array = np.full(max_shape, fill_value=pad_val, dtype=dtype)

View File

@ -56,7 +56,7 @@ def is_paddle_dtype_str(dtype):
def _get_dtype(ele_dtype, dtype, class_name):
if not (is_number_or_numpy_number(ele_dtype) or is_paddle_tensor(ele_dtype) or is_paddle_dtype_str(ele_dtype)):
if not (ele_dtype is not None or is_number_or_numpy_number(ele_dtype) or is_paddle_tensor(ele_dtype) or is_paddle_dtype_str(ele_dtype)):
raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers "
f"or numpy numbers or paddle.Tensor but get `{ele_dtype}`.")
@ -74,13 +74,20 @@ def _get_dtype(ele_dtype, dtype, class_name):
elif is_numpy_generic_class(ele_dtype):
dtype = numpy_to_paddle_dtype_dict.get(ele_dtype)
else:
dtype == ele_dtype
dtype = ele_dtype
return dtype
class PaddleNumberPadder(Padder):
def __init__(self, ele_dtype, pad_val=0, dtype=None):
def __init__(self, pad_val=0, ele_dtype=None, dtype=None):
"""
可以将形如 [1, 2, 3] 这类的数据转为 paddle.Tensor([1, 2, 3])
:param pad_val: 该值无意义
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 paddle.tensor 类型
:param dtype: 输出的数据的 dtype 是什么 int, float, 'int32'
"""
# 仅当 ele_dtype 是 python number/ numpy number 或者 tensor
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__)
super().__init__(pad_val=pad_val, dtype=dtype)
@ -91,7 +98,14 @@ class PaddleNumberPadder(Padder):
class PaddleSequencePadder(Padder):
def __init__(self, ele_dtype, pad_val=0, dtype=None):
def __init__(self, ele_dtype=None, pad_val=0, dtype=None):
"""
将类似于 [[1], [1, 2]] 的内容 pad paddle.Tensor([[1, 0], [1, 2]]) 可以 pad 多重嵌套的数据
:param pad_val: pad 的值
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 paddle.tensor 类型
:param dtype: 输出的数据的 dtype 是什么 int, float, 'int32'
"""
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__)
super().__init__(pad_val=pad_val, dtype=dtype)
@ -102,19 +116,26 @@ class PaddleSequencePadder(Padder):
class PaddleTensorPadder(Padder):
def __init__(self, ele_dtype, pad_val=0, dtype=None):
def __init__(self, pad_val=0, ele_dtype=None, dtype=None):
"""
目前支持 [paddle.tensor([3, 2], paddle.tensor([1])] 类似的
目前支持 [paddle.tensor([3, 2], paddle.tensor([2, 1])] 类似的若内部元素不为 paddle.tensor 则必须含有 tolist() 方法
:param ele_dtype:
:param pad_val:
:param dtype:
:param pad_val: pad 的值
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 paddle.tensor 类型
:param dtype: 输出的数据的 dtype 是什么 int, float, 'int32'
"""
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__)
super().__init__(pad_val=pad_val, dtype=dtype)
@staticmethod
def pad(batch_field, pad_val, dtype):
try:
if not isinstance(batch_field[0], paddle.Tensor):
batch_field = [paddle.to_tensor(field.tolist()) for field in batch_field]
except AttributeError:
raise RuntimeError(f"If the field is not a paddle.Tensor (it is {type(batch_field[0])}), "
f"it must have tolist() method.")
shapes = [field.shape for field in batch_field]
max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)]
if isinstance(dtype, np.dtype):
@ -174,6 +195,5 @@ def get_padded_paddle_tensor(batch_field, dtype=None, pad_val=0):
"""
shapes = get_shape(batch_field)
tensor = paddle.to_tensor(np.full(shape=shapes, fill_value=pad_val), dtype=dtype)
# tensor = paddle.full(shape=shapes, dtype=dtype, fill_value=pad_val)
tensor = fill_tensor(batch_field, tensor, dtype=dtype)
return tensor

View File

@ -1,4 +1,8 @@
__all__ = [
"RawNumberPadder",
"RawSequencePadder",
"RawTensorPadder"
]
from .padder import Padder
from .utils import is_number, get_padded_numpy_array, is_number_or_numpy_number
@ -63,3 +67,34 @@ class RawSequencePadder(Padder):
:return:
"""
return get_padded_numpy_array(batch_field, dtype=dtype, pad_val=pad_val).tolist()
class RawTensorPadder(Padder):
def __init__(self, pad_val=0, ele_dtype=None, dtype=None):
"""
将类似于 [[1], [1, 2]] 的内容 pad [[1, 0], [1, 2]] 可以 pad 多重嵌套的数据
:param pad_val: pad 的值
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 np.array 类型
:param dtype: 输出的数据的 dtype 是什么
"""
dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__)
super().__init__(pad_val=pad_val, dtype=dtype)
@staticmethod
def pad(batch_field, pad_val, dtype):
"""
:param batch_field:
:param pad_val:
:param dtype: 该参数无意义
:return:
"""
try:
if not isinstance(batch_field[0], (list, tuple)):
batch_field = [field.tolist() for field in batch_field]
except AttributeError:
raise RuntimeError(f"If the field is not a list or tuple(it is {type(batch_field[0])}), "
f"it must have tolist() method.")
return get_padded_numpy_array(batch_field, dtype=dtype, pad_val=pad_val).tolist()

View File

@ -1,4 +1,8 @@
__all__ = [
'TorchNumberPadder',
'TorchSequencePadder',
'TorchTensorPadder'
]
from inspect import isclass
import numpy as np
@ -37,7 +41,7 @@ def is_torch_tensor(dtype):
def _get_dtype(ele_dtype, dtype, class_name):
if not (ele_dtype is not None and (is_number_or_numpy_number(ele_dtype) or is_torch_tensor(ele_dtype))):
if not (ele_dtype is None or (is_number_or_numpy_number(ele_dtype) or is_torch_tensor(ele_dtype))):
raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers "
f"or numpy numbers or torch.Tensor but get `{ele_dtype}`.")
@ -97,7 +101,7 @@ class TorchSequencePadder(Padder):
class TorchTensorPadder(Padder):
def __init__(self, pad_val=0, ele_dtype=None, dtype=None):
"""
目前支持 [torch.tensor([3, 2], torch.tensor([1])] 类似的
目前支持 [torch.tensor([3, 2], torch.tensor([1])] 类似的若内部元素不为 torch.tensor 则必须含有 tolist() 方法
:param pad_val: 需要 pad 的值
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 torch.tensor 类型
@ -108,6 +112,13 @@ class TorchTensorPadder(Padder):
@staticmethod
def pad(batch_field, pad_val, dtype):
try:
if not isinstance(batch_field[0], torch.Tensor):
batch_field = [torch.tensor(field.tolist()) for field in batch_field]
except AttributeError:
raise RuntimeError(f"If the field is not a torch.Tensor (it is {type(batch_field[0])}), "
f"it must have tolist() method.")
shapes = [field.shape for field in batch_field]
max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)]
tensor = torch.full(max_shape, fill_value=pad_val, dtype=dtype)

View File

@ -1,6 +1,10 @@
__all__ = [
'get_padded_numpy_array'
]
from typing import Sequence, List
from numbers import Number
import re
from inspect import isclass

View File

@ -2,8 +2,6 @@ __all__ = [
'Loop',
'EvaluateBatchLoop',
'TrainBatchLoop',
'State',
'TrainerState',
'Evaluator',
'Trainer',
]

View File

@ -17,10 +17,10 @@ from .utils import State, TrainerState
from .utils.utils import check_evaluate_every
from .evaluator import Evaluator
from fastNLP.core.controllers.utils.utils import TrainerEventTrigger, _TruncatedDataLoader
from fastNLP.core.callbacks import Callback, CallbackManager, Events, EventsList
from fastNLP.core.callbacks import Callback, CallbackManager
from fastNLP.core.callbacks.callback import _CallbackWrapper
from fastNLP.core.callbacks.callback_manager import prepare_callbacks
from fastNLP.core.callbacks.callback_events import _SingleEventState
from fastNLP.core.callbacks.callback_event import Event
from fastNLP.core.drivers import Driver
from fastNLP.core.drivers.utils import choose_driver
from fastNLP.core.utils import get_fn_arg_names, match_and_substitute_params, nullcontext
@ -363,7 +363,6 @@ class Trainer(TrainerEventTrigger):
raise e
finally:
self.on_train_end()
self.driver.barrier()
def _set_num_eval_batch_per_dl(self, num_eval_batch_per_dl):
def _evaluate_fn(trainer: Trainer, evaluate_fn: Callable) -> None:
@ -399,7 +398,7 @@ class Trainer(TrainerEventTrigger):
if self.cur_epoch_idx % evaluate_every == 0:
self.run_evaluate()
def add_callback_fn(self, event: Optional[Union[Events, EventsList]], fn: Callable):
def add_callback_fn(self, event: Event, fn: Callable):
r"""
在初始化一个 trainer 实例后用户可以使用这一函数来方便地添加 callback 函数
这一函数应当交给具体的 trainer 实例去做因此不需要 `mark` 参数
@ -407,19 +406,69 @@ class Trainer(TrainerEventTrigger):
:param event: 特定的 callback 时机用户需要为该 callback 函数指定其属于哪一个 callback 时机
:param fn: 具体的 callback 函数
"""
if not isinstance(event, (_SingleEventState, EventsList)):
raise ValueError("parameter event should only be `Events` or `EventsList` type.")
if not isinstance(event, Event):
raise ValueError("parameter event should only be `Event` type.")
_custom_callback = _CallbackWrapper(event, fn)
self.callback_manager.dissect_one_callback(_custom_callback)
@classmethod
def on(cls, event: Optional[Union[Events, EventsList]], marker: Optional[str] = None):
def on(cls, event: Event, marker: Optional[str] = None):
r"""
函数修饰器用户可以使用该函数来方便地将一个函数转变为 callback 函数从而进行训练流程中的控制
支持的 event 时机有以下这些其执行的时机顺序也如下所示每个时机装饰的函数应该接受的参数列表也如下所示例如
Trainer.__init__():
on_after_trainer_initialized(trainer, driver)
Trainer.run():
if num_eval_sanity_batch>0:
on_sanity_check_begin(trainer) # 如果设置了num_eval_sanity_batch
on_sanity_check_end(trainer, sanity_check_res)
try:
on_train_begin(trainer)
while cur_epoch_idx < n_epochs:
on_train_epoch_begin(trainer)
while batch_idx_in_epoch<=num_batches_per_epoch:
on_fetch_data_begin(trainer)
batch = next(dataloader)
on_fetch_data_end(trainer)
on_train_batch_begin(trainer, batch, indices)
on_before_backward(trainer, outputs) # 其中 outputs 是经过 output_mapping如果设置了 后的,否则即为 model 的输出。
on_after_backward(trainer)
on_before_zero_grad(trainer, optimizers) # 实际调用受到 accumulation_steps 影响
on_after_zero_grad(trainer, optimizers) # 实际调用受到 accumulation_steps 影响
on_before_optimizers_step(trainer, optimizers) # 实际调用受到 accumulation_steps 影响
on_after_optimizers_step(trainer, optimizers) # 实际调用受到 accumulation_steps 影响
on_train_batch_end(trainer)
on_train_epoch_end(trainer)
except BaseException:
self.on_exception(trainer, exception)
finally:
on_train_end(trainer)
其它 callback 例如 on_evaluate_begin(trainer)/on_evaluate_end(trainer, results)/on_save_model(trainer)/
on_load_model(trainer)/on_save_checkpoint(trainer)/on_load_checkpoint(trainer)将根据需要在Trainer.run()
特定的时间调用
Example::
from fastNLP import Event
@Trainer.on(Event.on_save_model())
def do_something_1(trainer):
# do something
# 以上函数会在 Trainer 保存模型时执行。
@Trainer.on(Event.on_save_model(once=True))
def do_something_2(trainer):
# do something
# 以上函数会在 Trainer 保存模型时执行,但只执行一次。
@Trainer.on(Event.on_train_batch_begin(every=2))
def do_something_3(trainer, batch, indices):
# do something
# 以上函数会在 Trainer 每个新的 batch 开始的时候执行,但是是两个 batch 才执行一次。
注意如果你使用该函数修饰器来为你的训练添加 callback请务必保证你加入 callback 函数的代码在实例化 `Trainer` 之前
:param event: 特定的 callback 时机用户需要为该 callback 函数指定其属于哪一个 callback 时机
:param event: 特定的 callback 时机用户需要为该 callback 函数指定其属于哪一个 callback 时机每个时机运行的函数应该包含
特定的参数可以通过上述说明查阅
:param marker: 用来标记该 callback 函数属于哪几个具体的 trainer 实例两个特殊情况1. `marker` None默认情况
表示该 callback 函数只属于代码下方最近的一个 trainer 实例2. `marker` 'all' callback 函数会被所有的 trainer
实例使用
@ -427,9 +476,9 @@ class Trainer(TrainerEventTrigger):
"""
def wrapper(fn: Callable) -> Callable:
cls._custom_callbacks[marker].append((event, fn))
callback_fn_args = get_fn_arg_names(getattr(Callback, event.value))[1:]
_check_valid_parameters_number(fn, callback_fn_args)
cls._custom_callbacks[marker].append((event, fn))
return fn
return wrapper
@ -441,6 +490,7 @@ class Trainer(TrainerEventTrigger):
"""
_own_callbacks: List = copy.deepcopy(self._custom_callbacks["all"])
_own_callbacks.extend(self._custom_callbacks[None])
logger.debug(f"Get {len(_own_callbacks)} callback fns through Trainer.on().")
self._custom_callbacks[None] = []
if self.marker is not None:
if len(self._custom_callbacks[self.marker]) == 0:

View File

@ -14,7 +14,7 @@ else:
from fastNLP.core.dataset import DataSet as Dataset
from fastNLP.core.utils.jittor_utils import jittor_collate_wraps
from fastNLP.core.collators import Collator
from fastNLP.core.utils.utils import indice_collate_wrapper
from fastNLP.core.dataloaders.utils import indice_collate_wrapper
from fastNLP.core.dataset import DataSet as FDataSet
@ -107,33 +107,33 @@ class JittorDataLoader:
return len(self.dataset) // self.dataset.batch_size
return (len(self.dataset) - 1) // self.dataset.batch_size + 1
def set_pad(self, field_name: Union[str, tuple], pad_val: Union[int, float, None] = 0, dtype=None, backend=None,
pad_fn: Callable = None) -> "JittorDataLoader":
def set_pad(self, field_name:Union[str, tuple], pad_val:Union[int, float, None]=0, dtype=None, backend=None,
pad_fn:Callable=None) -> Collator:
"""
如果需要对某个 field 的内容进行特殊的调整请使用这个函数
如果需要对某个 field 的内容进行特殊的调整请使用这个函数
:param field_name: 需要调整的 field 的名称如果 Dataset __getitem__ 方法返回的是 dict 类型的则可以直接使用对应的
field key 来表示如果是 nested dict可以使用元组表示多层次的 key例如 {'a': {'b': 1}} 中的使用 ('a', 'b');
如果 __getitem__ 返回的是 Sequence 类型的则可以使用 '_0', '_1' 表示序列中第 0 1 个元素如果该 field 在数据中没
有找到则报错如果 __getitem__ 返回的是就是整体内容请使用 "_single"
:param pad_val: 这个 field 的默认 pad 如果设置为 None则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad
field 进行 pad所以如果对应 field 本身就不是可以 pad 的形式可以不需要主动设置为 None
:param dtype: 对于需要 pad field field 的数据 dtype 应该是什么
:param backend: 可选[None, 'numpy', 'torch', 'paddle', 'jittor']分别代表输出为 list, numpy.ndarray, torch.Tensor,
paddle.Tensor, jittor.Var 类型 pad_val None 该值只能为 None numpy
:param pad_fn: 指定当前 field pad 函数传入该函数则 pad_val, dtype, backend 等参数失效pad_fn 的输入为当前 field
batch 形式 Collator 将自动 unbatch 数据然后将各个 field 组成各自的 batch pad_func 的输入即为 field batch
形式输出将被直接作为结果输出
:return: 返回 Collator 自身
:param field_name: 需要调整的 field 的名称如果 Dataset __getitem__ 方法返回的是 dict 类型的则可以直接使用对应的
field key 来表示如果是 nested dict可以使用元组表示多层次的 key例如 {'a': {'b': 1}} 中的使用 ('a', 'b');
如果 __getitem__ 返回的是 Sequence 类型的则可以使用 '_0', '_1' 表示序列中第 0 1 个元素如果该 field 在数据中没
有找到则报错如果 __getitem__ 返回的是就是整体内容请使用 "_single"
:param pad_val: 这个 field 的默认 pad 如果设置为 None则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad
field 进行 pad所以如果对应 field 本身就不是可以 pad 的形式可以不需要主动设置为 None 如果 backend None 该值
无意义
:param dtype: 对于需要 pad field field 的数据 dtype 应该是什么
:param backend: 可选['raw', 'numpy', 'torch', 'paddle', 'jittor', 'auto']分别代表输出为 list, numpy.ndarray,
torch.Tensor, paddle.Tensor, jittor.Var 类型 pad_val None 该值无意义
:param pad_fn: 指定当前 field pad 函数传入该函数则 pad_val, dtype, backend 等参数失效pad_fn 的输入为当前 field
batch 形式 Collator 将自动 unbatch 数据然后将各个 field 组成各自的 batch pad_func 的输入即为 field batch
形式输出将被直接作为结果输出
:return: 返回 Collator 自身
"""
if isinstance(self._collate_fn, Collator):
self._collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn,
backend=backend)
return self
self._collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend)
return self._collate_fn
else:
raise ValueError(f"collate_fn is not fastnlp collator")
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_pad() is allowed.")
def set_ignore(self, *field_names) -> "JittorDataLoader":
def set_ignore(self, *field_names) -> Collator:
"""
如果有的内容不希望输出可以在此处进行设置被设置的 field 将在 batch 的输出中被忽略
Ex::
@ -146,18 +146,17 @@ class JittorDataLoader:
"""
if isinstance(self._collate_fn, Collator):
self._collate_fn.set_ignore(*field_names)
return self
return self._collate_fn
else:
raise ValueError(f"collate_fn is not fastnlp collator")
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.")
def get_batch_indices(self) -> List[int]:
"""
获取当前数据的idx
获取当前 batch idx
:return:
"""
return self.cur_batch_indices
def prepare_jittor_dataloader():
...

View File

@ -15,8 +15,9 @@ else:
from fastNLP.core.utils.dummy_class import DummyClass as DataLoader
from fastNLP.core.collators.collator import Collator
from fastNLP.core.utils.utils import indice_collate_wrapper
from fastNLP.core.dataloaders.utils import indice_collate_wrapper
from fastNLP.core.dataset import DataSet as FDataSet
from fastNLP.core.samplers import ReproducibleBatchSampler, RandomBatchSampler
class _PaddleDataset(Dataset):
@ -54,6 +55,10 @@ class PaddleDataLoader(DataLoader):
if not isinstance(dataset, _PaddleDataset):
dataset = _PaddleDataset(dataset)
if batch_sampler is None:
batch_sampler = RandomBatchSampler(dataset, batch_size=batch_size, shuffle=shuffle,
drop_last=drop_last)
super(PaddleDataLoader, self).__init__(dataset=dataset, feed_list=feed_list, places=places,
return_list=return_list, batch_sampler=batch_sampler,
batch_size=batch_size, shuffle=shuffle, drop_last=drop_last,
@ -66,8 +71,6 @@ class PaddleDataLoader(DataLoader):
if isinstance(dataset.dataset, FDataSet):
self._collate_fn = dataset.dataset.collator
self._collate_fn.set_backend(backend="paddle")
# if collate_fn is not None:
# self._collate_fn.add_collator(collate_fn)
else:
self._collate_fn = Collator(backend="paddle")
@ -94,33 +97,33 @@ class PaddleDataLoader(DataLoader):
self.cur_batch_indices = indices
yield data
def set_pad(self, field_name: Union[str, tuple], pad_val: Union[int, float, None] = 0, dtype=None, backend=None,
pad_fn: Callable = None) -> "PaddleDataLoader":
def set_pad(self, field_name:Union[str, tuple], pad_val:Union[int, float, None]=0, dtype=None, backend=None,
pad_fn:Callable=None) -> Collator:
"""
如果需要对某个 field 的内容进行特殊的调整请使用这个函数
如果需要对某个 field 的内容进行特殊的调整请使用这个函数
:param field_name: 需要调整的 field 的名称如果 Dataset __getitem__ 方法返回的是 dict 类型的则可以直接使用对应的
field key 来表示如果是 nested dict可以使用元组表示多层次的 key例如 {'a': {'b': 1}} 中的使用 ('a', 'b');
如果 __getitem__ 返回的是 Sequence 类型的则可以使用 '_0', '_1' 表示序列中第 0 1 个元素如果该 field 在数据中没
有找到则报错如果 __getitem__ 返回的是就是整体内容请使用 "_single"
:param pad_val: 这个 field 的默认 pad 如果设置为 None则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad
field 进行 pad所以如果对应 field 本身就不是可以 pad 的形式可以不需要主动设置为 None
:param dtype: 对于需要 pad field field 的数据 dtype 应该是什么
:param backend: 可选[None, 'numpy', 'torch', 'paddle', 'jittor']分别代表输出为 list, numpy.ndarray, torch.Tensor,
paddle.Tensor, jittor.Var 类型 pad_val None 该值只能为 None numpy
:param pad_fn: 指定当前 field pad 函数传入该函数则 pad_val, dtype, backend 等参数失效pad_fn 的输入为当前 field
batch 形式 Collator 将自动 unbatch 数据然后将各个 field 组成各自的 batch pad_func 的输入即为 field batch
形式输出将被直接作为结果输出
:return: 返回 Collator 自身
:param field_name: 需要调整的 field 的名称如果 Dataset __getitem__ 方法返回的是 dict 类型的则可以直接使用对应的
field key 来表示如果是 nested dict可以使用元组表示多层次的 key例如 {'a': {'b': 1}} 中的使用 ('a', 'b');
如果 __getitem__ 返回的是 Sequence 类型的则可以使用 '_0', '_1' 表示序列中第 0 1 个元素如果该 field 在数据中没
有找到则报错如果 __getitem__ 返回的是就是整体内容请使用 "_single"
:param pad_val: 这个 field 的默认 pad 如果设置为 None则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad
field 进行 pad所以如果对应 field 本身就不是可以 pad 的形式可以不需要主动设置为 None 如果 backend None 该值
无意义
:param dtype: 对于需要 pad field field 的数据 dtype 应该是什么
:param backend: 可选['raw', 'numpy', 'torch', 'paddle', 'jittor', 'auto']分别代表输出为 list, numpy.ndarray,
torch.Tensor, paddle.Tensor, jittor.Var 类型 pad_val None 该值无意义
:param pad_fn: 指定当前 field pad 函数传入该函数则 pad_val, dtype, backend 等参数失效pad_fn 的输入为当前 field
batch 形式 Collator 将自动 unbatch 数据然后将各个 field 组成各自的 batch pad_func 的输入即为 field batch
形式输出将被直接作为结果输出
:return: 返回 Collator 自身
"""
if isinstance(self._collate_fn, Collator):
self._collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn,
backend=backend)
return self
self._collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend)
return self._collate_fn
else:
raise ValueError(f"collate_fn is not fastnlp collator")
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_pad() is allowed.")
def set_ignore(self, *field_names) -> "PaddleDataLoader":
def set_ignore(self, *field_names) -> Collator:
"""
如果有的内容不希望输出可以在此处进行设置被设置的 field 将在 batch 的输出中被忽略
Ex::
@ -133,13 +136,13 @@ class PaddleDataLoader(DataLoader):
"""
if isinstance(self._collate_fn, Collator):
self._collate_fn.set_ignore(*field_names)
return self
return self._collate_fn
else:
raise ValueError(f"collate_fn is not fastnlp collator")
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.")
def get_batch_indices(self) -> List[int]:
"""
获取当前数据的idx
获取当前 batch idx
:return:
"""
@ -147,7 +150,8 @@ class PaddleDataLoader(DataLoader):
def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None,
return_list: bool = True, batch_sampler=None,
return_list: bool = True,
batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None,
train_batch_size: int = 1, shuffle: bool = False,
drop_last: bool = False, collate_fn: Union[Callable, str, None] = None,
num_workers: int = 0, use_buffer_reader: bool = True,

View File

@ -3,14 +3,14 @@ __all__ = [
'prepare_torch_dataloader'
]
from typing import Optional, Callable, Sequence, List, Union, Tuple, Dict, Mapping
from typing import Optional, Callable, Sequence, Union, Tuple, Dict, Mapping, List
from fastNLP.core.dataset import DataSet
from fastNLP.core.collators import Collator
from fastNLP.core.utils.utils import indice_collate_wrapper
from fastNLP.core.dataloaders.utils import indice_collate_wrapper
from fastNLP.io.data_bundle import DataBundle
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, UnrepeatedSampler
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, UnrepeatedSampler, RandomSampler
if _NEED_IMPORT_TORCH:
from torch.utils.data import DataLoader, Sampler
@ -76,6 +76,10 @@ class TorchDataLoader(DataLoader):
if not isinstance(dataset, _FDataSet):
dataset = _FDataSet(dataset)
if sampler is None and batch_sampler is None:
sampler = RandomSampler(dataset, shuffle=shuffle)
shuffle=False
super().__init__(dataset=dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler,
batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=None,
pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
@ -87,9 +91,6 @@ class TorchDataLoader(DataLoader):
if isinstance(dataset.dataset, DataSet): # 使用了 fastnlp dataset
self._collate_fn = dataset.dataset.collator
self._collate_fn.set_backend(backend="torch")
# if collate_fn is not None and collate_fn is not default_collate:
# # 防止ddp重新初始化时候将torch dataloader的默认collate加进来
# self._collate_fn.add_collator(collate_fn)
else:
self._collate_fn = Collator(backend="torch")
else:
@ -112,31 +113,32 @@ class TorchDataLoader(DataLoader):
yield data
def set_pad(self, field_name:Union[str, tuple], pad_val:Union[int, float, None]=0, dtype=None, backend=None,
pad_fn:Callable=None) -> "TorchDataLoader":
pad_fn:Callable=None) -> Collator:
"""
如果需要对某个 field 的内容进行特殊的调整请使用这个函数
如果需要对某个 field 的内容进行特殊的调整请使用这个函数
:param field_name: 需要调整的 field 的名称如果 Dataset __getitem__ 方法返回的是 dict 类型的则可以直接使用对应的
field key 来表示如果是 nested dict可以使用元组表示多层次的 key例如 {'a': {'b': 1}} 中的使用 ('a', 'b');
如果 __getitem__ 返回的是 Sequence 类型的则可以使用 '_0', '_1' 表示序列中第 0 1 个元素如果该 field 在数据中没
有找到则报错如果 __getitem__ 返回的是就是整体内容请使用 "_single"
:param pad_val: 这个 field 的默认 pad 如果设置为 None则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad
field 进行 pad所以如果对应 field 本身就不是可以 pad 的形式可以不需要主动设置为 None
:param dtype: 对于需要 pad field field 的数据 dtype 应该是什么
:param backend: 可选[None, 'numpy', 'torch', 'paddle', 'jittor']分别代表输出为 list, numpy.ndarray, torch.Tensor,
paddle.Tensor, jittor.Var 类型 pad_val None 该值只能为 None numpy
:param pad_fn: 指定当前 field pad 函数传入该函数则 pad_val, dtype, backend 等参数失效pad_fn 的输入为当前 field
batch 形式 Collator 将自动 unbatch 数据然后将各个 field 组成各自的 batch pad_func 的输入即为 field batch
形式输出将被直接作为结果输出
:return: 返回 Collator 自身
:param field_name: 需要调整的 field 的名称如果 Dataset __getitem__ 方法返回的是 dict 类型的则可以直接使用对应的
field key 来表示如果是 nested dict可以使用元组表示多层次的 key例如 {'a': {'b': 1}} 中的使用 ('a', 'b');
如果 __getitem__ 返回的是 Sequence 类型的则可以使用 '_0', '_1' 表示序列中第 0 1 个元素如果该 field 在数据中没
有找到则报错如果 __getitem__ 返回的是就是整体内容请使用 "_single"
:param pad_val: 这个 field 的默认 pad 如果设置为 None则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad
field 进行 pad所以如果对应 field 本身就不是可以 pad 的形式可以不需要主动设置为 None 如果 backend None 该值
无意义
:param dtype: 对于需要 pad field field 的数据 dtype 应该是什么
:param backend: 可选['raw', 'numpy', 'torch', 'paddle', 'jittor', 'auto']分别代表输出为 list, numpy.ndarray,
torch.Tensor, paddle.Tensor, jittor.Var 类型 pad_val None 该值无意义
:param pad_fn: 指定当前 field pad 函数传入该函数则 pad_val, dtype, backend 等参数失效pad_fn 的输入为当前 field
batch 形式 Collator 将自动 unbatch 数据然后将各个 field 组成各自的 batch pad_func 的输入即为 field batch
形式输出将被直接作为结果输出
:return: 返回 Collator 自身
"""
if isinstance(self._collate_fn, Collator):
self._collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend)
return self
return self._collate_fn
else:
raise ValueError(f"collate_fn is not fastnlp collator")
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_pad() is allowed.")
def set_ignore(self, *field_names) -> "TorchDataLoader":
def set_ignore(self, *field_names) -> Collator:
"""
如果有的内容不希望输出可以在此处进行设置被设置的 field 将在 batch 的输出中被忽略
Ex::
@ -149,24 +151,23 @@ class TorchDataLoader(DataLoader):
"""
if isinstance(self._collate_fn, Collator):
self._collate_fn.set_ignore(*field_names)
return self
return self._collate_fn
else:
raise ValueError(f"collate_fn is not fastnlp collator")
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.")
def get_batch_indices(self) -> List[int]:
"""
获取当前数据的idx
获取当前 batch idx
:return:
"""
return self.cur_batch_indices
def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataSet], Mapping[str, DataSet]],
batch_size: int = 1,
shuffle: bool = False, sampler: Optional["Sampler[int]"] = None,
batch_sampler: Optional["Sampler[Sequence[int]]"] = None,
shuffle: bool = False, sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None,
batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None,
num_workers: int = 0, collate_fn: Union[str, Callable, None] = None,
pin_memory: bool = False, drop_last: bool = False,
timeout: float = 0, worker_init_fn: Optional[Callable] = None,

View File

@ -0,0 +1,16 @@
def indice_collate_wrapper(func):
"""
其功能是封装一层collate_fn,将dataset取到的tuple数据分离开将idx打包为indices
:param func: 需要修饰的函数
:return:
"""
def wrapper(tuple_data):
indice, ins_list = [], []
for idx, ins in tuple_data:
indice.append(idx)
ins_list.append(ins)
return indice, func(ins_list)
return wrapper

View File

@ -770,17 +770,8 @@ class DataSet:
df = self.to_pandas()
return df.to_csv(path, encoding="utf-8")
def set_ignore(self, *field_names) -> None:
"""
被设置为inputs的field_names会输入到AutoCollator中未被设置默认过滤掉
:param field_names:
:return:
"""
self.collator.set_ignore(*field_names)
@property
def collator(self):
def collator(self) -> Collator:
if self._collator is None:
self._collator = Collator()
return self._collator

View File

@ -22,7 +22,7 @@ from fastNLP.core.utils import (
rank_zero_rm
)
from fastNLP.core.samplers import (
RandomBatchSampler,
ReproduceBatchSampler,
ReproducibleSampler,
ReproducibleBatchSampler,
RandomSampler,
@ -485,7 +485,7 @@ class PaddleFleetDriver(PaddleDriver):
return self.model, model.forward
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, RandomBatchSampler]],
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, ReproduceBatchSampler]],
reproducible: bool = False):
r"""
根据输入的 dataloader 得到一个 支持分布式 distributed 可复现的 (reproducible) dataloader

View File

@ -22,7 +22,7 @@ from fastNLP.core.log import logger
from fastNLP.core.samplers import (
ReproducibleBatchSampler,
ReproducibleSampler,
RandomBatchSampler,
ReproduceBatchSampler,
RandomSampler,
)
@ -345,7 +345,7 @@ class PaddleDriver(Driver):
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or "
"`ReproducibleSampler`.")
else:
sampler = RandomBatchSampler(
sampler = ReproduceBatchSampler(
batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler,
batch_size=dataloader_args.batch_size,
drop_last=dataloader_args.drop_last
@ -476,7 +476,7 @@ class PaddleDriver(Driver):
res.shuffle = True
else:
res.shuffle = False
# RandomBatchSampler 的情况
# ReproduceBatchSampler 的情况
elif hasattr(dataloader.batch_sampler, "batch_sampler"):
batch_sampler = dataloader.batch_sampler.batch_sampler
res.sampler = batch_sampler.sampler

View File

@ -14,7 +14,7 @@ from fastNLP.core.utils import (
from fastNLP.core.utils.utils import _get_fun_msg
from fastNLP.core.samplers import (
ReproducibleBatchSampler,
RandomBatchSampler,
ReproduceBatchSampler,
ReproducibleSampler,
RandomSampler,
re_instantiate_sampler,
@ -177,7 +177,7 @@ class PaddleSingleDriver(PaddleDriver):
logger.debug("Replace paddle RandomSampler into fastNLP RandomSampler.")
return replace_sampler(dataloader, sampler)
else:
batch_sampler = RandomBatchSampler(
batch_sampler = ReproduceBatchSampler(
batch_sampler=args.batch_sampler,
batch_size=args.batch_size,
drop_last=args.drop_last

View File

@ -15,7 +15,7 @@ from .torch_driver import TorchDriver
from fastNLP.core.drivers.torch_driver.utils import replace_sampler, replace_batch_sampler
from fastNLP.core.utils import auto_param_call
from fastNLP.core.utils.utils import _get_fun_msg
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler, RandomBatchSampler
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler, ReproduceBatchSampler
from fastNLP.core.samplers import RandomSampler
from fastNLP.core.log import logger
@ -113,7 +113,7 @@ class TorchSingleDriver(TorchDriver):
logger.debug("Replace torch RandomSampler into fastNLP RandomSampler.")
return replace_sampler(dataloader, sampler)
else:
batch_sampler = RandomBatchSampler(
batch_sampler = ReproduceBatchSampler(
batch_sampler=args.batch_sampler,
batch_size=args.batch_size,
drop_last=args.drop_last

View File

@ -31,7 +31,7 @@ from fastNLP.core.utils import apply_to_collection, torch_move_data_to_device
from fastNLP.envs import rank_zero_call
from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME
from fastNLP.core.log import logger
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, RandomBatchSampler, RandomSampler
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, ReproduceBatchSampler, RandomSampler
class TorchDriver(Driver):
@ -293,7 +293,7 @@ class TorchDriver(Driver):
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or "
"`ReproducibleSampler`.")
else:
sampler = RandomBatchSampler(
sampler = ReproduceBatchSampler(
batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler,
batch_size=dataloader_args.batch_size,
drop_last=dataloader_args.drop_last
@ -407,7 +407,7 @@ class TorchDriver(Driver):
res.shuffle = True
else:
res.shuffle = False
# RandomBatchSampler 的情况
# ReproduceBatchSampler 的情况
elif hasattr(dataloader.batch_sampler, "batch_sampler"):
batch_sampler = dataloader.batch_sampler.batch_sampler
res.sampler = batch_sampler.sampler

25
fastNLP/core/log/print.py Normal file
View File

@ -0,0 +1,25 @@
__all__ = [
'print'
]
from .logger import logger
def print(*args, sep=' ', end='\n', file=None, flush=False):
"""
用来重定向 print 函数至 logger.info 的函数
Example:
from fastNLP import print
print("This is a test") # 等价于调用了 logger.info("This is a test")
:param args: 需要打印的内容
:param sep: 存在多个输入时使用的间隔
:param end: 该参数在当前设置无意义因为结尾一定会被加入 \n
:param file: 该参数无意义
:param flush: 该参数无意义
:return:
"""
line = sep.join(args)
logger.info(line)

View File

@ -14,9 +14,10 @@ __all__ = [
"UnrepeatedSortedSampler",
"UnrepeatedSequentialSampler",
"RandomBatchSampler",
"ReproduceBatchSampler",
"BucketedBatchSampler",
"ReproducibleBatchSampler",
"RandomBatchSampler",
"re_instantiate_sampler"
]
@ -26,5 +27,5 @@ from .mix_sampler import MixSampler, DopedSampler, MixSequentialSampler, Polling
from .reproducible_sampler import ReproducibleSampler, RandomSampler, SequentialSampler, SortedSampler
from .utils import re_instantiate_sampler
from .conversion_utils import conversion_between_reproducible_and_unrepeated_sampler
from .reproducible_batch_sampler import RandomBatchSampler, BucketedBatchSampler, ReproducibleBatchSampler
from .reproducible_batch_sampler import ReproduceBatchSampler, BucketedBatchSampler, ReproducibleBatchSampler, RandomBatchSampler

View File

@ -1,5 +1,6 @@
__all__ = [
'BucketedBatchSampler',
"ReproduceBatchSampler",
"RandomBatchSampler"
]
@ -7,7 +8,6 @@ import math
from copy import deepcopy
from typing import Dict, Union, List
from itertools import chain
import os
import numpy as np
@ -54,13 +54,12 @@ class ReproducibleBatchSampler:
raise NotImplementedError("Each specific batch_sampler should implement its own `batch_idx_in_epoch` property.")
class RandomBatchSampler(ReproducibleBatchSampler):
# 这两个参数的值应当交给 driver 的 get_dataloader_args 函数去拿;
class ReproduceBatchSampler(ReproducibleBatchSampler):
def __init__(self, batch_sampler, batch_size: int, drop_last: bool, **kwargs):
"""
可以使得 batch_sampler 对象状态恢复的 wrapper
:param batch_sampler: 可迭代出 数字 数字列表 的可迭代对象RandomBatchSampler 将首先遍历一边该对象然后将迭代
:param batch_sampler: 可迭代出 数字 数字列表 的可迭代对象ReproduceBatchSampler 将首先遍历一边该对象然后将迭代
出来的序号暂存起来使用时按照 batch_size batch 大小吐出序号列表
:param batch_size: 每个 batch 的大小是多少
:param drop_last: 如果最后一个 batch 无法构成 batch_size 那么多个 sample 是否丢掉
@ -143,7 +142,7 @@ class RandomBatchSampler(ReproducibleBatchSampler):
self.need_reinitialize = False
def set_distributed(self, num_replicas, rank, pad=True):
raise RuntimeError(f"RandomBatchSampler does not support to change to distributed training.")
raise RuntimeError(f"ReproduceBatchSampler does not support to change to distributed training.")
def set_epoch(self, epoch):
if hasattr(self.batch_sampler, "sampler") and hasattr(self.batch_sampler.sampler, 'set_epoch') and callable(self.batch_sampler.sampler.set_epoch):
@ -158,6 +157,211 @@ class RandomBatchSampler(ReproducibleBatchSampler):
(len(self.index_list) - self.num_consumed_samples + self.batch_size - 1) // self.batch_size
class RandomBatchSampler(ReproducibleBatchSampler):
def __init__(self, dataset, batch_size:int = 32, shuffle: bool = True,
drop_last: bool = False, seed: int = 0, **kwargs):
"""
随机分 batch batch_sampler
:param dataset: 实现了 __len__ 方法的数据容器
:param batch_size: 每个 batch 的大小
:param shuffle: 如果为 True将不进行 shuffle实际上数据会以从长到短的方式输出
:param drop_last: 如果最后一个 batch sample 数量无法凑齐 batch_size 这么多是否需要丢掉
:param seed: 设置的随机数种子
:param kwargs: fastNLP 保留使用
"""
super().__init__()
self.dataset = dataset
self.batch_size = batch_size
self.shuffle = shuffle
self.drop_last = drop_last
self.seed = seed
self.num_consumed_samples = kwargs.get("num_consumed_samples", 0) # 总共迭代了多少数据了,包括多卡情况下的其它卡上的输出的数量
# 多卡的相关的参数
self.num_replicas = kwargs.get("num_replicas", 1)
self.rank = kwargs.get("rank", 0)
self.epoch = kwargs.get("epoch", -1)
self.pad = kwargs.get("pad", False) # 该参数在单卡上不具有任何意义;
# 是否处于iteration之间为True不允许调用 set_distributed()和load_state_dict()
self.during_iter = kwargs.get("during_iter", False)
# 以下变量为内部使用恢复状态的变量。
self.old_batch_size = kwargs.get('old_batch_size', self.batch_size)
def set_distributed(self, num_replicas, rank, pad=True):
assert self.during_iter is False, "Cannot set the sampler to be distributed when it is " \
"during an unfinished iteration."
assert num_replicas > 0 and isinstance(num_replicas, int)
assert isinstance(rank, int) and 0 <= rank < num_replicas
# 注意初始化该函数时,所有的状态都应当默认是一个 epoch 刚开始训练的状态;
self.num_replicas = num_replicas
self.rank = rank
self.pad = pad
return self
def __iter__(self):
if self.during_iter: # 如果发现_during_iter为True说明之前的还没结束只有强制重新初始化了
self.num_consumed_samples = 0
self.during_iter = True
indices = list(range(len(self.dataset)))
if self.shuffle:
if self.num_consumed_samples > 0: # 需要先按照原来的排序,删掉多余的
_batches = []
for _i in range(self.old_num_replicas):
_indices = indices[_i:len(indices):self.old_num_replicas]
__batches = self.batchify(_indices, self.old_batch_size, seed=self.seed + self.epoch)
_batches.append(__batches)
batches = list(chain(*[_ for _ in zip(*_batches)]))
indices = list(chain(*batches))
indices = indices[self.num_consumed_samples:]
# 取出这个 rank
indices = indices[self.rank:len(indices):self.num_replicas]
batches = self.batchify(indices, self.batch_size, seed=self.seed + self.epoch)
batches = list(map(list, batches))
else:
indices = indices[self.num_consumed_samples:]
indices = indices[self.rank:len(indices):self.num_replicas]
_num_batches = len(indices) // self.batch_size
if _num_batches == 0:
batches = [indices]
else:
batches = list(map(list, np.array_split(indices[:_num_batches*self.batch_size], _num_batches)))
if len(indices)%self.batch_size!=0:
batches.append(indices[_num_batches*self.batch_size:])
need_pad_num = (len(self.dataset)-self.num_consumed_samples) % self.num_replicas
if self.pad and need_pad_num !=0 and need_pad_num<=self.rank:
if len(batches) > 0:
if len(batches[-1])<self.batch_size:
batches[-1].append(batches[-1][0]) # 这里可以保证这个bucket的长度没被破坏。
else:
batches.append([batches[-1][0]])
elif self.pad is False and need_pad_num !=0 and need_pad_num>self.rank:
if len(batches):
batches[-1].pop(-1)
if len(batches[-1])==0:
batches.pop(-1)
assert sum(map(len, batches)) == self.num_left_samples
if self.drop_last and len(batches) >= 1 and len(batches[-1]) < self.batch_size:
batches = batches[:-1]
for batch in batches:
self.num_consumed_samples += self.num_replicas * len(batch)
yield list(map(int, batch))
self.during_iter = False
self.num_consumed_samples = 0
self.old_batch_size = self.batch_size
self.old_num_replicas = self.num_replicas
if self.epoch < 0: # 防止用户没有修改epoch导致每个epoch都一样了
self.epoch -= 1
def batchify(self, indices, batch_size, seed):
"""
indices 分为 batches
:param sorted_indices: List[int]
:param batch_size: int
:param seed: int
:return: List[List[int]]
"""
# 实际的 bucket 大小
rng = np.random.default_rng(abs(seed))
rng.shuffle(indices)
num_samples = 0
batches = []
while num_samples<len(indices):
batches.append(indices[num_samples:num_samples+batch_size])
num_samples += batch_size
return batches
def set_epoch(self, epoch):
self.epoch = epoch
@property
def batch_idx_in_epoch(self):
if self.drop_last:
return len(self.dataset) // self.num_replicas // self.batch_size - self.num_left_samples // self.batch_size
else:
return (len(self.dataset) // self.num_replicas + self.batch_size - 1) // self.batch_size - \
(self.num_left_samples + self.batch_size - 1) // self.batch_size
@property
def total_size(self):
"""
这个变量代表的含义是当前这个sampler会最终产生出的index数量包括了其它rank的因为replica和pad的原因这个值可能等于
大于或者小于len(dataset)
:return:
"""
return self.num_consumed_samples + self.num_replicas*self.num_left_samples
@property
def num_left_samples(self):
"""
返回当前 iteration 还有多少个 sample 结束表示的是当前 rank 的还剩多少
:return:
"""
num_consumed_samples = self.num_consumed_samples
return math.ceil((len(self.dataset) - num_consumed_samples) / self.num_replicas) if \
self.pad else math.floor(((len(self.dataset) - num_consumed_samples) / self.num_replicas))
def __len__(self)->int:
"""
返回当前 sampler 还会返回多少个 batch 的数据
:return:
"""
num_sampler_per_rank = self.total_size//self.num_replicas
num_batches = num_sampler_per_rank//self.batch_size if self.drop_last else \
(num_sampler_per_rank+self.batch_size-1)//self.batch_size
return num_batches
def state_dict(self) -> Dict:
if self.old_batch_size != self.batch_size:
raise RuntimeError("BucketedBatchSampler does not support saving before last checkpoint states have been"
" consumed. ")
states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples,
'sampler_type': self.__class__.__name__, 'length': len(self.dataset), 'shuffle': self.shuffle,
'batch_size': self.batch_size,
'num_replicas': self.num_replicas}
return states
def load_state_dict(self, states: Dict):
# 如果 self.during_iter 是 True那么 num_consumed_samples 一定是 0
assert self.during_iter is False, "Cannot call load_state_dict() when it is " \
"during an unfinished iteration."
assert states['sampler_type'] == self.__class__.__name__, f"The sampler type in checkpoint is {states['sampler_type']}," \
f"we cannot use {self.__class__.__name__} to load it."
length = states['length']
assert length == len(self.dataset), "The number of samples is different between the checkpoint record " \
"and current dataset."
self.seed = states['seed']
self.epoch = states['epoch']
self.num_consumed_samples = states['num_consumed_samples']
if self.num_consumed_samples>=length: # 如果保存的时候已经到达了最后一个sample了则直接将结果重置为0
self.num_consumed_samples = 0
if self.shuffle != states['shuffle']:
logger.info(f"The shuffle from the checkpoint is {states['shuffle']}, while set as {self.shuffle}, "
f"we use shuffle={states['shuffle']}")
self.shuffle = states["shuffle"]
self.old_batch_size = states['batch_size']
self.old_num_replicas = states['num_replicas']
class BucketedBatchSampler(ReproducibleBatchSampler):
def __init__(self, dataset, length: Union[List[int], str], batch_size:int = 32, num_batch_per_bucket:int = 10,
shuffle: bool = True, drop_last: bool = False, seed: int = 0, **kwargs):

View File

@ -16,6 +16,8 @@ from fastNLP.core.dataset import DataSet
class ReproducibleSampler:
"""
可复现的 Sampler 对象
注意所有继承 `ReproducibleSampler` 的类的 `__init__` 方法中都需要加入参数 `**kwargs`用来使我们再断点重训时重新实例化这个 sampler
或者 batch_sampler注意所有在 init 中初始化的变量都不能含有 _ 下横线作为开头所有不在 init 中设置的变量都必须以下横线开头
@ -54,13 +56,12 @@ class RandomSampler(ReproducibleSampler):
def __init__(self, dataset, shuffle: bool = True, seed: int = 0, **kwargs):
"""
:param dataset: 实现了 __len__ 方法的数据容器
:param shuffle: 是否在每次 iterate 的时候打乱顺序
:param seed: 随机数种子
:param kwargs: 用户不需要使用fastNLP 内部使用
"""
super(RandomSampler, self).__init__()
self.dataset = dataset
self.shuffle = shuffle
self.seed = seed

View File

@ -21,7 +21,6 @@ __all__ = [
'nullcontext',
'pretty_table_printer',
'Option',
'indice_collate_wrapper',
'deprecated',
'seq_len_to_mask',
'rank_zero_rm',
@ -37,6 +36,7 @@ from .torch_paddle_utils import torch_paddle_move_data_to_device
from .torch_utils import torch_move_data_to_device
from .utils import get_fn_arg_names, auto_param_call, check_user_specific_params, \
dataclass_to_dict, match_and_substitute_params, apply_to_collection, nullcontext, pretty_table_printer, Option, \
indice_collate_wrapper, deprecated, seq_len_to_mask, rank_zero_rm, rank_zero_mkdir
deprecated, seq_len_to_mask, rank_zero_rm, rank_zero_mkdir
from ..dataloaders.utils import indice_collate_wrapper

View File

@ -1,5 +1,5 @@
import functools
class DummyClass:
def __call__(self, *args, **kwargs):
return
def __init__(self, *args, **kwargs):
pass

View File

@ -35,6 +35,7 @@ def paddle_to(data, device: Union[str, int]):
else:
return data.cuda(get_paddle_device_id(device))
def get_paddle_gpu_str(device: Union[str, int]):
"""
获得 `gpu:x` 类型的设备名
@ -46,6 +47,7 @@ def get_paddle_gpu_str(device: Union[str, int]):
return device.replace("cuda", "gpu")
return f"gpu:{device}"
def get_paddle_device_id(device: Union[str, int]):
"""
获得 gpu 的设备id
@ -94,18 +96,21 @@ def paddle_move_data_to_device(batch: Any, device: Optional[str] = None,
return apply_to_collection(batch, dtype=paddle.Tensor, function=batch_to)
def is_in_paddle_dist():
"""
判断是否处于分布式的进程下使用 global_rank selected_gpus 判断
"""
return ('PADDLE_RANK_IN_NODE' in os.environ and 'FLAGS_selected_gpus' in os.environ)
def is_in_fnlp_paddle_dist():
"""
判断是否处于 FastNLP 拉起的分布式进程中
"""
return FASTNLP_DISTRIBUTED_CHECK in os.environ
def is_in_paddle_launch_dist():
"""
判断是否处于 launch 启动的分布式进程中

View File

@ -6,7 +6,7 @@ import warnings
from dataclasses import is_dataclass
from copy import deepcopy
from collections import defaultdict, OrderedDict
from typing import Callable, List, Any, Dict, AnyStr, Union, Mapping, Sequence, Optional
from typing import Callable, List, Any, Dict, AnyStr, Union, Mapping, Sequence
from typing import Tuple, Optional
from time import sleep
@ -35,7 +35,6 @@ __all__ = [
'nullcontext',
'pretty_table_printer',
'Option',
'indice_collate_wrapper',
'deprecated',
'seq_len_to_mask',
'rank_zero_rm',
@ -513,24 +512,6 @@ class Option(dict):
self.update(state)
def indice_collate_wrapper(func):
"""
其功能是封装一层collate_fn,将dataset取到的tuple数据分离开将idx打包为indices
:param func: 需要修饰的函数
:return:
"""
def wrapper(tuple_data):
indice, ins_list = [], []
for idx, ins in tuple_data:
indice.append(idx)
ins_list.append(ins)
return indice, func(ins_list)
return wrapper
_emitted_deprecation_warnings = set()

View File

@ -332,13 +332,44 @@ class DataBundle:
show_progress_bar=show_progress_bar, progress_desc=progress_desc)
return res
def set_pad_val(self, *field_names, val=0) -> None:
for _, ds in self.iter_datasets():
ds.set_pad_val(*field_names, val=val)
def set_pad(self, field_name, pad_val=0, dtype=None, backend=None, pad_fn=None) -> "DataBundle":
"""
如果需要对某个 field 的内容进行特殊的调整请使用这个函数
def set_input(self, *field_names) -> None:
:param field_name: 需要调整的 field 的名称如果 Dataset __getitem__ 方法返回的是 dict 类型的则可以直接使用对应的
field key 来表示如果是 nested dict可以使用元组表示多层次的 key例如 {'a': {'b': 1}} 中的使用 ('a', 'b');
如果 __getitem__ 返回的是 Sequence 类型的则可以使用 '_0', '_1' 表示序列中第 0 1 个元素如果该 field 在数据中没
有找到则报错如果 __getitem__ 返回的是就是整体内容请使用 "_single"
:param pad_val: 这个 field 的默认 pad 如果设置为 None则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad
field 进行 pad所以如果对应 field 本身就不是可以 pad 的形式可以不需要主动设置为 None 如果 backend None 该值
无意义
:param dtype: 对于需要 pad field field 的数据 dtype 应该是什么
:param backend: 可选['raw', 'numpy', 'torch', 'paddle', 'jittor', 'auto']分别代表输出为 list, numpy.ndarray,
torch.Tensor, paddle.Tensor, jittor.Var 类型 pad_val None 该值无意义
:param pad_fn: 指定当前 field pad 函数传入该函数则 pad_val, dtype, backend 等参数失效pad_fn 的输入为当前 field
batch 形式 Collator 将自动 unbatch 数据然后将各个 field 组成各自的 batch pad_func 的输入即为 field batch
形式输出将被直接作为结果输出
:return: self
"""
for _, ds in self.iter_datasets():
ds.set_input(*field_names)
ds.collator.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, backend=backend,
pad_fn=pad_fn)
return self
def set_ignore(self, *field_names) -> "DataBundle":
"""
如果有的内容不希望输出可以在此处进行设置被设置的 field 将在 batch 的输出中被忽略
Ex::
collator.set_ignore('field1', 'field2')
:param field_names: 需要忽略的 field 的名称如果 Dataset __getitem__ 方法返回的是 dict 类型的则可以直接使用对应的
field key 来表示如果是 nested dict可以使用元组来表示例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); 如果
__getitem__ 返回的是 Sequence 类型的则可以使用 '_0', '_1' 表示序列中第 0 1 个元素
:return: self
"""
for _, ds in self.iter_datasets():
ds.collator.set_ignore(*field_names)
return self
def __repr__(self) -> str:
_str = ''

View File

@ -0,0 +1,208 @@
import pytest
from functools import reduce
from fastNLP.core.callbacks.callback_event import Event, Filter
class TestFilter:
def test_every_filter(self):
# every = 10
@Filter(every=10)
def _fn(data):
return data
_res = []
for i in range(100):
cu_res = _fn(i)
if cu_res is not None:
_res.append(cu_res)
assert _res == [w-1 for w in range(10, 101, 10)]
# every = 1
@Filter(every=1)
def _fn(data):
return data
_res = []
for i in range(100):
cu_res = _fn(i)
if cu_res is not None:
_res.append(cu_res)
assert _res == list(range(100))
def test_once_filter(self):
# once = 10
@Filter(once=10)
def _fn(data):
return data
_res = []
for i in range(100):
cu_res = _fn(i)
if cu_res is not None:
_res.append(cu_res)
assert _res == [9]
def test_extract_filter_from_fn(self):
@Filter(every=10)
def _fn(data):
return data
_filter_num_called = []
_filter_num_executed = []
for i in range(100):
cu_res = _fn(i)
_filter = _fn.__fastNLP_filter__
_filter_num_called.append(_filter.num_called)
_filter_num_executed.append(_filter.num_executed)
assert _filter_num_called == list(range(1, 101))
assert _filter_num_executed == [0]*9 + reduce(lambda x, y: x+y, [[w]*10 for w in range(1, 10)]) + [10]
def _fn(data):
return data
assert not hasattr(_fn, "__fastNLP_filter__")
def test_filter_state_dict(self):
# every = 10
@Filter(every=10)
def _fn(data):
return data
_res = []
for i in range(50):
cu_res = _fn(i)
if cu_res is not None:
_res.append(cu_res)
assert _res == [w - 1 for w in range(10, 51, 10)]
# 保存状态
state = _fn.__fastNLP_filter__.state_dict()
# 加载状态
_fn.__fastNLP_filter__.load_state_dict(state)
_res = []
for i in range(50, 100):
cu_res = _fn(i)
if cu_res is not None:
_res.append(cu_res)
assert _res == [w - 1 for w in range(60, 101, 10)]
@pytest.mark.torch
def test_filter_fn_torch():
from torch.optim import SGD
from torch.utils.data import DataLoader
from fastNLP.core.controllers.trainer import Trainer
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification
model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10)
optimizer = SGD(model.parameters(), lr=0.0001)
dataset = TorchNormalDataset_Classification(num_labels=3, feature_dimension=10)
dataloader = DataLoader(dataset=dataset, batch_size=4)
trainer = Trainer(model=model, driver="torch", device="cpu", train_dataloader=dataloader, optimizers=optimizer)
def filter_fn(filter, trainer):
if trainer.__heihei_test__ == 10:
return True
return False
@Filter(filter_fn=filter_fn)
def _fn(trainer, data):
return data
_res = []
for i in range(100):
trainer.__heihei_test__ = i
cu_res = _fn(trainer, i)
if cu_res is not None:
_res.append(cu_res)
assert _res == [10]
class TestCallbackEvents:
def test_every(self):
# 这里是什么样的事件是不影响的,因为我们是与 Trainer 拆分开了进行测试;
event_state = Event.on_train_begin() # 什么都不输入是应当默认 every=1
@Filter(every=event_state.every, once=event_state.once, filter_fn=event_state.filter_fn)
def _fn(data):
return data
_res = []
for i in range(100):
cu_res = _fn(i)
if cu_res is not None:
_res.append(cu_res)
assert _res == list(range(100))
event_state = Event.on_train_begin(every=10)
@Filter(every=event_state.every, once=event_state.once, filter_fn=event_state.filter_fn)
def _fn(data):
return data
_res = []
for i in range(100):
cu_res = _fn(i)
if cu_res is not None:
_res.append(cu_res)
assert _res == [w - 1 for w in range(10, 101, 10)]
def test_once(self):
event_state = Event.on_train_begin(once=10)
@Filter(once=event_state.once)
def _fn(data):
return data
_res = []
for i in range(100):
cu_res = _fn(i)
if cu_res is not None:
_res.append(cu_res)
assert _res == [9]
@pytest.mark.torch
def test_callback_events_torch():
from torch.optim import SGD
from torch.utils.data import DataLoader
from fastNLP.core.controllers.trainer import Trainer
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification
model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10)
optimizer = SGD(model.parameters(), lr=0.0001)
dataset = TorchNormalDataset_Classification(num_labels=3, feature_dimension=10)
dataloader = DataLoader(dataset=dataset, batch_size=4)
trainer = Trainer(model=model, driver="torch", device="cpu", train_dataloader=dataloader, optimizers=optimizer)
def filter_fn(filter, trainer):
if trainer.__heihei_test__ == 10:
return True
return False
event_state = Event.on_train_begin(filter_fn=filter_fn)
@Filter(filter_fn=event_state.filter_fn)
def _fn(trainer, data):
return data
_res = []
for i in range(100):
trainer.__heihei_test__ = i
cu_res = _fn(trainer, i)
if cu_res is not None:
_res.append(cu_res)
assert _res == [10]

View File

@ -1,157 +0,0 @@
import pytest
from functools import reduce
from fastNLP.core.callbacks.callback_events import Events, Filter
class TestFilter:
def test_params_check(self):
# 顺利通过
_filter1 = Filter(every=10)
_filter2 = Filter(once=10)
_filter3 = Filter(filter_fn=lambda: None)
# 触发 ValueError
with pytest.raises(ValueError) as e:
_filter4 = Filter()
exec_msg = e.value.args[0]
assert exec_msg == "If you mean your decorated function should be called every time, you do not need this filter."
# 触发 ValueError
with pytest.raises(ValueError) as e:
_filter5 = Filter(every=10, once=10)
exec_msg = e.value.args[0]
assert exec_msg == "These three values should be only set one."
# 触发 TypeError
with pytest.raises(ValueError) as e:
_filter6 = Filter(every="heihei")
exec_msg = e.value.args[0]
assert exec_msg == "Argument every should be integer and greater than zero"
# 触发 TypeError
with pytest.raises(ValueError) as e:
_filter7 = Filter(once="heihei")
exec_msg = e.value.args[0]
assert exec_msg == "Argument once should be integer and positive"
# 触发 TypeError
with pytest.raises(TypeError) as e:
_filter7 = Filter(filter_fn="heihei")
exec_msg = e.value.args[0]
assert exec_msg == "Argument event_filter should be a callable"
def test_every_filter(self):
# every = 10
@Filter(every=10)
def _fn(data):
return data
_res = []
for i in range(100):
cu_res = _fn(i)
if cu_res is not None:
_res.append(cu_res)
assert _res == [w-1 for w in range(10, 101, 10)]
# every = 1
@Filter(every=1)
def _fn(data):
return data
_res = []
for i in range(100):
cu_res = _fn(i)
if cu_res is not None:
_res.append(cu_res)
assert _res == list(range(100))
def test_once_filter(self):
# once = 10
@Filter(once=10)
def _fn(data):
return data
_res = []
for i in range(100):
cu_res = _fn(i)
if cu_res is not None:
_res.append(cu_res)
assert _res == [9]
def test_filter_fn(self):
from torch.optim import SGD
from torch.utils.data import DataLoader
from fastNLP.core.controllers.trainer import Trainer
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification
model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10)
optimizer = SGD(model.parameters(), lr=0.0001)
dataset = TorchNormalDataset_Classification(num_labels=3, feature_dimension=10)
dataloader = DataLoader(dataset=dataset, batch_size=4)
trainer = Trainer(model=model, driver="torch", device="cpu", train_dataloader=dataloader, optimizers=optimizer)
def filter_fn(filter, trainer):
if trainer.__heihei_test__ == 10:
return True
return False
@Filter(filter_fn=filter_fn)
def _fn(trainer, data):
return data
_res = []
for i in range(100):
trainer.__heihei_test__ = i
cu_res = _fn(trainer, i)
if cu_res is not None:
_res.append(cu_res)
assert _res == [10]
def test_extract_filter_from_fn(self):
@Filter(every=10)
def _fn(data):
return data
_filter_num_called = []
_filter_num_executed = []
for i in range(100):
cu_res = _fn(i)
_filter = _fn.__fastNLP_filter__
_filter_num_called.append(_filter.num_called)
_filter_num_executed.append(_filter.num_executed)
assert _filter_num_called == list(range(1, 101))
assert _filter_num_executed == [0]*9 + reduce(lambda x, y: x+y, [[w]*10 for w in range(1, 10)]) + [10]
def _fn(data):
return data
assert not hasattr(_fn, "__fastNLP_filter__")
def test_filter_state_dict(self):
# every = 10
@Filter(every=10)
def _fn(data):
return data
_res = []
for i in range(50):
cu_res = _fn(i)
if cu_res is not None:
_res.append(cu_res)
assert _res == [w - 1 for w in range(10, 51, 10)]
# 保存状态
state = _fn.__fastNLP_filter__.state_dict()
# 加载状态
_fn.__fastNLP_filter__.load_state_dict(state)
_res = []
for i in range(50, 100):
cu_res = _fn(i)
if cu_res is not None:
_res.append(cu_res)
assert _res == [w - 1 for w in range(60, 101, 10)]

View File

@ -2,9 +2,6 @@ import os
import pytest
from typing import Any
from dataclasses import dataclass
from torch.utils.data import DataLoader
from torch.optim import SGD
import torch.distributed as dist
from pathlib import Path
import re
import time
@ -20,6 +17,11 @@ from tests.helpers.datasets.torch_data import TorchArgMaxDataset
from torchmetrics import Accuracy
from fastNLP.core.log import logger
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
from torch.utils.data import DataLoader
from torch.optim import SGD
import torch.distributed as dist
@dataclass
class ArgMaxDatasetConfig:
@ -216,9 +218,9 @@ def test_model_checkpoint_callback_2(
path = Path.cwd().joinpath("test_model_checkpoint")
path.mkdir(exist_ok=True, parents=True)
from fastNLP.core.callbacks.callback_events import Events
from fastNLP.core.callbacks.callback_event import Event
@Trainer.on(Events.on_train_epoch_end)
@Trainer.on(Event.on_train_epoch_end())
def raise_exception(trainer):
if trainer.driver.get_local_rank() == 0 and trainer.cur_epoch_idx == 4:
raise NotImplementedError
@ -550,7 +552,7 @@ def test_trainer_checkpoint_callback_2(
if version == 0:
callbacks = [
TrainerCheckpointCallback(
CheckpointCallback(
monitor="acc",
folder=path,
every_n_epochs=None,
@ -558,12 +560,13 @@ def test_trainer_checkpoint_callback_2(
topk=None,
last=False,
on_exception=None,
model_save_fn=model_save_fn
model_save_fn=model_save_fn,
save_object="trainer"
)
]
elif version == 1:
callbacks = [
TrainerCheckpointCallback(
CheckpointCallback(
monitor="acc",
folder=path,
every_n_epochs=None,
@ -571,7 +574,8 @@ def test_trainer_checkpoint_callback_2(
topk=1,
last=True,
on_exception=None,
model_save_fn=model_save_fn
model_save_fn=model_save_fn,
save_object="trainer"
)
]

View File

@ -12,9 +12,7 @@ import os
import pytest
from typing import Any
from dataclasses import dataclass
from torch.utils.data import DataLoader
from torch.optim import SGD
import torch.distributed as dist
from pathlib import Path
import re
@ -29,7 +27,11 @@ from torchmetrics import Accuracy
from fastNLP.core.metrics import Metric
from fastNLP.core.log import logger
from fastNLP.core.callbacks import MoreEvaluateCallback
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
from torch.utils.data import DataLoader
from torch.optim import SGD
import torch.distributed as dist
@dataclass
class ArgMaxDatasetConfig:

View File

@ -17,12 +17,13 @@ def test_get_element_shape_dtype():
@pytest.mark.parametrize('backend', ['raw', None, 'numpy', 'torch', 'jittor', 'paddle'])
@pytest.mark.torch
@pytest.mark.paddle
@pytest.mark.jittor
def test_get_padder_run(backend):
if not _NEED_IMPORT_TORCH and backend == 'torch':
pytest.skip("No torch")
if not _NEED_IMPORT_PADDLE and backend == 'paddle':
pytest.skip("No paddle")
if not _NEED_IMPORT_PADDLE and backend == 'jittor':
if not _NEED_IMPORT_JITTOR and backend == 'jittor':
pytest.skip("No jittor")
batch_field = [1, 2, 3]
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test')
@ -66,6 +67,13 @@ def test_raw_padder():
pad_batch = padder(batch_field)
assert np.shape(pad_batch) == (3, 3, 2)
batch_field = [np.ones((3,3)), np.ones((2,3)), np.ones((1,0))]
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test')
pad_batch = padder(batch_field)
assert isinstance(pad_batch, list)
assert np.shape(pad_batch) == (3, 3, 3)
assert (pad_batch == np.zeros(np.shape(pad_batch))).sum()==12
def test_numpy_padder():
backend = 'numpy'
@ -140,3 +148,18 @@ def test_torch_padder():
with pytest.raises(InconsistencyError):
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test')
# 可以是 numpy.ndarray
batch_field = [np.ones((3,3)), np.ones((2,3)), np.ones((1,0))]
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test')
pad_batch = padder(batch_field)
assert isinstance(pad_batch, target_type)
assert pad_batch.shape == (3, 3, 3)
assert (pad_batch == torch.zeros(pad_batch.shape)).sum()==12
# 测试 to numpy
batch_field = [torch.ones((3,3)), torch.ones((2,3)), torch.ones((1,0))]
padder = get_padder(batch_field, pad_val=0, backend='numpy', dtype=int, field_name='test')
pad_batch = padder(batch_field)
assert isinstance(pad_batch, np.ndarray)
assert np.shape(pad_batch) == (3, 3, 3)
assert (pad_batch == np.zeros(np.shape(pad_batch))).sum()==12

View File

@ -1,7 +1,7 @@
import numpy as np
import pytest
from fastNLP.core.collators.padders.paddle_padder import paddleTensorPadder, paddleSequencePadder, paddleNumberPadder
from fastNLP.core.collators.padders.paddle_padder import PaddleTensorPadder, PaddleSequencePadder, PaddleNumberPadder
from fastNLP.core.collators.padders.exceptions import DtypeError
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
@ -10,9 +10,9 @@ if _NEED_IMPORT_PADDLE:
@pytest.mark.paddle
class TestpaddleNumberPadder:
class TestPaddleNumberPadder:
def test_run(self):
padder = paddleNumberPadder(ele_dtype=int, dtype=int, pad_val=-1)
padder = PaddleNumberPadder(ele_dtype=int, dtype=int, pad_val=-1)
a = [1, 2, 3]
t_a = padder(a)
assert isinstance(t_a, paddle.Tensor)
@ -20,9 +20,9 @@ class TestpaddleNumberPadder:
@pytest.mark.paddle
class TestpaddleSequencePadder:
class TestPaddleSequencePadder:
def test_run(self):
padder = paddleSequencePadder(ele_dtype=int, dtype=int, pad_val=-1)
padder = PaddleSequencePadder(ele_dtype=int, dtype=int, pad_val=-1)
a = [[1, 2, 3], [3]]
a = padder(a)
shape = a.shape
@ -32,20 +32,20 @@ class TestpaddleSequencePadder:
assert (a == b).sum().item() == shape[0]*shape[1]
def test_dtype_check(self):
padder = paddleSequencePadder(ele_dtype=np.zeros(3, dtype=np.int32).dtype, dtype=int, pad_val=-1)
padder = PaddleSequencePadder(ele_dtype=np.zeros(3, dtype=np.int32).dtype, dtype=int, pad_val=-1)
with pytest.raises(DtypeError):
padder = paddleSequencePadder(ele_dtype=str, dtype=int, pad_val=-1)
padder = paddleSequencePadder(ele_dtype='int64', dtype=int, pad_val=-1)
padder = paddleSequencePadder(ele_dtype=np.int32, dtype=None, pad_val=-1)
padder = PaddleSequencePadder(ele_dtype=str, dtype=int, pad_val=-1)
padder = PaddleSequencePadder(ele_dtype='int64', dtype=int, pad_val=-1)
padder = PaddleSequencePadder(ele_dtype=np.int32, dtype=None, pad_val=-1)
a = padder([[1], [2, 322]])
# assert (a>67).sum()==0 # 因为int8的范围为-67 - 66
padder = paddleSequencePadder(ele_dtype=np.zeros(2).dtype, dtype=None, pad_val=-1)
padder = PaddleSequencePadder(ele_dtype=np.zeros(2).dtype, dtype=None, pad_val=-1)
@pytest.mark.paddle
class TestpaddleTensorPadder:
class TestPaddleTensorPadder:
def test_run(self):
padder = paddleTensorPadder(ele_dtype=paddle.zeros((3,)).dtype, dtype=paddle.zeros((3,)).dtype, pad_val=-1)
padder = PaddleTensorPadder(ele_dtype=paddle.zeros((3,)).dtype, dtype=paddle.zeros((3,)).dtype, pad_val=-1)
a = [paddle.zeros((3,)), paddle.zeros((2,))]
a = padder(a)
shape = a.shape
@ -74,7 +74,7 @@ class TestpaddleTensorPadder:
[[0, -1], [-1, -1], [-1, -1]]])
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2]
padder = paddleTensorPadder(ele_dtype=paddle.zeros((3, )).dtype, dtype=paddle.zeros((3, )).dtype, pad_val=-1)
padder = PaddleTensorPadder(ele_dtype=paddle.zeros((3, )).dtype, dtype=paddle.zeros((3, )).dtype, pad_val=-1)
a = [paddle.zeros((3, 2)), paddle.zeros((2, 2))]
a = padder(a)
shape = a.shape
@ -85,7 +85,7 @@ class TestpaddleTensorPadder:
])
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2]
padder = paddleTensorPadder(ele_dtype=paddle.zeros((3, 2)).dtype, dtype=None, pad_val=-1)
padder = PaddleTensorPadder(ele_dtype=paddle.zeros((3, 2)).dtype, dtype=None, pad_val=-1)
a = [np.zeros((3, 2), dtype=np.float32), np.zeros((2, 2), dtype=np.float32)]
a = padder(a)
shape = a.shape
@ -96,11 +96,11 @@ class TestpaddleTensorPadder:
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2]
def test_dtype_check(self):
padder = paddleTensorPadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1)
padder = PaddleTensorPadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1)
with pytest.raises(DtypeError):
padder = paddleTensorPadder(ele_dtype=str, dtype=int, pad_val=-1)
padder = paddleTensorPadder(ele_dtype='int64', dtype=int, pad_val=-1)
padder = paddleTensorPadder(ele_dtype=int, dtype='int64', pad_val=-1)
padder = PaddleTensorPadder(ele_dtype=str, dtype=int, pad_val=-1)
padder = PaddleTensorPadder(ele_dtype='int64', dtype=int, pad_val=-1)
padder = PaddleTensorPadder(ele_dtype=int, dtype='int64', pad_val=-1)
def test_v1(self):
print(paddle.zeros((3, )).dtype)

View File

@ -23,7 +23,6 @@ class TestRawSequencePadder:
assert (a == b).sum().item() == shape[0]*shape[1]
def test_dtype_check(self):
with pytest.raises(DtypeError):
padder = RawSequencePadder(pad_val=-1, ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int)
padder = RawSequencePadder(pad_val=-1, ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int)
with pytest.raises(DtypeError):
padder = RawSequencePadder(pad_val=-1, ele_dtype=str, dtype=int)

View File

@ -1,81 +1,293 @@
import numpy as np
import pytest
from fastNLP.core.collators import AutoCollator
from fastNLP.core.collators.collator import _MultiCollator
from fastNLP.core.dataset import DataSet
from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _NEED_IMPORT_PADDLE, _NEED_IMPORT_JITTOR
from fastNLP.core.collators.collator import Collator
def _assert_equal(d1, d2):
try:
if 'torch' in str(type(d1)):
if 'float64' in str(d2.dtype):
print(d2.dtype)
assert (d1 == d2).all().item()
else:
assert all(d1 == d2)
except TypeError:
assert d1 == d2
except ValueError:
assert (d1 == d2).all()
def findDictDiff(d1, d2, path=""):
for k in d1:
if k in d2:
if isinstance(d1[k], dict):
findDictDiff(d1[k], d2[k], "%s -> %s" % (path, k) if path else k)
else:
_assert_equal(d1[k], d2[k])
else:
raise RuntimeError("%s%s as key not in d2\n" % ("%s: " % path if path else "", k))
def findListDiff(d1, d2):
assert len(d1)==len(d2)
for _d1, _d2 in zip(d1, d2):
if isinstance(_d1, list):
findListDiff(_d1, _d2)
else:
_assert_equal(_d1, _d2)
class TestCollator:
@pytest.mark.parametrize('as_numpy', [True, False])
def test_auto_collator(self, as_numpy):
"""
测试auto_collator的auto_pad功能
@pytest.mark.torch
def test_run(self):
dict_batch = [{
'str': '1',
'lst_str': ['1'],
'int': 1,
'lst_int': [1],
'nest_lst_int': [[1]],
'float': 1.1,
'lst_float': [1.1],
'bool': True,
'numpy': np.ones(1),
'dict': {'1': '1'},
'set': {'1'},
'nested_dict': {'a': 1, 'b':[1, 2]}
},
{
'str': '2',
'lst_str': ['2', '2'],
'int': 2,
'lst_int': [1, 2],
'nest_lst_int': [[1], [1, 2]],
'float': 2.1,
'lst_float': [2.1],
'bool': False,
'numpy': np.zeros(1),
'dict': {'1': '2'},
'set': {'2'},
'nested_dict': {'a': 2, 'b': [1, 2]}
}
]
list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}],
['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]]
raw_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': [1, 2], 'lst_int': [[1, 0], [1, 2]], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': [1, 2], 'b': [[1, 2], [1, 2]]}}
collator = Collator(backend='raw')
assert raw_pad_batch == collator(dict_batch)
collator = Collator(backend='raw')
raw_pad_lst = [['1', '2'], [['1'], ['2', '2']], [1, 2], [[1, 0], [2, 2]], [[[1, 0], [0, 0]], [[1, 0], [1, 2]]],
[1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}],
[{'1'}, {'2'}]]
findListDiff(raw_pad_lst, collator(list_batch))
collator = Collator(backend='numpy')
numpy_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': np.array([1, 2]), 'lst_int': np.array([[1, 0], [1, 2]]),
'nest_lst_int': np.array([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), 'float': np.array([1.1, 2.1]),
'lst_float': np.array([[1.1], [2.1]]), 'bool': np.array([True, False]), 'numpy': np.array([[1], [0]]),
'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': np.array([1, 2]),
'b': np.array([[1, 2], [1, 2]])}}
findDictDiff(numpy_pad_batch, collator(dict_batch))
collator = Collator(backend='numpy')
numpy_pad_lst = [['1', '2'], [['1'], ['2', '2']], np.array([1, 2]), np.array([[1, 0], [2, 2]]),
np.array([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]),
np.array([1.1, 2.1]), np.array([[1.1], [2.1]]), np.array([True, False]),
np.array([[1, 0], [1, 1]]), [{'1': '1'}, {'2': '2'}],
[{'1'}, {'2'}]]
findListDiff(numpy_pad_lst, collator(list_batch))
if _NEED_IMPORT_TORCH:
import torch
collator = Collator(backend='torch')
numpy_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': torch.LongTensor([1, 2]),
'lst_int': torch.LongTensor([[1, 0], [1, 2]]),
'nest_lst_int': torch.LongTensor([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]),
'float': torch.FloatTensor([1.1, 2.1]),
'lst_float': torch.FloatTensor([[1.1], [2.1]]), 'bool': torch.BoolTensor([True, False]),
'numpy': torch.FloatTensor([[1], [0]]),
'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': torch.LongTensor([1, 2]),
'b': torch.LongTensor(
[[1, 2], [1, 2]])}}
findDictDiff(numpy_pad_batch, collator(dict_batch))
collator = Collator(backend='torch')
torch_pad_lst = [['1', '2'], [['1'], ['2', '2']], torch.LongTensor([1, 2]), torch.LongTensor([[1, 0], [2, 2]]),
torch.LongTensor([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]),
torch.FloatTensor([1.1, 2.1]), torch.FloatTensor([[1.1], [2.1]]), torch.BoolTensor([True, False]),
torch.LongTensor([[1, 0], [1, 1]]), [{'1': '1'}, {'2': '2'}],
[{'1'}, {'2'}]]
findListDiff(torch_pad_lst, collator(list_batch))
def test_pad(self):
dict_batch = [{
'str': '1',
'lst_str': ['1'],
'int': 1,
'lst_int': [1],
'nest_lst_int': [[1]],
'float': 1.1,
'lst_float': [1.1],
'bool': True,
'numpy': np.ones(1),
'dict': {'1': '1'},
'set': {'1'},
'nested_dict': {'a': 1, 'b':[1, 2]}
},
{
'str': '2',
'lst_str': ['2', '2'],
'int': 2,
'lst_int': [1, 2],
'nest_lst_int': [[1], [1, 2]],
'float': 2.1,
'lst_float': [2.1],
'bool': False,
'numpy': np.zeros(1),
'dict': {'1': '2'},
'set': {'2'},
'nested_dict': {'a': 2, 'b': [1, 2]}
}
]
raw_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': [1, 2], 'lst_int': [[1, 0], [1, 2]], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': [1, 2], 'b': [[1, 2], [1, 2]]}}
# 测试 ignore
collator = Collator(backend='raw')
collator.set_ignore('str', 'int', 'lst_int', ('nested_dict', 'a'))
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}}
findDictDiff(raw_pad_batch, collator(dict_batch))
# 测试 set_pad
collator = Collator(backend='raw')
collator.set_pad('str', pad_val=1)
with pytest.raises(BaseException):
collator(dict_batch)
# 测试设置 pad 值
collator = Collator(backend='raw')
collator.set_pad('nest_lst_int', pad_val=100)
collator.set_ignore('str', 'int', 'lst_int', ('nested_dict','a'))
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 100], [100, 100]], [[1, 100], [1, 2]]],
'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}}
findDictDiff(raw_pad_batch, collator(dict_batch))
# 设置 backend 和 type
collator.set_pad('float', pad_val=100, backend='numpy', dtype=int)
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 100], [100, 100]], [[1, 100], [1, 2]]],
'float': np.array([1, 2]), 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}}
findDictDiff(raw_pad_batch, collator(dict_batch))
# raw_pad_lst = [['1', '2'], [['1'], ['2', '2']], [1, 2], [[1, 0], [2, 2]], [[[1, 0], [0, 0]], [[1, 0], [1, 2]]],
# [1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}],
# [{'1'}, {'2'}]]
list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}],
['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]]
collator = Collator(backend='raw')
collator.set_ignore('_0', '_3', '_1')
collator.set_pad('_4', pad_val=None)
raw_pad_lst = [[1, 2], [[[1]], [[1], [1, 2]]],
[1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}],
[{'1'}, {'2'}]]
findListDiff(raw_pad_lst, collator(list_batch))
collator = Collator(backend='raw')
collator.set_pad('_0', pad_val=1)
with pytest.raises(BaseException):
collator(dict_batch)
list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}],
['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]]
collator = Collator(backend='raw')
collator.set_ignore('_0', '_3', '_1')
collator.set_pad('_2', backend='numpy')
collator.set_pad('_4', backend='numpy', pad_val=100)
raw_pad_lst = [np.array([1, 2]), np.array([[[1, 100], [100, 100]], [[1, 100], [1, 2]]]),
[1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}],
[{'1'}, {'2'}]]
findListDiff(raw_pad_lst, collator(list_batch))
# _single
collator = Collator()
collator.set_pad('_single')
findListDiff(list_batch, collator(list_batch))
def test_nest_ignore(self):
dict_batch = [{
'str': '1',
'lst_str': ['1'],
'int': 1,
'lst_int': [1],
'nest_lst_int': [[1]],
'float': 1.1,
'lst_float': [1.1],
'bool': True,
'numpy': np.ones(1),
'dict': {'1': '1'},
'set': {'1'},
'nested_dict': {'int': 1, 'lst_int':[1, 2], 'c': {'int': 1}}
},
{
'str': '2',
'lst_str': ['2', '2'],
'int': 2,
'lst_int': [1, 2],
'nest_lst_int': [[1], [1, 2]],
'float': 2.1,
'lst_float': [2.1],
'bool': False,
'numpy': np.zeros(1),
'dict': {'1': '2'},
'set': {'2'},
'nested_dict': {'int': 1, 'lst_int': [1, 2], 'c': {'int': 1}}
}
]
# 测试 ignore
collator = Collator(backend='raw')
collator.set_ignore('str', 'int', 'lst_int', ('nested_dict', 'int'))
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]],
'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False],
'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']},
'set': [{'1'}, {'2'}], 'nested_dict': {'lst_int': [[1, 2], [1, 2]],
'c': {'int':[1, 1]}}}
findDictDiff(raw_pad_batch, collator(dict_batch))
collator = Collator(backend='raw')
collator.set_pad(('nested_dict', 'c'), pad_val=None)
collator.set_ignore('str', 'int', 'lst_int')
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]],
'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False],
'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']},
'set': [{'1'}, {'2'}], 'nested_dict': {'lst_int': [[1, 2], [1, 2]],
'c': [{'int':1}, {'int':1}]}}
pad_batch = collator(dict_batch)
findDictDiff(raw_pad_batch, pad_batch)
collator = Collator(backend='raw')
collator.set_pad(('nested_dict', 'c'), pad_val=1)
with pytest.raises(BaseException):
collator(dict_batch)
collator = Collator(backend='raw')
collator.set_ignore('str', 'int', 'lst_int')
collator.set_pad(('nested_dict', 'c'), pad_fn=lambda x: [d['int'] for d in x])
pad_batch = collator(dict_batch)
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]],
'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False],
'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']},
'set': [{'1'}, {'2'}], 'nested_dict': {'lst_int': [[1, 2], [1, 2]],
'c': [1, 1]}}
findDictDiff(raw_pad_batch, pad_batch)
:param as_numpy:
:return:
"""
dataset = DataSet({'x': [[1, 2], [0, 1, 2, 3], [3], [9, 0, 10, 1, 5]] * 100,
'y': [0, 1, 1, 0] * 100})
collator = AutoCollator(as_numpy=as_numpy)
collator.set_input('x', 'y')
bucket_data = []
data = []
for i in range(len(dataset)):
data.append(dataset[i])
if len(data) == 40:
bucket_data.append(data)
data = []
results = []
for bucket in bucket_data:
res = collator(bucket)
assert res['x'].shape == (40, 5)
assert res['y'].shape == (40,)
results.append(res)
def test_auto_collator_v1(self):
"""
测试auto_collator的set_pad_val和set_pad_val功能
:return:
"""
dataset = DataSet({'x': [[1, 2], [0, 1, 2, 3], [3], [9, 0, 10, 1, 5]] * 100,
'y': [0, 1, 1, 0] * 100})
collator = AutoCollator(as_numpy=False)
collator.set_input('x')
collator.set_pad_val('x', val=-1)
collator.set_as_numpy(True)
bucket_data = []
data = []
for i in range(len(dataset)):
data.append(dataset[i])
if len(data) == 40:
bucket_data.append(data)
data = []
for bucket in bucket_data:
res = collator(bucket)
print(res)
def test_multicollator(self):
"""
测试multicollator功能
:return:
"""
dataset = DataSet({'x': [[1, 2], [0, 1, 2, 3], [3], [9, 0, 10, 1, 5]] * 100,
'y': [0, 1, 1, 0] * 100})
collator = AutoCollator(as_numpy=False)
multi_collator = _MultiCollator(collator)
multi_collator.set_as_numpy(as_numpy=True)
multi_collator.set_pad_val('x', val=-1)
multi_collator.set_input('x')
bucket_data = []
data = []
for i in range(len(dataset)):
data.append(dataset[i])
if len(data) == 40:
bucket_data.append(data)
data = []
for bucket in bucket_data:
res = multi_collator(bucket)
print(res)

View File

@ -1,293 +0,0 @@
import numpy as np
import pytest
from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _NEED_IMPORT_PADDLE, _NEED_IMPORT_JITTOR
from fastNLP.core.collators.new_collator import Collator
def _assert_equal(d1, d2):
try:
if 'torch' in str(type(d1)):
if 'float64' in str(d2.dtype):
print(d2.dtype)
assert (d1 == d2).all().item()
else:
assert all(d1 == d2)
except TypeError:
assert d1 == d2
except ValueError:
assert (d1 == d2).all()
def findDictDiff(d1, d2, path=""):
for k in d1:
if k in d2:
if isinstance(d1[k], dict):
findDictDiff(d1[k], d2[k], "%s -> %s" % (path, k) if path else k)
else:
_assert_equal(d1[k], d2[k])
else:
raise RuntimeError("%s%s as key not in d2\n" % ("%s: " % path if path else "", k))
def findListDiff(d1, d2):
assert len(d1)==len(d2)
for _d1, _d2 in zip(d1, d2):
if isinstance(_d1, list):
findListDiff(_d1, _d2)
else:
_assert_equal(_d1, _d2)
class TestCollator:
@pytest.mark.torch
def test_run(self):
dict_batch = [{
'str': '1',
'lst_str': ['1'],
'int': 1,
'lst_int': [1],
'nest_lst_int': [[1]],
'float': 1.1,
'lst_float': [1.1],
'bool': True,
'numpy': np.ones(1),
'dict': {'1': '1'},
'set': {'1'},
'nested_dict': {'a': 1, 'b':[1, 2]}
},
{
'str': '2',
'lst_str': ['2', '2'],
'int': 2,
'lst_int': [1, 2],
'nest_lst_int': [[1], [1, 2]],
'float': 2.1,
'lst_float': [2.1],
'bool': False,
'numpy': np.zeros(1),
'dict': {'1': '2'},
'set': {'2'},
'nested_dict': {'a': 2, 'b': [1, 2]}
}
]
list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}],
['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]]
raw_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': [1, 2], 'lst_int': [[1, 0], [1, 2]], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': [1, 2], 'b': [[1, 2], [1, 2]]}}
collator = Collator(backend='raw')
assert raw_pad_batch == collator(dict_batch)
collator = Collator(backend='raw')
raw_pad_lst = [['1', '2'], [['1'], ['2', '2']], [1, 2], [[1, 0], [2, 2]], [[[1, 0], [0, 0]], [[1, 0], [1, 2]]],
[1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}],
[{'1'}, {'2'}]]
findListDiff(raw_pad_lst, collator(list_batch))
collator = Collator(backend='numpy')
numpy_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': np.array([1, 2]), 'lst_int': np.array([[1, 0], [1, 2]]),
'nest_lst_int': np.array([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), 'float': np.array([1.1, 2.1]),
'lst_float': np.array([[1.1], [2.1]]), 'bool': np.array([True, False]), 'numpy': np.array([[1], [0]]),
'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': np.array([1, 2]),
'b': np.array([[1, 2], [1, 2]])}}
findDictDiff(numpy_pad_batch, collator(dict_batch))
collator = Collator(backend='numpy')
numpy_pad_lst = [['1', '2'], [['1'], ['2', '2']], np.array([1, 2]), np.array([[1, 0], [2, 2]]),
np.array([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]),
np.array([1.1, 2.1]), np.array([[1.1], [2.1]]), np.array([True, False]),
np.array([[1, 0], [1, 1]]), [{'1': '1'}, {'2': '2'}],
[{'1'}, {'2'}]]
findListDiff(numpy_pad_lst, collator(list_batch))
if _NEED_IMPORT_TORCH:
import torch
collator = Collator(backend='torch')
numpy_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': torch.LongTensor([1, 2]),
'lst_int': torch.LongTensor([[1, 0], [1, 2]]),
'nest_lst_int': torch.LongTensor([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]),
'float': torch.FloatTensor([1.1, 2.1]),
'lst_float': torch.FloatTensor([[1.1], [2.1]]), 'bool': torch.BoolTensor([True, False]),
'numpy': torch.FloatTensor([[1], [0]]),
'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': torch.LongTensor([1, 2]),
'b': torch.LongTensor(
[[1, 2], [1, 2]])}}
findDictDiff(numpy_pad_batch, collator(dict_batch))
collator = Collator(backend='torch')
torch_pad_lst = [['1', '2'], [['1'], ['2', '2']], torch.LongTensor([1, 2]), torch.LongTensor([[1, 0], [2, 2]]),
torch.LongTensor([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]),
torch.FloatTensor([1.1, 2.1]), torch.FloatTensor([[1.1], [2.1]]), torch.BoolTensor([True, False]),
torch.LongTensor([[1, 0], [1, 1]]), [{'1': '1'}, {'2': '2'}],
[{'1'}, {'2'}]]
findListDiff(torch_pad_lst, collator(list_batch))
def test_pad(self):
dict_batch = [{
'str': '1',
'lst_str': ['1'],
'int': 1,
'lst_int': [1],
'nest_lst_int': [[1]],
'float': 1.1,
'lst_float': [1.1],
'bool': True,
'numpy': np.ones(1),
'dict': {'1': '1'},
'set': {'1'},
'nested_dict': {'a': 1, 'b':[1, 2]}
},
{
'str': '2',
'lst_str': ['2', '2'],
'int': 2,
'lst_int': [1, 2],
'nest_lst_int': [[1], [1, 2]],
'float': 2.1,
'lst_float': [2.1],
'bool': False,
'numpy': np.zeros(1),
'dict': {'1': '2'},
'set': {'2'},
'nested_dict': {'a': 2, 'b': [1, 2]}
}
]
raw_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': [1, 2], 'lst_int': [[1, 0], [1, 2]], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': [1, 2], 'b': [[1, 2], [1, 2]]}}
# 测试 ignore
collator = Collator(backend='raw')
collator.set_ignore('str', 'int', 'lst_int', ('nested_dict', 'a'))
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}}
findDictDiff(raw_pad_batch, collator(dict_batch))
# 测试 set_pad
collator = Collator(backend='raw')
collator.set_pad('str', pad_val=1)
with pytest.raises(BaseException):
collator(dict_batch)
# 测试设置 pad 值
collator = Collator(backend='raw')
collator.set_pad('nest_lst_int', pad_val=100)
collator.set_ignore('str', 'int', 'lst_int', ('nested_dict','a'))
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 100], [100, 100]], [[1, 100], [1, 2]]],
'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}}
findDictDiff(raw_pad_batch, collator(dict_batch))
# 设置 backend 和 type
collator.set_pad('float', pad_val=100, backend='numpy', dtype=int)
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 100], [100, 100]], [[1, 100], [1, 2]]],
'float': np.array([1, 2]), 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}}
findDictDiff(raw_pad_batch, collator(dict_batch))
# raw_pad_lst = [['1', '2'], [['1'], ['2', '2']], [1, 2], [[1, 0], [2, 2]], [[[1, 0], [0, 0]], [[1, 0], [1, 2]]],
# [1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}],
# [{'1'}, {'2'}]]
list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}],
['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]]
collator = Collator(backend='raw')
collator.set_ignore('_0', '_3', '_1')
collator.set_pad('_4', pad_val=None)
raw_pad_lst = [[1, 2], [[[1]], [[1], [1, 2]]],
[1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}],
[{'1'}, {'2'}]]
findListDiff(raw_pad_lst, collator(list_batch))
collator = Collator(backend='raw')
collator.set_pad('_0', pad_val=1)
with pytest.raises(BaseException):
collator(dict_batch)
list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}],
['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]]
collator = Collator(backend='raw')
collator.set_ignore('_0', '_3', '_1')
collator.set_pad('_2', backend='numpy')
collator.set_pad('_4', backend='numpy', pad_val=100)
raw_pad_lst = [np.array([1, 2]), np.array([[[1, 100], [100, 100]], [[1, 100], [1, 2]]]),
[1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}],
[{'1'}, {'2'}]]
findListDiff(raw_pad_lst, collator(list_batch))
# _single
collator = Collator()
collator.set_pad('_single')
findListDiff(list_batch, collator(list_batch))
def test_nest_ignore(self):
dict_batch = [{
'str': '1',
'lst_str': ['1'],
'int': 1,
'lst_int': [1],
'nest_lst_int': [[1]],
'float': 1.1,
'lst_float': [1.1],
'bool': True,
'numpy': np.ones(1),
'dict': {'1': '1'},
'set': {'1'},
'nested_dict': {'int': 1, 'lst_int':[1, 2], 'c': {'int': 1}}
},
{
'str': '2',
'lst_str': ['2', '2'],
'int': 2,
'lst_int': [1, 2],
'nest_lst_int': [[1], [1, 2]],
'float': 2.1,
'lst_float': [2.1],
'bool': False,
'numpy': np.zeros(1),
'dict': {'1': '2'},
'set': {'2'},
'nested_dict': {'int': 1, 'lst_int': [1, 2], 'c': {'int': 1}}
}
]
# 测试 ignore
collator = Collator(backend='raw')
collator.set_ignore('str', 'int', 'lst_int', ('nested_dict', 'int'))
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]],
'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False],
'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']},
'set': [{'1'}, {'2'}], 'nested_dict': {'lst_int': [[1, 2], [1, 2]],
'c': {'int':[1, 1]}}}
findDictDiff(raw_pad_batch, collator(dict_batch))
collator = Collator(backend='raw')
collator.set_pad(('nested_dict', 'c'), pad_val=None)
collator.set_ignore('str', 'int', 'lst_int')
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]],
'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False],
'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']},
'set': [{'1'}, {'2'}], 'nested_dict': {'lst_int': [[1, 2], [1, 2]],
'c': [{'int':1}, {'int':1}]}}
pad_batch = collator(dict_batch)
findDictDiff(raw_pad_batch, pad_batch)
collator = Collator(backend='raw')
collator.set_pad(('nested_dict', 'c'), pad_val=1)
with pytest.raises(BaseException):
collator(dict_batch)
collator = Collator(backend='raw')
collator.set_ignore('str', 'int', 'lst_int')
collator.set_pad(('nested_dict', 'c'), pad_fn=lambda x: [d['int'] for d in x])
pad_batch = collator(dict_batch)
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]],
'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False],
'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']},
'set': [{'1'}, {'2'}], 'nested_dict': {'lst_int': [[1, 2], [1, 2]],
'c': [1, 1]}}
findDictDiff(raw_pad_batch, pad_batch)

View File

@ -1,17 +1,20 @@
import pytest
from typing import Any
from dataclasses import dataclass
from torch.optim import SGD
from torch.utils.data import DataLoader
from torchmetrics import Accuracy
import torch.distributed as dist
from fastNLP.core.controllers.trainer import Trainer
from fastNLP.core.callbacks.callback_events import Events
from fastNLP.core.callbacks.callback_event import Event
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification
from tests.helpers.callbacks.helper_callbacks import RecordTrainerEventTriggerCallback
from tests.helpers.utils import magic_argv_env_context, Capturing
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
from torch.optim import SGD
from torch.utils.data import DataLoader
from torchmetrics import Accuracy
import torch.distributed as dist
@dataclass
@ -62,12 +65,11 @@ def model_and_optimizers():
return trainer_params
@pytest.mark.torch
@pytest.mark.parametrize("driver,device", [("torch", "cpu")]) # , ("torch", 6), ("torch", [6, 7])
@pytest.mark.parametrize("callbacks", [[RecordTrainerEventTriggerCallback()]])
@pytest.mark.torch
@magic_argv_env_context
def test_trainer_event_trigger(
def test_trainer_event_trigger_1(
model_and_optimizers: TrainerParameters,
driver,
device,
@ -97,8 +99,215 @@ def test_trainer_event_trigger(
if dist.is_initialized():
dist.destroy_process_group()
for name, member in Events.__members__.items():
assert member.value in output[0]
Event_attrs = Event.__dict__
for k, v in Event_attrs.items():
if isinstance(v, staticmethod):
assert k in output[0]
@pytest.mark.torch
@pytest.mark.parametrize("driver,device", [("torch", "cpu")]) # , ("torch", 6), ("torch", [6, 7])
@magic_argv_env_context
def test_trainer_event_trigger_2(
model_and_optimizers: TrainerParameters,
driver,
device,
n_epochs=2,
):
@Trainer.on(Event.on_after_trainer_initialized())
def on_after_trainer_initialized(trainer, driver):
print("on_after_trainer_initialized")
@Trainer.on(Event.on_sanity_check_begin())
def on_sanity_check_begin(trainer):
print("on_sanity_check_begin")
@Trainer.on(Event.on_sanity_check_end())
def on_sanity_check_end(trainer, sanity_check_res):
print("on_sanity_check_end")
@Trainer.on(Event.on_train_begin())
def on_train_begin(trainer):
print("on_train_begin")
@Trainer.on(Event.on_train_end())
def on_train_end(trainer):
print("on_train_end")
@Trainer.on(Event.on_train_epoch_begin())
def on_train_epoch_begin(trainer):
if trainer.cur_epoch_idx >= 1:
# 触发 on_exception
raise Exception
print("on_train_epoch_begin")
@Trainer.on(Event.on_train_epoch_end())
def on_train_epoch_end(trainer):
print("on_train_epoch_end")
@Trainer.on(Event.on_fetch_data_begin())
def on_fetch_data_begin(trainer):
print("on_fetch_data_begin")
@Trainer.on(Event.on_fetch_data_end())
def on_fetch_data_end(trainer):
print("on_fetch_data_end")
@Trainer.on(Event.on_train_batch_begin())
def on_train_batch_begin(trainer, batch, indices=None):
print("on_train_batch_begin")
@Trainer.on(Event.on_train_batch_end())
def on_train_batch_end(trainer):
print("on_train_batch_end")
@Trainer.on(Event.on_exception())
def on_exception(trainer, exception):
print("on_exception")
@Trainer.on(Event.on_before_backward())
def on_before_backward(trainer, outputs):
print("on_before_backward")
@Trainer.on(Event.on_after_backward())
def on_after_backward(trainer):
print("on_after_backward")
@Trainer.on(Event.on_before_optimizers_step())
def on_before_optimizers_step(trainer, optimizers):
print("on_before_optimizers_step")
@Trainer.on(Event.on_after_optimizers_step())
def on_after_optimizers_step(trainer, optimizers):
print("on_after_optimizers_step")
@Trainer.on(Event.on_before_zero_grad())
def on_before_zero_grad(trainer, optimizers):
print("on_before_zero_grad")
@Trainer.on(Event.on_after_zero_grad())
def on_after_zero_grad(trainer, optimizers):
print("on_after_zero_grad")
@Trainer.on(Event.on_evaluate_begin())
def on_evaluate_begin(trainer):
print("on_evaluate_begin")
@Trainer.on(Event.on_evaluate_end())
def on_evaluate_end(trainer, results):
print("on_evaluate_end")
with pytest.raises(Exception):
with Capturing() as output:
trainer = Trainer(
model=model_and_optimizers.model,
driver=driver,
device=device,
optimizers=model_and_optimizers.optimizers,
train_dataloader=model_and_optimizers.train_dataloader,
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders,
input_mapping=model_and_optimizers.input_mapping,
output_mapping=model_and_optimizers.output_mapping,
metrics=model_and_optimizers.metrics,
n_epochs=n_epochs,
)
trainer.run()
if dist.is_initialized():
dist.destroy_process_group()
Event_attrs = Event.__dict__
for k, v in Event_attrs.items():
if isinstance(v, staticmethod):
assert k in output[0]
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 6)])
@pytest.mark.torch
@magic_argv_env_context
def test_trainer_event_trigger_3(
model_and_optimizers: TrainerParameters,
driver,
device,
n_epochs=2,
):
import re
once_message_1 = "This message should be typed 1 times."
once_message_2 = "test_filter_fn"
once_message_3 = "once message 3"
twice_message = "twice message hei hei"
@Trainer.on(Event.on_train_epoch_begin(every=2))
def train_epoch_begin_1(trainer):
print(once_message_1)
@Trainer.on(Event.on_train_epoch_begin())
def train_epoch_begin_2(trainer):
print(twice_message)
@Trainer.on(Event.on_train_epoch_begin(once=2))
def train_epoch_begin_3(trainer):
print(once_message_3)
def filter_fn(filter, trainer):
if trainer.cur_epoch_idx == 1:
return True
else:
return False
@Trainer.on(Event.on_train_epoch_end(filter_fn=filter_fn))
def test_filter_fn(trainer):
print(once_message_2)
with Capturing() as output:
trainer = Trainer(
model=model_and_optimizers.model,
driver=driver,
device=device,
optimizers=model_and_optimizers.optimizers,
train_dataloader=model_and_optimizers.train_dataloader,
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders,
input_mapping=model_and_optimizers.input_mapping,
output_mapping=model_and_optimizers.output_mapping,
metrics=model_and_optimizers.metrics,
n_epochs=n_epochs,
)
trainer.run()
if dist.is_initialized():
dist.destroy_process_group()
once_pattern_1 = re.compile(once_message_1)
once_pattern_2 = re.compile(once_message_2)
once_pattern_3 = re.compile(once_message_3)
twice_pattern = re.compile(twice_message)
once_res_1 = once_pattern_1.findall(output[0])
assert len(once_res_1) == 1
once_res_2 = once_pattern_2.findall(output[0])
assert len(once_res_2) == 1
once_res_3 = once_pattern_3.findall(output[0])
assert len(once_res_3) == 1
twice_res = twice_pattern.findall(output[0])
assert len(twice_res) == 2

View File

@ -1,22 +1,22 @@
import pytest
from fastNLP.core.controllers.trainer import Trainer
from fastNLP.core.callbacks import Events
from fastNLP.core.callbacks import Event
from tests.helpers.utils import magic_argv_env_context
@magic_argv_env_context
def test_trainer_torch_without_evaluator():
@Trainer.on(Events.on_train_epoch_begin(every=10))
@Trainer.on(Event.on_train_epoch_begin(every=10), marker="test_trainer_other_things")
def fn1(trainer):
pass
@Trainer.on(Events.on_train_batch_begin(every=10))
@Trainer.on(Event.on_train_batch_begin(every=10), marker="test_trainer_other_things")
def fn2(trainer, batch, indices):
pass
with pytest.raises(AssertionError):
@Trainer.on(Events.on_train_batch_begin(every=10))
with pytest.raises(BaseException):
@Trainer.on(Event.on_train_batch_begin(every=10), marker="test_trainer_other_things")
def fn3(trainer, batch):
pass

View File

@ -2,9 +2,7 @@
注意这一文件中的测试函数都应当是在 `test_trainer_w_evaluator_torch.py` 中已经测试过的测试函数的基础上加上 metrics evaluator 修改而成
"""
import pytest
from torch.optim import SGD
from torch.utils.data import DataLoader
import torch.distributed as dist
from dataclasses import dataclass
from typing import Any
from torchmetrics import Accuracy
@ -14,7 +12,11 @@ from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification, TorchArgMaxDataset
from tests.helpers.callbacks.helper_callbacks import RecordLossCallback, RecordMetricCallback
from tests.helpers.utils import magic_argv_env_context
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
from torch.optim import SGD
from torch.utils.data import DataLoader
import torch.distributed as dist
@dataclass
class NormalClassificationTrainTorchConfig:

View File

@ -2,9 +2,7 @@ import os.path
import subprocess
import sys
import pytest
import torch.distributed as dist
from torch.optim import SGD
from torch.utils.data import DataLoader
from dataclasses import dataclass
from typing import Any
from pathlib import Path
@ -16,6 +14,11 @@ from tests.helpers.callbacks.helper_callbacks import RecordLossCallback
from tests.helpers.callbacks.helper_callbacks_torch import RecordAccumulationStepsCallback_Torch
from tests.helpers.utils import magic_argv_env_context, Capturing
from fastNLP.core import rank_zero_rm
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
import torch.distributed as dist
from torch.optim import SGD
from torch.utils.data import DataLoader
@dataclass
@ -257,9 +260,9 @@ def test_trainer_on_exception(
cur_rank,
n_epochs=2,
):
from fastNLP.core.callbacks.callback_events import Events
from fastNLP.core.callbacks.callback_event import Event
@Trainer.on(Events.on_train_epoch_end)
@Trainer.on(Event.on_train_epoch_end())
def raise_exception(trainer):
if trainer.driver.get_local_rank() == cur_rank:
raise NotImplementedError
@ -286,6 +289,7 @@ def test_trainer_on_exception(
dist.destroy_process_group()
@pytest.mark.torch
@pytest.mark.parametrize("version", [0, 1, 2, 3])
@magic_argv_env_context
def test_torch_distributed_launch_1(version):

View File

@ -1,7 +1,7 @@
from functools import reduce
from fastNLP.core.controllers.utils.utils import _TruncatedDataLoader # TODO: 该类修改过,记得将 test 也修改;
from tests.helpers.datasets.normal_data import NormalIterator
from tests.helpers.datasets.normal_data import NormalSampler
class Test_WrapDataLoader:
@ -9,9 +9,9 @@ class Test_WrapDataLoader:
def test_normal_generator(self):
all_sanity_batches = [4, 20, 100]
for sanity_batches in all_sanity_batches:
data = NormalIterator(num_of_data=1000)
data = NormalSampler(num_of_data=1000)
wrapper = _TruncatedDataLoader(dataloader=data, num_batches=sanity_batches)
dataloader = iter(wrapper(dataloader=data))
dataloader = iter(wrapper)
mark = 0
while True:
try:
@ -32,8 +32,7 @@ class Test_WrapDataLoader:
dataset = TorchNormalDataset(num_of_data=1000)
dataloader = DataLoader(dataset, batch_size=bs, shuffle=True)
wrapper = _TruncatedDataLoader(dataloader, num_batches=sanity_batches)
dataloader = wrapper(dataloader)
dataloader = iter(dataloader)
dataloader = iter(wrapper)
all_supposed_running_data_num = 0
while True:
try:
@ -55,6 +54,5 @@ class Test_WrapDataLoader:
dataset = TorchNormalDataset(num_of_data=1000)
dataloader = DataLoader(dataset, batch_size=bs, shuffle=True)
wrapper = _TruncatedDataLoader(dataloader, num_batches=sanity_batches)
dataloader = wrapper(dataloader)
length.append(len(dataloader))
length.append(len(wrapper))
assert length == reduce(lambda x, y: x+y, [all_sanity_batches for _ in range(len(bses))])

View File

@ -15,7 +15,7 @@ else:
class Model (Module):
class Model(Module):
def __init__ (self):
super (Model, self).__init__()
self.conv1 = nn.Conv (3, 32, 3, 1) # no padding
@ -45,6 +45,7 @@ class Model (Module):
return x
@pytest.mark.jittor
@pytest.mark.skip("Skip jittor tests now.")
class TestSingleDevice:
def test_on_gpu_without_fp16(self):

View File

@ -2,7 +2,7 @@ import pytest
from pathlib import Path
from fastNLP.core.drivers.paddle_driver.single_device import PaddleSingleDriver
from fastNLP.core.samplers import RandomBatchSampler, RandomSampler
from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1
from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset
from tests.helpers.datasets.torch_data import TorchNormalDataset
@ -278,7 +278,7 @@ class TestPaddleDriverFunctions:
dataset = PaddleNormalDataset()
dataloader = DataLoader(
dataset,
batch_sampler=RandomBatchSampler(
batch_sampler=ReproduceBatchSampler(
BatchSampler(dataset, batch_size=batch_size, shuffle=shuffle),
batch_size,
drop_last,
@ -287,7 +287,7 @@ class TestPaddleDriverFunctions:
res = PaddleSingleDriver.get_dataloader_args(dataloader)
assert isinstance(res.dataset, PaddleNormalDataset)
assert isinstance(res.batch_sampler, RandomBatchSampler)
assert isinstance(res.batch_sampler, ReproduceBatchSampler)
if shuffle:
assert isinstance(res.sampler, paddle.io.RandomSampler)
else:
@ -387,7 +387,7 @@ class TestSetDistReproDataloader:
"""
测试 set_dist_repro_dataloader 参数 `reproducible` True 时的表现
当dist为字符串时此时应该返回新的 dataloader且如果原 sampler paddle.io.RandomSamplershuffle=True
只会替换 Sampler RandomSampler否则会替换 batch_sampler RandomBatchSampler
只会替换 Sampler RandomSampler否则会替换 batch_sampler ReproduceBatchSampler
"""
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=True)
@ -400,7 +400,7 @@ class TestSetDistReproDataloader:
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
else:
# 此时会替换 batch_sampler
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler)
assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler)
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
assert replaced_loader.drop_last == dataloader.drop_last
@ -414,11 +414,11 @@ class TestSetDistReproDataloader:
应该返回新的 dataloader并将 batch_sampler 替换为 dist 对应的 Sampler
"""
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=not shuffle)
dist = RandomBatchSampler(BatchSampler(self.dataset, batch_size=4, shuffle=shuffle), 4, False)
dist = ReproduceBatchSampler(BatchSampler(self.dataset, batch_size=4, shuffle=shuffle), 4, False)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False)
assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler)
assert replaced_loader.batch_sampler is dist
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
@ -450,7 +450,7 @@ class TestSetDistReproDataloader:
"""
dataloader = DataLoader(
dataset=self.dataset,
batch_sampler=RandomBatchSampler(
batch_sampler=ReproduceBatchSampler(
BatchSampler(self.dataset, batch_size=4, shuffle=shuffle),
batch_size=4,
drop_last=False,
@ -459,7 +459,7 @@ class TestSetDistReproDataloader:
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False)
assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler)
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
assert replaced_loader.drop_last == dataloader.drop_last
@ -500,20 +500,20 @@ class TestSetDistReproDataloader:
if idx >= num_consumed_batches:
break
already_seen_idx.update(batch)
if isinstance(replaced_loader.batch_sampler, RandomBatchSampler):
if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler):
sampler_states = replaced_loader.batch_sampler.state_dict()
else:
sampler_states = replaced_loader.batch_sampler.sampler.state_dict()
# 重新加载,应该可以输出剩下的内容,且对于 PaddleNormalDataset 来说,排序后应该是一个 range
left_idxes = set()
if isinstance(replaced_loader.batch_sampler, RandomBatchSampler):
if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler):
batch_size = replaced_loader.batch_sampler.batch_size
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size
# 重新改造 dataloader
new_loader = DataLoader(
dataset=replaced_loader.dataset,
batch_sampler=RandomBatchSampler(
batch_sampler=ReproduceBatchSampler(
BatchSampler(replaced_loader.dataset, shuffle=shuffle, batch_size=batch_size),
batch_size=batch_size,
drop_last=False,
@ -603,7 +603,7 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16):
dataset = PaddleRandomMaxDataset(40, 10)
dataloader = DataLoader(
dataset=dataset,
batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=4), 4, False)
batch_sampler=ReproduceBatchSampler(BatchSampler(dataset, batch_size=4), 4, False)
)
driver1, driver2 = generate_random_driver(10, 10, fp16, "gpu"), generate_random_driver(10, 10, False, "gpu")
@ -627,7 +627,7 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16):
# 更改 batch_size
dataloader = DataLoader(
dataset=dataset,
batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=2, shuffle=True), 2, False)
batch_sampler=ReproduceBatchSampler(BatchSampler(dataset, batch_size=2, shuffle=True), 2, False)
)
load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True)
replaced_loader = load_states.pop("dataloader")
@ -637,7 +637,7 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16):
# 2. 检查 batch_sampler 是否被正确地加载和替换
assert not (replaced_loader is dataloader)
assert replaced_loader.batch_sampler is dataloader.batch_sampler
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler)
assert replaced_loader.batch_sampler.index_list == sampler_states["index_list"]
assert replaced_loader.batch_sampler.num_consumed_samples == num_consumed_batches * 4

View File

@ -6,7 +6,7 @@ from fastNLP.core.drivers.paddle_driver.utils import (
replace_batch_sampler,
replace_sampler,
)
from fastNLP.core.samplers import RandomBatchSampler, RandomSampler
from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
if _NEED_IMPORT_PADDLE:
import paddle
@ -36,12 +36,12 @@ def test_get_device_from_visible_str(user_visible_devices, cuda_visible_devices,
def test_replace_batch_sampler():
dataset = PaddleNormalDataset(10)
dataloader = DataLoader(dataset, batch_size=32)
batch_sampler = RandomBatchSampler(dataloader.batch_sampler, batch_size=16, drop_last=False)
batch_sampler = ReproduceBatchSampler(dataloader.batch_sampler, batch_size=16, drop_last=False)
replaced_loader = replace_batch_sampler(dataloader, batch_sampler)
assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler)
assert isinstance(replaced_loader.dataset, PaddleNormalDataset)
assert len(replaced_loader.dataset) == len(dataset)
assert replaced_loader.batch_sampler.batch_size == 16

View File

@ -13,12 +13,13 @@ from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchArgMaxDataset
from tests.helpers.utils import magic_argv_env_context
from fastNLP.core import rank_zero_rm
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
import torch
import torch.distributed as dist
from torch.utils.data import DataLoader, BatchSampler
import torch
import torch.distributed as dist
from torch.utils.data import DataLoader, BatchSampler
def generate_driver(num_labels, feature_dimension, device=[0,1], fp16=False, output_from_new_proc="only_error"):
def generate_driver(num_labels, feature_dimension, device=[0,1], fp16=False, output_from_new_proc="all"):
torch_model = TorchNormalModel_Classification_1(num_labels, feature_dimension)
torch_opt = torch.optim.Adam(params=torch_model.parameters(), lr=0.01)
device = [torch.device(i) for i in device]
@ -72,108 +73,100 @@ def dataloader_with_randomsampler(dataset, batch_size, shuffle, drop_last, seed=
#
############################################################################
@pytest.mark.torch
@magic_argv_env_context
def test_multi_drivers():
"""
测试使用了多个 TorchDDPDriver 的情况
"""
generate_driver(10, 10)
generate_driver(20, 10)
with pytest.raises(RuntimeError):
# 设备设置不同,应该报错
generate_driver(20, 3, device=[0,1,2])
assert False
dist.barrier()
if dist.is_initialized():
dist.destroy_process_group()
@pytest.mark.torch
class TestDDPDriverFunction:
"""
测试 TorchDDPDriver 一些简单函数的测试类基本都是测试能否运行是否存在 import 错误等问题
"""
@classmethod
def setup_class(cls):
cls.driver = generate_driver(10, 10)
@magic_argv_env_context
def test_multi_drivers(self):
def test_simple_functions(self):
"""
测试使用了多个 TorchDDPDriver 的情况
简单测试多个函数
"""
driver2 = generate_driver(20, 10)
with pytest.raises(RuntimeError):
# 设备设置不同,应该报错
driver3 = generate_driver(20, 3, device=[0,1,2])
assert False
driver = generate_driver(10, 10)
"""
测试 move_data_to_device 函数这个函数仅调用了 torch_move_data_to_device 测试例在
tests/core/utils/test_torch_utils.py中就不重复测试了
"""
driver.move_data_to_device(torch.rand((32, 64)))
dist.barrier()
@magic_argv_env_context
def test_move_data_to_device(self):
"""
这个函数仅调用了torch_move_data_to_device测试例在tests/core/utils/test_torch_utils.py中
就不重复测试了
"""
self.driver.move_data_to_device(torch.rand((32, 64)))
dist.barrier()
@magic_argv_env_context
def test_is_distributed(self):
"""
测试 is_distributed 函数
"""
assert self.driver.is_distributed() == True
assert driver.is_distributed() == True
dist.barrier()
@magic_argv_env_context
def test_get_no_sync_context(self):
"""
测试 get_no_sync_context 函数
"""
res = self.driver.get_model_no_sync_context()
res = driver.get_model_no_sync_context()
dist.barrier()
@magic_argv_env_context
def test_is_global_zero(self):
"""
测试 is_global_zero 函数
"""
self.driver.is_global_zero()
driver.is_global_zero()
dist.barrier()
@magic_argv_env_context
def test_unwrap_model(self):
"""
测试 unwrap_model 函数
"""
self.driver.unwrap_model()
driver.unwrap_model()
dist.barrier()
@magic_argv_env_context
def test_get_local_rank(self):
"""
测试 get_local_rank 函数
"""
self.driver.get_local_rank()
driver.get_local_rank()
dist.barrier()
@magic_argv_env_context
def test_all_gather(self):
"""
测试 all_gather 函数
详细的测试在 test_dist_utils.py 中完成
"""
obj = {
"rank": self.driver.global_rank
"rank": driver.global_rank
}
obj_list = self.driver.all_gather(obj, group=None)
obj_list = driver.all_gather(obj, group=None)
for i, res in enumerate(obj_list):
assert res["rank"] == i
@magic_argv_env_context
@pytest.mark.parametrize("src_rank", ([0, 1]))
def test_broadcast_object(self, src_rank):
"""
测试 broadcast_object 函数
详细的函数在 test_dist_utils.py 中完成
"""
if self.driver.global_rank == src_rank:
if driver.global_rank == 0:
obj = {
"rank": self.driver.global_rank
"rank": driver.global_rank
}
else:
obj = None
res = self.driver.broadcast_object(obj, src=src_rank)
assert res["rank"] == src_rank
res = driver.broadcast_object(obj, src=0)
assert res["rank"] == 0
if dist.is_initialized():
dist.destroy_process_group()
############################################################################
#
@ -187,7 +180,6 @@ class TestSetDistReproDataloader:
@classmethod
def setup_class(cls):
cls.device = [0, 1]
cls.driver = generate_driver(10, 10, device=cls.device)
def setup_method(self):
self.dataset = TorchNormalDataset(40)
@ -204,17 +196,20 @@ class TestSetDistReproDataloader:
测试 set_dist_repro_dataloader dist BucketedBatchSampler 时的表现
此时应该将 batch_sampler 替换为 dist 对应的 BucketedBatchSampler
"""
driver = generate_driver(10, 10, device=self.device)
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=not shuffle)
batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, batch_sampler, False)
replaced_loader = driver.set_dist_repro_dataloader(dataloader, batch_sampler, False)
assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler)
assert replaced_loader.batch_sampler is batch_sampler
self.check_distributed_sampler(replaced_loader.batch_sampler)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
self.check_set_dist_repro_dataloader(driver, dataloader, replaced_loader, shuffle)
dist.barrier()
if dist.is_initialized():
dist.destroy_process_group()
@magic_argv_env_context
@pytest.mark.parametrize("shuffle", ([True, False]))
@ -223,9 +218,10 @@ class TestSetDistReproDataloader:
测试 set_dist_repro_dataloader dist RandomSampler 时的表现
此时应该将 batch_sampler.sampler 替换为 dist 对应的 RandomSampler
"""
driver = generate_driver(10, 10, device=self.device)
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=not shuffle)
sampler = RandomSampler(self.dataset, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, sampler, False)
replaced_loader = driver.set_dist_repro_dataloader(dataloader, sampler, False)
assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, BatchSampler)
@ -234,9 +230,11 @@ class TestSetDistReproDataloader:
assert replaced_loader.batch_sampler.sampler is sampler
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
self.check_set_dist_repro_dataloader(driver, dataloader, replaced_loader, shuffle)
dist.barrier()
if dist.is_initialized():
dist.destroy_process_group()
"""
传入的参数 `dist` None 的情况这种情况出现在 trainer evaluator 的初始化过程中用户指定了 `use_dist_sampler`
@ -251,15 +249,17 @@ class TestSetDistReproDataloader:
测试 set_dist_repro_dataloader dist Nonereproducible True 时的表现
当用户在 driver 之外初始化了分布式环境时fastnlp 不支持进行断点重训此时应该报错
"""
driver = generate_driver(10, 10, device=self.device)
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True)
with pytest.raises(RuntimeError):
# 应当抛出 RuntimeError
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, True)
replaced_loader = driver.set_dist_repro_dataloader(dataloader, None, True)
dist.barrier()
if dist.is_initialized():
dist.destroy_process_group()
@magic_argv_env_context
# @pytest.mark.parametrize("shuffle", ([True, False]))
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_with_dist_none_reproducible_false_dataloader_reproducible_batch_sampler(self, shuffle):
"""
@ -268,21 +268,24 @@ class TestSetDistReproDataloader:
此时传入的 dataloader batch_sampler 应该已经执行了 set_distributed产生一个新的 dataloader batch_sampler
和原 dataloader 相同
"""
driver = generate_driver(10, 10, device=self.device)
dataloader = dataloader_with_bucketedbatchsampler(self.dataset, self.dataset._data, 4, shuffle, False)
dataloader.batch_sampler.set_distributed(
num_replicas=self.driver.world_size,
rank=self.driver.global_rank,
num_replicas=driver.world_size,
rank=driver.global_rank,
pad=True
)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False)
replaced_loader = driver.set_dist_repro_dataloader(dataloader, None, False)
assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler)
assert replaced_loader.batch_sampler.batch_size == 4
self.check_distributed_sampler(dataloader.batch_sampler)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
self.check_set_dist_repro_dataloader(driver, dataloader, replaced_loader, shuffle)
dist.barrier()
if dist.is_initialized():
dist.destroy_process_group()
@magic_argv_env_context
@pytest.mark.parametrize("shuffle", ([True, False]))
@ -292,12 +295,13 @@ class TestSetDistReproDataloader:
此时传入的 dataloader batch_sampler.sampler 应该已经执行了 set_distributed产生一个新的 dataloader
batch_sampler.sampler 和原 dataloader 相同
"""
driver = generate_driver(10, 10, device=self.device)
dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=False)
dataloader.batch_sampler.sampler.set_distributed(
num_replicas=self.driver.world_size,
rank=self.driver.global_rank
num_replicas=driver.world_size,
rank=driver.global_rank
)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False)
replaced_loader = driver.set_dist_repro_dataloader(dataloader, None, False)
assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, BatchSampler)
@ -307,9 +311,11 @@ class TestSetDistReproDataloader:
assert replaced_loader.batch_sampler.batch_size == 4
assert replaced_loader.batch_sampler.drop_last == False
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
self.check_set_dist_repro_dataloader(driver, dataloader, replaced_loader, shuffle)
dist.barrier()
if dist.is_initialized():
dist.destroy_process_group()
@magic_argv_env_context
@pytest.mark.parametrize("shuffle", ([True, False]))
@ -318,11 +324,14 @@ class TestSetDistReproDataloader:
测试 set_dist_repro_dataloader dist Nonereproducible False dataloader 为一般情况时的表现
此时直接返回原来的 dataloader不做任何处理
"""
driver = generate_driver(10, 10, device=self.device)
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False)
replaced_loader = driver.set_dist_repro_dataloader(dataloader, None, False)
assert replaced_loader is dataloader
dist.barrier()
if dist.is_initialized():
dist.destroy_process_group()
"""
传入的参数 `dist` 'dist' 的情况这种情况出现在 trainer 的初始化过程中用户指定了 `use_dist_sampler` 参数
@ -337,12 +346,13 @@ class TestSetDistReproDataloader:
的表现
此时应该返回一个新的 dataloader其batch_sampler 和原 dataloader 相同且应该正确地设置了分布式相关的属性
"""
driver = generate_driver(10, 10, device=self.device)
dataloader = DataLoader(
dataset=self.dataset,
batch_sampler=BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4, shuffle=shuffle)
)
dataloader = dataloader_with_bucketedbatchsampler(self.dataset, self.dataset._data, 4, shuffle, False)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False)
replaced_loader = driver.set_dist_repro_dataloader(dataloader, "dist", False)
assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler)
@ -351,6 +361,8 @@ class TestSetDistReproDataloader:
assert replaced_loader.drop_last == dataloader.drop_last
self.check_distributed_sampler(replaced_loader.batch_sampler)
dist.barrier()
if dist.is_initialized():
dist.destroy_process_group()
@magic_argv_env_context
@pytest.mark.parametrize("shuffle", ([True, False]))
@ -361,8 +373,9 @@ class TestSetDistReproDataloader:
此时应该返回一个新的 dataloader batch_sampler.sampler 和原 dataloader 相同且应该正确地设置了分布式相关
的属性
"""
driver = generate_driver(10, 10, device=self.device)
dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=False)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False)
replaced_loader = driver.set_dist_repro_dataloader(dataloader, "dist", False)
assert not (replaced_loader is dataloader)
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
@ -372,6 +385,8 @@ class TestSetDistReproDataloader:
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
dist.barrier()
if dist.is_initialized():
dist.destroy_process_group()
@magic_argv_env_context
@pytest.mark.parametrize("shuffle", ([True, False]))
@ -381,8 +396,9 @@ class TestSetDistReproDataloader:
此时应该返回一个新的 dataloader并替换其 batch_sampler.sampler RandomSampler且应该正确设置了分布式相关
的属性
"""
driver = generate_driver(10, 10, device=self.device)
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False)
replaced_loader = driver.set_dist_repro_dataloader(dataloader, "dist", False)
assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, BatchSampler)
@ -392,6 +408,8 @@ class TestSetDistReproDataloader:
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
dist.barrier()
if dist.is_initialized():
dist.destroy_process_group()
"""
传入的参数 `dist` 'unrepeatdist' 的情况这种情况出现在 evaluator 的初始化过程中用户指定了 `use_dist_sampler` 参数
@ -407,8 +425,9 @@ class TestSetDistReproDataloader:
此时应该返回一个新的 dataloader且将原来的 Sampler 替换为 UnrepeatedRandomSampler且正确地设置了分布式相关
的属性
"""
driver = generate_driver(10, 10, device=self.device)
dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=False)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False)
replaced_loader = driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False)
assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, BatchSampler)
@ -418,6 +437,8 @@ class TestSetDistReproDataloader:
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
dist.barrier()
if dist.is_initialized():
dist.destroy_process_group()
@magic_argv_env_context
@pytest.mark.parametrize("shuffle", ([True, False]))
@ -427,8 +448,9 @@ class TestSetDistReproDataloader:
的表现
此时应该返回一个新的 dataloader且重新实例化了原来的 Sampler
"""
driver = generate_driver(10, 10, device=self.device)
dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=True)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False)
replaced_loader = driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False)
assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, BatchSampler)
@ -439,6 +461,8 @@ class TestSetDistReproDataloader:
assert replaced_loader.drop_last == dataloader.drop_last
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
dist.barrier()
if dist.is_initialized():
dist.destroy_process_group()
@magic_argv_env_context
@pytest.mark.parametrize("shuffle", ([True, False]))
@ -448,8 +472,9 @@ class TestSetDistReproDataloader:
此时应该返回一个新的 dataloader且将 sampler 替换为 UnrepeatedSequentialSampler并正确地设置了分布式相关
的属性
"""
driver = generate_driver(10, 10, device=self.device)
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False)
replaced_loader = driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False)
assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, BatchSampler)
@ -459,6 +484,8 @@ class TestSetDistReproDataloader:
assert replaced_loader.drop_last == dataloader.drop_last
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
dist.barrier()
if dist.is_initialized():
dist.destroy_process_group()
def check_distributed_sampler(self, sampler):
"""
@ -469,7 +496,7 @@ class TestSetDistReproDataloader:
if not isinstance(sampler, UnrepeatedSampler):
assert sampler.pad == True
def check_set_dist_repro_dataloader(self, dataloader, replaced_loader, shuffle):
def check_set_dist_repro_dataloader(self, driver, dataloader, replaced_loader, shuffle):
"""
测试多卡下 set_dist_repro_dataloader 函数的执行结果是否正确
"""
@ -501,8 +528,8 @@ class TestSetDistReproDataloader:
drop_last=False,
)
new_loader.batch_sampler.set_distributed(
num_replicas=self.driver.world_size,
rank=self.driver.global_rank,
num_replicas=driver.world_size,
rank=driver.global_rank,
pad=True
)
new_loader.batch_sampler.load_state_dict(sampler_states)
@ -512,8 +539,8 @@ class TestSetDistReproDataloader:
# 重新构造 dataloader
new_loader = dataloader_with_randomsampler(replaced_loader.dataset, batch_size, shuffle, drop_last=False)
new_loader.batch_sampler.sampler.set_distributed(
num_replicas=self.driver.world_size,
rank=self.driver.global_rank
num_replicas=driver.world_size,
rank=driver.global_rank
)
new_loader.batch_sampler.sampler.load_state_dict(sampler_states)
for idx, batch in enumerate(new_loader):
@ -534,11 +561,6 @@ class TestSaveLoad:
测试多卡情况下 save load 相关函数的表现
"""
@classmethod
def setup_class(cls):
# 不在这里 setup 的话会报错
cls.driver = generate_driver(10, 10)
def setup_method(self):
self.dataset = TorchArgMaxDataset(10, 20)
@ -552,26 +574,26 @@ class TestSaveLoad:
path = "model"
dataloader = DataLoader(self.dataset, batch_size=2)
self.driver1, self.driver2 = generate_driver(10, 10), generate_driver(10, 10)
driver1, driver2 = generate_driver(10, 10), generate_driver(10, 10)
self.driver1.save_model(path, only_state_dict)
driver1.save_model(path, only_state_dict)
# 同步
dist.barrier()
self.driver2.load_model(path, only_state_dict)
driver2.load_model(path, only_state_dict)
for idx, batch in enumerate(dataloader):
batch = self.driver1.move_data_to_device(batch)
res1 = self.driver1.model(
batch = driver1.move_data_to_device(batch)
res1 = driver1.model(
batch,
fastnlp_fn=self.driver1.model.module.model.evaluate_step,
fastnlp_fn=driver1.model.module.model.evaluate_step,
# Driver.model -> DataParallel.module -> _FleetWrappingModel.model
fastnlp_signature_fn=None,
wo_auto_param_call=False,
)
res2 = self.driver2.model(
res2 = driver2.model(
batch,
fastnlp_fn=self.driver2.model.module.model.evaluate_step,
fastnlp_fn=driver2.model.module.model.evaluate_step,
fastnlp_signature_fn=None,
wo_auto_param_call=False,
)
@ -580,6 +602,9 @@ class TestSaveLoad:
finally:
rank_zero_rm(path)
if dist.is_initialized():
dist.destroy_process_group()
@magic_argv_env_context
@pytest.mark.parametrize("only_state_dict", ([True, False]))
@pytest.mark.parametrize("fp16", ([True, False]))
@ -593,7 +618,7 @@ class TestSaveLoad:
path = "model.ckp"
num_replicas = len(device)
self.driver1, self.driver2 = generate_driver(10, 10, device=device, fp16=fp16), \
driver1, driver2 = generate_driver(10, 10, device=device, fp16=fp16), \
generate_driver(10, 10, device=device, fp16=False)
dataloader = dataloader_with_bucketedbatchsampler(
self.dataset,
@ -603,8 +628,8 @@ class TestSaveLoad:
drop_last=False
)
dataloader.batch_sampler.set_distributed(
num_replicas=self.driver1.world_size,
rank=self.driver1.global_rank,
num_replicas=driver1.world_size,
rank=driver1.global_rank,
pad=True
)
num_consumed_batches = 2
@ -623,7 +648,7 @@ class TestSaveLoad:
# 保存状态
sampler_states = dataloader.batch_sampler.state_dict()
save_states = {"num_consumed_batches": num_consumed_batches}
self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True)
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True)
# 加载
# 更改 batch_size
dataloader = dataloader_with_bucketedbatchsampler(
@ -634,11 +659,11 @@ class TestSaveLoad:
drop_last=False
)
dataloader.batch_sampler.set_distributed(
num_replicas=self.driver2.world_size,
rank=self.driver2.global_rank,
num_replicas=driver2.world_size,
rank=driver2.global_rank,
pad=True
)
load_states = self.driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True)
load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True)
replaced_loader = load_states.pop("dataloader")
# 1. 检查 optimizer 的状态
# TODO optimizer 的 state_dict 总是为空
@ -652,7 +677,7 @@ class TestSaveLoad:
# 3. 检查 fp16 是否被加载
if fp16:
assert isinstance(self.driver2.grad_scaler, torch.cuda.amp.GradScaler)
assert isinstance(driver2.grad_scaler, torch.cuda.amp.GradScaler)
# 4. 检查 model 的参数是否正确
# 5. 检查 batch_idx
@ -664,16 +689,16 @@ class TestSaveLoad:
left_x_batches.update(batch["x"])
left_y_batches.update(batch["y"])
res1 = self.driver1.model(
res1 = driver1.model(
batch,
fastnlp_fn=self.driver1.model.module.model.evaluate_step,
fastnlp_fn=driver1.model.module.model.evaluate_step,
# Driver.model -> DataParallel.module -> _FleetWrappingModel.model
fastnlp_signature_fn=None,
wo_auto_param_call=False,
)
res2 = self.driver2.model(
res2 = driver2.model(
batch,
fastnlp_fn=self.driver2.model.module.model.evaluate_step,
fastnlp_fn=driver2.model.module.model.evaluate_step,
fastnlp_signature_fn=None,
wo_auto_param_call=False,
)
@ -686,6 +711,9 @@ class TestSaveLoad:
finally:
rank_zero_rm(path)
if dist.is_initialized():
dist.destroy_process_group()
@magic_argv_env_context
@pytest.mark.parametrize("only_state_dict", ([True, False]))
@pytest.mark.parametrize("fp16", ([True, False]))
@ -700,13 +728,13 @@ class TestSaveLoad:
num_replicas = len(device)
self.driver1 = generate_driver(10, 10, device=device, fp16=fp16)
self.driver2 = generate_driver(10, 10, device=device, fp16=False)
driver1 = generate_driver(10, 10, device=device, fp16=fp16)
driver2 = generate_driver(10, 10, device=device, fp16=False)
dataloader = dataloader_with_randomsampler(self.dataset, 4, True, False, unrepeated=False)
dataloader.batch_sampler.sampler.set_distributed(
num_replicas=self.driver1.world_size,
rank=self.driver1.global_rank,
num_replicas=driver1.world_size,
rank=driver1.global_rank,
pad=True
)
num_consumed_batches = 2
@ -726,18 +754,18 @@ class TestSaveLoad:
sampler_states = dataloader.batch_sampler.sampler.state_dict()
save_states = {"num_consumed_batches": num_consumed_batches}
if only_state_dict:
self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True)
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True)
else:
self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[torch.ones((16, 10))])
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[torch.ones((16, 10))])
# 加载
# 更改 batch_size
dataloader = dataloader_with_randomsampler(self.dataset, 2, True, False, unrepeated=False)
dataloader.batch_sampler.sampler.set_distributed(
num_replicas=self.driver2.world_size,
rank=self.driver2.global_rank,
num_replicas=driver2.world_size,
rank=driver2.global_rank,
pad=True
)
load_states = self.driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True)
load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True)
replaced_loader = load_states.pop("dataloader")
# 1. 检查 optimizer 的状态
@ -753,7 +781,7 @@ class TestSaveLoad:
assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"]
# 3. 检查 fp16 是否被加载
if fp16:
assert isinstance(self.driver2.grad_scaler, torch.cuda.amp.GradScaler)
assert isinstance(driver2.grad_scaler, torch.cuda.amp.GradScaler)
# 4. 检查 model 的参数是否正确
# 5. 检查 batch_idx
@ -765,16 +793,16 @@ class TestSaveLoad:
left_x_batches.update(batch["x"])
left_y_batches.update(batch["y"])
res1 = self.driver1.model(
res1 = driver1.model(
batch,
fastnlp_fn=self.driver1.model.module.model.evaluate_step,
fastnlp_fn=driver1.model.module.model.evaluate_step,
# Driver.model -> DataParallel.module -> _FleetWrappingModel.model
fastnlp_signature_fn=None,
wo_auto_param_call=False,
)
res2 = self.driver2.model(
res2 = driver2.model(
batch,
fastnlp_fn=self.driver2.model.module.model.evaluate_step,
fastnlp_fn=driver2.model.module.model.evaluate_step,
fastnlp_signature_fn=None,
wo_auto_param_call=False,
)
@ -786,4 +814,7 @@ class TestSaveLoad:
assert len(left_y_batches | already_seen_y_set) == len(self.dataset) / num_replicas
finally:
rank_zero_rm(path)
rank_zero_rm(path)
if dist.is_initialized():
dist.destroy_process_group()

View File

@ -2,12 +2,14 @@ import pytest
from fastNLP.core.drivers import TorchSingleDriver, TorchDDPDriver
from fastNLP.core.drivers.torch_driver.initialize_torch_driver import initialize_torch_driver
from fastNLP.envs import get_gpu_count
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
from tests.helpers.utils import magic_argv_env_context
import torch
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
import torch
from torch import device as torchdevice
else:
from fastNLP.core.utils.dummy_class import DummyClass as torchdevice
@pytest.mark.torch
def test_incorrect_driver():
@ -20,7 +22,7 @@ def test_incorrect_driver():
@pytest.mark.torch
@pytest.mark.parametrize(
"device",
["cpu", "cuda:0", 0, torch.device("cuda:0")]
["cpu", "cuda:0", 0, torchdevice("cuda:0")]
)
@pytest.mark.parametrize(
"driver",
@ -83,7 +85,6 @@ def test_get_ddp(driver, device):
("driver", "device"),
[("torch_ddp", "cpu")]
)
@magic_argv_env_context
def test_get_ddp_cpu(driver, device):
"""
测试试图在 cpu 上初始化分布式训练的情况
@ -96,13 +97,12 @@ def test_get_ddp_cpu(driver, device):
@pytest.mark.torch
@pytest.mark.parametrize(
"device",
[-2, [0, torch.cuda.device_count() + 1, 3], [-2], torch.cuda.device_count() + 1]
[-2, [0, 20, 3], [-2], 20]
)
@pytest.mark.parametrize(
"driver",
["torch", "torch_ddp"]
)
@magic_argv_env_context
def test_device_out_of_range(driver, device):
"""
测试传入的device超过范围的情况

View File

@ -2,7 +2,7 @@ import pytest
from pathlib import Path
from fastNLP.core.drivers.torch_driver.single_device import TorchSingleDriver
from fastNLP.core.samplers import RandomBatchSampler, RandomSampler
from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchArgMaxDataset
from tests.helpers.datasets.paddle_data import PaddleNormalDataset
@ -17,7 +17,7 @@ if _NEED_IMPORT_PADDLE:
def dataloader_with_randombatchsampler(dataset, batch_size, shuffle, drop_last):
"""
建立一个 batch_sampler RandomBatchSampler dataloader
建立一个 batch_sampler ReproduceBatchSampler dataloader
"""
if shuffle:
sampler = torch.utils.data.RandomSampler(dataset)
@ -25,7 +25,7 @@ def dataloader_with_randombatchsampler(dataset, batch_size, shuffle, drop_last):
sampler = torch.utils.data.SequentialSampler(dataset)
dataloader = DataLoader(
dataset=dataset,
batch_sampler=RandomBatchSampler(
batch_sampler=ReproduceBatchSampler(
BatchSampler(
sampler, batch_size=batch_size, drop_last=drop_last
),
@ -306,7 +306,7 @@ class TestTorchDriverFunctions:
res = TorchSingleDriver.get_dataloader_args(dataloader)
assert isinstance(res.dataset, TorchNormalDataset)
assert isinstance(res.batch_sampler, RandomBatchSampler)
assert isinstance(res.batch_sampler, ReproduceBatchSampler)
if shuffle:
assert isinstance(res.sampler, torch.utils.data.RandomSampler)
else:
@ -401,7 +401,7 @@ class TestSetDistReproDataloader:
"""
测试 set_dist_repro_dataloader 参数 `reproducible` True 时的表现
当dist为字符串时此时应该返回新的 dataloader且如果原 sampler torch.utils.data.RandomSamplershuffle=True
只会替换 Sampler RandomSampler否则会替换 batch_sampler RandomBatchSampler
只会替换 Sampler RandomSampler否则会替换 batch_sampler ReproduceBatchSampler
"""
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=True)
@ -414,7 +414,7 @@ class TestSetDistReproDataloader:
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
else:
# 此时会替换 batch_sampler
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler)
assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler)
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
assert replaced_loader.drop_last == dataloader.drop_last
@ -428,11 +428,11 @@ class TestSetDistReproDataloader:
应该返回新的 dataloader并将 batch_sampler 替换为 dist 对应的 Sampler
"""
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle)
dist = RandomBatchSampler(BatchSampler(self.dataset, batch_size=4, drop_last=False), 4, False)
dist = ReproduceBatchSampler(BatchSampler(self.dataset, batch_size=4, drop_last=False), 4, False)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False)
assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler)
assert replaced_loader.batch_sampler is dist
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
@ -466,7 +466,7 @@ class TestSetDistReproDataloader:
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False)
assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler)
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
assert replaced_loader.drop_last == dataloader.drop_last
@ -502,14 +502,14 @@ class TestSetDistReproDataloader:
if idx >= num_consumed_batches:
break
already_seen_idx.update(batch)
if isinstance(replaced_loader.batch_sampler, RandomBatchSampler):
if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler):
sampler_states = replaced_loader.batch_sampler.state_dict()
else:
sampler_states = replaced_loader.batch_sampler.sampler.state_dict()
# 重新加载,应该可以输出剩下的内容,且对于 TorchNormalDataset 来说,排序后应该是一个 range
left_idxes = set()
if isinstance(replaced_loader.batch_sampler, RandomBatchSampler):
if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler):
batch_size = replaced_loader.batch_sampler.batch_size
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size
# 重新改造 dataloader
@ -613,7 +613,7 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16):
# 2. 检查 batch_sampler 是否被正确地加载和替换
assert not (replaced_loader is dataloader)
assert replaced_loader.batch_sampler is dataloader.batch_sampler
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler)
assert replaced_loader.batch_sampler.index_list == sampler_states["index_list"]
assert replaced_loader.batch_sampler.num_consumed_samples == num_consumed_batches * 4

View File

@ -30,7 +30,7 @@ class SequenceDataSet:
def check_replace_sampler(driver):
# dist_sampler 可以选择的有['dist', 'unrepeatdist', None]或者是ReproducibleSamplerRandomBatchSampler
# dist_sampler 可以选择的有['dist', 'unrepeatdist', None]或者是ReproducibleSamplerReproduceBatchSampler
# reproducible 是 True 和 False
# 需要 check 返回的 sampler 和 dataloader 都不同了

View File

@ -4,7 +4,7 @@ from fastNLP.core.drivers.torch_driver.utils import (
replace_batch_sampler,
replace_sampler,
)
from fastNLP.core.samplers import RandomBatchSampler, RandomSampler
from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler
from torch.utils.data import DataLoader, BatchSampler
from tests.helpers.datasets.torch_data import TorchNormalDataset
@ -14,12 +14,12 @@ from tests.helpers.datasets.torch_data import TorchNormalDataset
def test_replace_batch_sampler():
dataset = TorchNormalDataset(10)
dataloader = DataLoader(dataset, batch_size=32)
batch_sampler = RandomBatchSampler(dataloader.batch_sampler, batch_size=16, drop_last=False)
batch_sampler = ReproduceBatchSampler(dataloader.batch_sampler, batch_size=16, drop_last=False)
replaced_loader = replace_batch_sampler(dataloader, batch_sampler)
assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler)
assert isinstance(replaced_loader.dataset, TorchNormalDataset)
assert len(replaced_loader.dataset) == len(dataset)
assert replaced_loader.batch_sampler.batch_size == 16

View File

@ -7,15 +7,20 @@ import copy
import socket
import pytest
import numpy as np
import torch
import torch.distributed
from torch.multiprocessing import Pool, set_start_method
from sklearn.metrics import accuracy_score as sklearn_accuracy
from fastNLP.core.dataset import DataSet
from fastNLP.core.metrics.accuracy import Accuracy
from fastNLP.core.metrics.metric import Metric
from .utils import find_free_network_port, setup_ddp, _assert_allclose
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
import torch
import torch.distributed
from torch.multiprocessing import Pool, set_start_method
else:
from fastNLP.core.utils.dummy_class import DummyClass as set_start_method
set_start_method("spawn", force=True)
@ -26,7 +31,7 @@ pool = None
def _test(local_rank: int,
world_size: int,
device: torch.device,
device: "torch.device",
dataset: DataSet,
metric_class: Type[Metric],
metric_kwargs: Dict[str, Any],

View File

@ -2,18 +2,23 @@ from functools import partial
import copy
import pytest
import torch
import numpy as np
from torch.multiprocessing import Pool, set_start_method
from fastNLP.core.metrics import ClassifyFPreRecMetric
from fastNLP.core.dataset import DataSet
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
from .utils import find_free_network_port, setup_ddp
if _NEED_IMPORT_TORCH:
import torch
from torch.multiprocessing import Pool, set_start_method
else:
from fastNLP.core.utils.dummy_class import DummyClass as set_start_method
set_start_method("spawn", force=True)
def _test(local_rank: int, world_size: int, device: torch.device,
def _test(local_rank: int, world_size: int, device: "torch.device",
dataset: DataSet, metric_class, metric_kwargs, metric_result):
metric = metric_class(**metric_kwargs)
# dataset 也类似(每个进程有自己的一个)

View File

@ -5,16 +5,21 @@ import os, sys
import copy
from functools import partial
import torch
import torch.distributed
import numpy as np
import socket
from torch.multiprocessing import Pool, set_start_method
# from multiprocessing import Pool, set_start_method
from fastNLP.core.vocabulary import Vocabulary
from fastNLP.core.metrics import SpanFPreRecMetric
from fastNLP.core.dataset import DataSet
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
from .utils import find_free_network_port, setup_ddp
if _NEED_IMPORT_TORCH:
import torch
import torch.distributed
from torch.multiprocessing import Pool, set_start_method
else:
from fastNLP.core.utils.dummy_class import DummyClass as set_start_method
set_start_method("spawn", force=True)
@ -44,7 +49,7 @@ pool = None
def _test(local_rank: int,
world_size: int,
device: torch.device,
device: "torch.device",
dataset: DataSet,
metric_class,
metric_kwargs,

View File

@ -2,9 +2,11 @@ import os, sys
import socket
from typing import Union
import torch
from torch import distributed
import numpy as np
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
import torch
from torch import distributed
def setup_ddp(rank: int, world_size: int, master_port: int) -> None:

View File

@ -1,161 +1,131 @@
from array import array
import numpy as np
import pytest
from itertools import chain
from copy import deepcopy
from array import array
from fastNLP.core.samplers import RandomBatchSampler, BucketedBatchSampler
from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler
from tests.helpers.datasets.torch_data import TorchNormalDataset
from tests.helpers.datasets.normal_data import NormalSampler, NormalBatchSampler
from fastNLP.core.samplers import ReproduceBatchSampler, BucketedBatchSampler, RandomBatchSampler
#
# class TestReproducibleBatchSampler:
# # TODO 拆分测试,在这里只测试一个东西
# def test_torch_dataloader_1(self):
# import torch
# from torch.utils.data import DataLoader
# # no shuffle
# before_batch_size = 7
# dataset = TorchNormalDataset(num_of_data=100)
# dataloader = DataLoader(dataset, batch_size=before_batch_size)
# re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
# dataloader = replace_batch_sampler(dataloader, re_batchsampler)
#
# forward_steps = 3
# iter_dataloader = iter(dataloader)
# for _ in range(forward_steps):
# next(iter_dataloader)
#
# # 1. 保存状态
# _get_re_batchsampler = dataloader.batch_sampler
# assert isinstance(_get_re_batchsampler, RandomBatchSampler)
# state = _get_re_batchsampler.state_dict()
# assert state == {"index_list": array("I", list(range(100))), "num_consumed_samples": forward_steps*before_batch_size,
# "sampler_type": "RandomBatchSampler"}
#
# # 2. 断点重训,重新生成一个 dataloader
# # 不改变 batch_size
# dataloader = DataLoader(dataset, batch_size=before_batch_size)
# re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
# re_batchsampler.load_state_dict(state)
# dataloader = replace_batch_sampler(dataloader, re_batchsampler)
#
# real_res = []
# supposed_res = (torch.tensor(list(range(21, 28))), torch.tensor(list(range(28, 35))))
# forward_steps = 2
# iter_dataloader = iter(dataloader)
# for _ in range(forward_steps):
# real_res.append(next(iter_dataloader))
#
# for i in range(forward_steps):
# assert all(real_res[i] == supposed_res[i])
#
# # 改变 batch_size
# after_batch_size = 3
# dataloader = DataLoader(dataset, batch_size=after_batch_size)
# re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
# re_batchsampler.load_state_dict(state)
# dataloader = replace_batch_sampler(dataloader, re_batchsampler)
#
# real_res = []
# supposed_res = (torch.tensor(list(range(21, 24))), torch.tensor(list(range(24, 27))))
# forward_steps = 2
# iter_dataloader = iter(dataloader)
# for _ in range(forward_steps):
# real_res.append(next(iter_dataloader))
#
# for i in range(forward_steps):
# assert all(real_res[i] == supposed_res[i])
#
# # 断点重训的第二轮是否是一个完整的 dataloader
# # 先把断点重训所在的那一个 epoch 跑完;
# begin_idx = 27
# while True:
# try:
# data = next(iter_dataloader)
# _batch_size = len(data)
# assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size))))
# begin_idx += _batch_size
# except StopIteration:
# break
#
# # 开始新的一轮;
# begin_idx = 0
# iter_dataloader = iter(dataloader)
# while True:
# try:
# data = next(iter_dataloader)
# _batch_size = len(data)
# assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size))))
# begin_idx += _batch_size
# except StopIteration:
# break
#
# def test_torch_dataloader_2(self):
# # 测试新的一轮的 index list 是重新生成的,而不是沿用上一轮的;
# from torch.utils.data import DataLoader
# # no shuffle
# before_batch_size = 7
# dataset = TorchNormalDataset(num_of_data=100)
# # 开启 shuffle来检验断点重训后的第二轮的 index list 是不是重新生成的;
# dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True)
# re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
# dataloader = replace_batch_sampler(dataloader, re_batchsampler)
#
# # 将一轮的所有数据保存下来,看是否恢复的是正确的;
# all_supposed_data = []
# forward_steps = 3
# iter_dataloader = iter(dataloader)
# for _ in range(forward_steps):
# all_supposed_data.extend(next(iter_dataloader).tolist())
#
# # 1. 保存状态
# _get_re_batchsampler = dataloader.batch_sampler
# assert isinstance(_get_re_batchsampler, RandomBatchSampler)
# state = _get_re_batchsampler.state_dict()
#
# # 2. 断点重训,重新生成一个 dataloader
# # 不改变 batch_size
# dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True)
# re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
# re_batchsampler.load_state_dict(state)
# dataloader = replace_batch_sampler(dataloader, re_batchsampler)
#
# # 先把这一轮的数据过完;
# pre_index_list = dataloader.batch_sampler.state_dict()["index_list"]
# while True:
# try:
# all_supposed_data.extend(next(iter_dataloader).tolist())
# except StopIteration:
# break
# assert all_supposed_data == list(pre_index_list)
#
# # 重新开启新的一轮;
# for _ in range(3):
# iter_dataloader = iter(dataloader)
# res = []
# while True:
# try:
# res.append(next(iter_dataloader))
# except StopIteration:
# break
#
# def test_3(self):
# import torch
# from torch.utils.data import DataLoader
# before_batch_size = 7
# dataset = TorchNormalDataset(num_of_data=100)
# # 开启 shuffle来检验断点重训后的第二轮的 index list 是不是重新生成的;
# dataloader = DataLoader(dataset, batch_size=before_batch_size)
#
# for idx, data in enumerate(dataloader):
# if idx > 3:
# break
#
# iterator = iter(dataloader)
# for each in iterator:
# pass
class TestReproducibleBatchSampler:
def test_1(self):
sampler = NormalSampler(num_of_data=100) # 这里是否是 batchsampler 不影响;
reproduce_batch_sampler = ReproduceBatchSampler(sampler, batch_size=4, drop_last=False)
forward_steps = 3
iterator = iter(reproduce_batch_sampler)
i = 0
while i < forward_steps:
next(iterator)
i += 1
# 保存状态;
state = reproduce_batch_sampler.state_dict()
assert state == {"index_list": array("I", list(range(100))),
"num_consumed_samples": forward_steps * 4,
"sampler_type": "ReproduceBatchSampler"}
# 重新生成一个 batchsampler 然后加载状态;
sampler = NormalSampler(num_of_data=100) # 这里是否是 batchsampler 不影响;
reproduce_batch_sampler = ReproduceBatchSampler(sampler, batch_size=4, drop_last=False)
reproduce_batch_sampler.load_state_dict(state)
real_res = []
supposed_res = (list(range(12, 16)), list(range(16, 20)))
forward_steps = 2
iter_dataloader = iter(reproduce_batch_sampler)
for _ in range(forward_steps):
real_res.append(next(iter_dataloader))
for i in range(forward_steps):
assert supposed_res[i] == real_res[i]
# 改变 batchsize
sampler = NormalSampler(num_of_data=100) # 这里是否是 batchsampler 不影响;
reproduce_batch_sampler = ReproduceBatchSampler(sampler, batch_size=7, drop_last=False)
reproduce_batch_sampler.load_state_dict(state)
real_res = []
supposed_res = (list(range(12, 19)), list(range(19, 26)))
forward_steps = 2
iter_dataloader = iter(reproduce_batch_sampler)
for _ in range(forward_steps):
real_res.append(next(iter_dataloader))
for i in range(forward_steps):
assert supposed_res[i] == real_res[i]
# 断点重训的第二轮是否是一个完整的 dataloader
# 先把断点重训所在的那一个 epoch 跑完;
begin_idx = 26
while True:
try:
data = next(iter_dataloader)
_batch_size = len(data)
assert data == list(range(begin_idx, begin_idx + _batch_size))
begin_idx += _batch_size
except StopIteration:
break
# 开始新的一轮;
begin_idx = 0
iter_dataloader = iter(reproduce_batch_sampler)
while True:
try:
data = next(iter_dataloader)
_batch_size = len(data)
assert data == list(range(begin_idx, begin_idx + _batch_size))
begin_idx += _batch_size
except StopIteration:
break
def test_2(self):
# 测试新的一轮的 index list 是重新生成的,而不是沿用上一轮的;
before_batch_size = 7
sampler = NormalSampler(num_of_data=100)
# 开启 shuffle来检验断点重训后的第二轮的 index list 是不是重新生成的;
reproduce_batch_sampler = ReproduceBatchSampler(sampler, before_batch_size, drop_last=False)
# 将一轮的所有数据保存下来,看是否恢复的是正确的;
all_supposed_data = []
forward_steps = 3
iter_dataloader = iter(reproduce_batch_sampler)
for _ in range(forward_steps):
all_supposed_data.extend(next(iter_dataloader))
# 1. 保存状态
state = reproduce_batch_sampler.state_dict()
# 2. 断点重训,重新生成一个 dataloader
# 不改变 batch_size
sampler = NormalSampler(num_of_data=100, shuffle=True)
reproduce_batch_sampler = ReproduceBatchSampler(sampler, before_batch_size, drop_last=False)
reproduce_batch_sampler.load_state_dict(state)
# 先把这一轮的数据过完;
pre_index_list = reproduce_batch_sampler.state_dict()["index_list"]
iter_dataloader = iter(reproduce_batch_sampler)
while True:
try:
all_supposed_data.extend(next(iter_dataloader))
except StopIteration:
break
assert all_supposed_data == list(pre_index_list)
# 重新开启新的一轮;
for _ in range(3):
iter_dataloader = iter(reproduce_batch_sampler)
res = []
while True:
try:
res.extend(next(iter_dataloader))
except StopIteration:
break
assert res != all_supposed_data
class DatasetWithVaryLength:
@ -511,3 +481,313 @@ class TestBucketedBatchSampler:
already_seen_set.update(batch)
assert len(already_seen_set)==len(dataset) if drop_last is False else len(already_seen_set)<=len(dataset)
class TestRandomBatchSampler:
@pytest.mark.parametrize('shuffle', [True, False])
@pytest.mark.parametrize('drop_last', [True, False])
@pytest.mark.parametrize('num', [2, 7, 14, 15, 70, 71])
def test_single_num_batch(self, shuffle, drop_last, num):
# 数量不够不报错
for num in [2, 7, 14, 15, 70, 71]:
dataset = DatasetWithVaryLength(num_of_data=num)
before_batch_size = 7
re_batchsampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=before_batch_size,
drop_last=drop_last,
shuffle=shuffle)
count = len(list(iter(re_batchsampler)))
if drop_last:
assert count==num//before_batch_size, num
else:
assert count==(num+before_batch_size-1)//before_batch_size, num
@pytest.mark.parametrize('shuffle', [True, False])
@pytest.mark.parametrize('drop_last', [True, False])
def test_single(self, shuffle, drop_last):
before_batch_size = 7
num_batch_per_bucket = 4 # 那么任意 batch 内的长度差值不应该超过4
dataset = DatasetWithVaryLength(num_of_data=1000)
re_batchsampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=before_batch_size,
drop_last=drop_last,
shuffle=shuffle)
re_batchsampler.set_epoch(0)
forward_steps = 10
iterator = iter(re_batchsampler)
already_generate_indices = set()
for _ in range(forward_steps):
batch = next(iterator)
already_generate_indices.update(batch)
# 1. 保存状态
state = re_batchsampler.state_dict()
# 2. 断点重训,继续训练
re_batchsampler2 = RandomBatchSampler(dataset, length=dataset.data, batch_size=before_batch_size,
drop_last=drop_last,
shuffle=shuffle)
re_batchsampler2.load_state_dict(state)
re_batchsampler2.set_epoch(0)
new_already_generate_indices = set()
mask = np.ones(len(dataset), dtype=bool)
mask[list(already_generate_indices)] = 0
indices = np.arange(len(dataset))[mask]
max_diff = -1
for i in range(len(indices)-before_batch_size * num_batch_per_bucket):
max_diff = max(max_diff, indices[i+before_batch_size * num_batch_per_bucket]-indices[i])
for batch in re_batchsampler2:
for b in batch:
assert b not in already_generate_indices
new_already_generate_indices.update(batch)
if drop_last is False:
assert len(new_already_generate_indices.union(already_generate_indices))==len(dataset)
# 改变 batch_size
after_batch_size = 3
re_batchsampler3 = RandomBatchSampler(dataset, length=dataset.data, batch_size=after_batch_size,
drop_last=drop_last,
shuffle=shuffle)
re_batchsampler3.load_state_dict(state)
re_batchsampler3.set_epoch(0)
count = 0
mask = np.ones(len(dataset), dtype=bool)
mask[list(already_generate_indices)] = 0
indices = np.arange(len(dataset))[mask]
for batch in re_batchsampler3:
for b in batch:
assert b not in already_generate_indices
already_generate_indices.update(batch)
count += 1
if count > 5:
break
# 再 save 不允许再上个epoch没结束继续sample
after_batch_size = 5
with pytest.raises(RuntimeError):
state = re_batchsampler3.state_dict()
for batch in re_batchsampler3: # consume all, 这样才能save
pass
already_generate_indices = set()
count = 0
for batch in re_batchsampler3: # 重新开始
for b in batch:
assert b not in already_generate_indices
already_generate_indices.update(batch)
count += 1
if count > 5:
break
state = re_batchsampler3.state_dict()
# 这里的 drop_last 为 False需要最终是所有 sample
re_batchsampler4 = RandomBatchSampler(dataset, length=dataset.data, batch_size=after_batch_size,
drop_last=False,
shuffle=shuffle)
re_batchsampler4.load_state_dict(state)
re_batchsampler4.set_epoch(0)
mask = np.ones(len(dataset), dtype=bool)
mask[list(already_generate_indices)] = 0
for batch in re_batchsampler4:
for b in batch:
assert b not in already_generate_indices
already_generate_indices.update(batch)
assert len(already_generate_indices) == len(dataset)
@pytest.mark.parametrize('shuffle', [True, False])
@pytest.mark.parametrize('drop_last', [True, False])
@pytest.mark.parametrize('pad', [True, False])
def test_multi(self, shuffle, drop_last, pad):
# def test_multi(self, shuffle=True, drop_last=False, pad=False):
# no shuffle
num_replica = 2
dataset = DatasetWithVaryLength(num_of_data=1000)
batch_size = 5
num_batch_per_bucket = 10
lengths = []
rank0_already_seen_indexes = None
max_diff = num_batch_per_bucket * batch_size * num_replica
for rank in range(num_replica):
sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size = batch_size,
shuffle = shuffle, drop_last=drop_last)
sampler.set_epoch(0)
sampler.set_distributed(num_replica, rank=rank, pad=pad)
lengths.append(len(sampler))
already_seen_indexes = set()
repeat_count = 0
for batch in sampler:
for b in batch:
repeat_count += int(b in already_seen_indexes)
if rank0_already_seen_indexes: # 不能交叉出现
assert b not in rank0_already_seen_indexes
already_seen_indexes.update(batch)
if rank0_already_seen_indexes is None:
rank0_already_seen_indexes = already_seen_indexes
if pad: # 应该允许重复一次
assert repeat_count<=1
else:
assert repeat_count==0
assert len(set(lengths))==1, lengths # 每个进程的batch数量一致
# 多进程的保存
already_seen_indexes = set()
for rank in range(num_replica):
sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size = batch_size,
shuffle = shuffle, drop_last=drop_last)
sampler.set_epoch(0)
sampler.set_distributed(num_replica, rank=rank, pad=pad)
lengths.append(len(sampler))
count = 0
for batch in sampler:
already_seen_indexes.update(batch)
if count>5:
break
count += 1
state = sampler.state_dict()
# 切换成单机
new_batch_size = 6
num_batch_per_bucket = 3
new_sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=new_batch_size,
shuffle=shuffle, drop_last=drop_last)
new_sampler.load_state_dict(state)
repeat_count = 0
new_already_seen_indexes = set(list(already_seen_indexes))
mask = np.ones(len(dataset), dtype=bool)
mask[list(already_seen_indexes)] = 0
indices = np.arange(len(dataset))[mask]
for batch in new_sampler:
for b in batch:
repeat_count += int(b in new_already_seen_indexes)
new_already_seen_indexes.update(batch)
if pad: # 应该允许重复一次
assert repeat_count <= 1
else:
assert repeat_count == 0
if drop_last is False: # 如果没有drop应该相等
assert len(new_already_seen_indexes)==len(dataset)
# 测试替换卡的数量。
num_replica = 3
new_sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=new_batch_size,
shuffle=shuffle, drop_last=drop_last)
new_sampler.set_epoch(0)
new_sampler.load_state_dict(state)
new_sampler.set_distributed(num_replicas=num_replica, rank=1, pad=pad)
repeat_count = 0
mask = np.ones(len(dataset), dtype=bool)
mask[list(already_seen_indexes)] = 0
indices = np.arange(len(dataset))[mask]
for batch in new_sampler:
for b in batch:
repeat_count += int(b in already_seen_indexes)
if pad: # 应该允许重复一次
assert repeat_count <= 1
else:
assert repeat_count == 0
@pytest.mark.parametrize('shuffle', [True, False])
@pytest.mark.parametrize('drop_last', [True, False])
@pytest.mark.parametrize('pad', [True, False])
@pytest.mark.parametrize('num_samples', [13, 100, 623, 1000])
@pytest.mark.parametrize('num_replicas', [2, 3])
def test_multi_same_bucket(self, shuffle, drop_last, pad, num_samples, num_replicas):
# def test_multi_same_bucket(self, shuffle=True, drop_last=True, pad=True, num_samples=623, num_replicas=2):
dataset = DatasetWithVaryLength(num_of_data=num_samples)
batch_size = 6
if num_replicas*batch_size > num_samples:
return
num_batch_per_bucket = 10
samplers = []
lengths = []
for i in range(num_replicas):
sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=batch_size,
shuffle=shuffle, drop_last=drop_last)
sampler.set_distributed(num_replicas, rank=i, pad=pad)
sampler.set_epoch(0)
samplers.append(sampler)
lengths.append(len(list(iter(sampler))))
assert len(set(lengths))==1
@pytest.mark.parametrize('shuffle', [True, False])
@pytest.mark.parametrize('drop_last', [True, False])
@pytest.mark.parametrize('pad', [True, False])
@pytest.mark.parametrize('num_samples', [13, 100, 623, 1000])
@pytest.mark.parametrize('num_replicas', [1, 2, 3])
def test_multi_save_load(self, shuffle, drop_last, pad, num_samples, num_replicas):
"""
测试是否能够正确地恢复使用过的forward数据
:return:
"""
batch_size = 6
dataset = DatasetWithVaryLength(num_of_data=num_samples)
samplers = []
num_consumed_samples_array = list(range(0, num_samples+num_replicas, num_replicas))
for i in range(num_replicas):
sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=batch_size,
shuffle=shuffle, drop_last=drop_last)
sampler.set_distributed(num_replicas=num_replicas, rank=i, pad=pad)
samplers.append(sampler)
count = 0
already_seen_sets = [set()]
already_seen_set = set()
for batchs in zip(*samplers):
batch = chain(*batchs)
already_seen_set.update(batch)
already_seen_sets.append(deepcopy(already_seen_set))
count += 1
if count > 3:
break
states = samplers[0].state_dict()
for i in range(len(already_seen_sets)):
states['num_consumed_samples'] = num_consumed_samples_array[i]
sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size+1,
shuffle=shuffle, drop_last=drop_last)
sampler.set_epoch(0)
already_seen_set = deepcopy(already_seen_sets[i])
for batch in sampler:
already_seen_set.update(batch)
assert len(already_seen_set) == len(dataset) if drop_last is False else len(already_seen_set) <= len(
dataset)
# 测试保存之后再次保存
sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=batch_size + 1,
shuffle=shuffle,
drop_last=drop_last)
sampler.set_epoch(0)
states['num_consumed_samples'] = num_consumed_samples_array[2]
if len(already_seen_sets)<3:
return
already_seen_set = already_seen_sets[2]
count = 0
for batch in sampler:
already_seen_set.update(batch)
count += 1
if count > 6:
break
states = sampler.state_dict()
num_consumed_samples_array = list(range(len(dataset)))
states['num_consumed_samples'] = num_consumed_samples_array[count]
sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=batch_size//2,
shuffle=shuffle,
drop_last=drop_last)
sampler.load_state_dict(states)
sampler.set_epoch(0)
for batch in sampler:
already_seen_set.update(batch)
assert len(already_seen_set)==len(dataset) if drop_last is False else len(already_seen_set)<=len(dataset)

View File

@ -0,0 +1,141 @@
from array import array
import torch
from torch.utils.data import DataLoader
import pytest
from fastNLP.core.samplers import ReproduceBatchSampler
from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler
from tests.helpers.datasets.torch_data import TorchNormalDataset
@pytest.mark.torch
class TestReproducibleBatchSamplerTorch:
def test_torch_dataloader_1(self):
# no shuffle
before_batch_size = 7
dataset = TorchNormalDataset(num_of_data=100)
dataloader = DataLoader(dataset, batch_size=before_batch_size)
re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
dataloader = replace_batch_sampler(dataloader, re_batchsampler)
forward_steps = 3
iter_dataloader = iter(dataloader)
for _ in range(forward_steps):
next(iter_dataloader)
# 1. 保存状态
_get_re_batchsampler = dataloader.batch_sampler
assert isinstance(_get_re_batchsampler, ReproduceBatchSampler)
state = _get_re_batchsampler.state_dict()
assert state == {"index_list": array("I", list(range(100))), "num_consumed_samples": forward_steps*before_batch_size,
"sampler_type": "ReproduceBatchSampler"}
# 2. 断点重训,重新生成一个 dataloader
# 不改变 batch_size
dataloader = DataLoader(dataset, batch_size=before_batch_size)
re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
re_batchsampler.load_state_dict(state)
dataloader = replace_batch_sampler(dataloader, re_batchsampler)
real_res = []
supposed_res = (torch.tensor(list(range(21, 28))), torch.tensor(list(range(28, 35))))
forward_steps = 2
iter_dataloader = iter(dataloader)
for _ in range(forward_steps):
real_res.append(next(iter_dataloader))
for i in range(forward_steps):
assert all(real_res[i] == supposed_res[i])
# 改变 batch_size
after_batch_size = 3
dataloader = DataLoader(dataset, batch_size=after_batch_size)
re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
re_batchsampler.load_state_dict(state)
dataloader = replace_batch_sampler(dataloader, re_batchsampler)
real_res = []
supposed_res = (torch.tensor(list(range(21, 24))), torch.tensor(list(range(24, 27))))
forward_steps = 2
iter_dataloader = iter(dataloader)
for _ in range(forward_steps):
real_res.append(next(iter_dataloader))
for i in range(forward_steps):
assert all(real_res[i] == supposed_res[i])
# 断点重训的第二轮是否是一个完整的 dataloader
# 先把断点重训所在的那一个 epoch 跑完;
begin_idx = 27
while True:
try:
data = next(iter_dataloader)
_batch_size = len(data)
assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size))))
begin_idx += _batch_size
except StopIteration:
break
# 开始新的一轮;
begin_idx = 0
iter_dataloader = iter(dataloader)
while True:
try:
data = next(iter_dataloader)
_batch_size = len(data)
assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size))))
begin_idx += _batch_size
except StopIteration:
break
def test_torch_dataloader_2(self):
# 测试新的一轮的 index list 是重新生成的,而不是沿用上一轮的;
from torch.utils.data import DataLoader
before_batch_size = 7
dataset = TorchNormalDataset(num_of_data=100)
# 开启 shuffle来检验断点重训后的第二轮的 index list 是不是重新生成的;
dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True)
re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
dataloader = replace_batch_sampler(dataloader, re_batchsampler)
# 将一轮的所有数据保存下来,看是否恢复的是正确的;
all_supposed_data = []
forward_steps = 3
iter_dataloader = iter(dataloader)
for _ in range(forward_steps):
all_supposed_data.extend(next(iter_dataloader).tolist())
# 1. 保存状态
_get_re_batchsampler = dataloader.batch_sampler
assert isinstance(_get_re_batchsampler, ReproduceBatchSampler)
state = _get_re_batchsampler.state_dict()
# 2. 断点重训,重新生成一个 dataloader
# 不改变 batch_size
dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True)
re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
re_batchsampler.load_state_dict(state)
dataloader = replace_batch_sampler(dataloader, re_batchsampler)
iter_dataloader = iter(dataloader)
# 先把这一轮的数据过完;
pre_index_list = dataloader.batch_sampler.state_dict()["index_list"]
while True:
try:
all_supposed_data.extend(next(iter_dataloader).tolist())
except StopIteration:
break
assert all_supposed_data == list(pre_index_list)
# 重新开启新的一轮;
for _ in range(3):
iter_dataloader = iter(dataloader)
res = []
while True:
try:
res.extend(next(iter_dataloader).tolist())
except StopIteration:
break
assert res != all_supposed_data

View File

@ -3,6 +3,7 @@ import pytest
import subprocess
from io import StringIO
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), '../../..'))
from fastNLP.core.utils.cache_results import cache_results
from fastNLP.core import rank_zero_rm

View File

@ -1,4 +1,5 @@
import os
import pytest
from fastNLP.envs.set_backend import dump_fastnlp_backend
from tests.helpers.utils import Capturing
@ -9,7 +10,7 @@ def test_dump_fastnlp_envs():
filepath = None
try:
with Capturing() as output:
dump_fastnlp_backend()
dump_fastnlp_backend(backend="torch")
filepath = os.path.join(os.path.expanduser('~'), '.fastNLP', 'envs', os.environ['CONDA_DEFAULT_ENV']+'.json')
assert filepath in output[0]
assert os.path.exists(filepath)

View File

@ -1,7 +1,9 @@
import torch
from copy import deepcopy
from fastNLP.core.callbacks.callback import Callback
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
import torch
class RecordAccumulationStepsCallback_Torch(Callback):

View File

@ -1,13 +1,25 @@
import numpy as np
import random
class NormalIterator:
def __init__(self, num_of_data=1000):
class NormalSampler:
def __init__(self, num_of_data=1000, shuffle=False):
self._num_of_data = num_of_data
self._data = list(range(num_of_data))
if shuffle:
random.shuffle(self._data)
self.shuffle = shuffle
self._index = 0
self.need_reinitialize = False
def __iter__(self):
if self.need_reinitialize:
self._index = 0
if self.shuffle:
random.shuffle(self._data)
else:
self.need_reinitialize = True
return self
def __next__(self):
@ -15,12 +27,45 @@ class NormalIterator:
raise StopIteration
_data = self._data[self._index]
self._index += 1
return self._data
return _data
def __len__(self):
return self._num_of_data
class NormalBatchSampler:
def __init__(self, sampler, batch_size: int, drop_last: bool) -> None:
# Since collections.abc.Iterable does not check for `__getitem__`, which
# is one way for an object to be an iterable, we don't do an `isinstance`
# check here.
if not isinstance(batch_size, int) or isinstance(batch_size, bool) or \
batch_size <= 0:
raise ValueError("batch_size should be a positive integer value, "
"but got batch_size={}".format(batch_size))
if not isinstance(drop_last, bool):
raise ValueError("drop_last should be a boolean value, but got "
"drop_last={}".format(drop_last))
self.sampler = sampler
self.batch_size = batch_size
self.drop_last = drop_last
def __iter__(self):
batch = []
for idx in self.sampler:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
if len(batch) > 0 and not self.drop_last:
yield batch
def __len__(self) -> int:
if self.drop_last:
return len(self.sampler) // self.batch_size
else:
return (len(self.sampler) + self.batch_size - 1) // self.batch_size
class RandomDataset:
def __init__(self, num_data=10):
self.data = np.random.rand(num_data)
@ -29,4 +74,7 @@ class RandomDataset:
return len(self.data)
def __getitem__(self, item):
return self.data[item]
return self.data[item]

View File

@ -1,7 +1,11 @@
import torch
from functools import reduce
from torch.utils.data import Dataset, DataLoader, DistributedSampler
from torch.utils.data.sampler import SequentialSampler, BatchSampler
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
from torch.utils.data import Dataset, DataLoader, DistributedSampler
from torch.utils.data.sampler import SequentialSampler, BatchSampler
else:
from fastNLP.core.utils.dummy_class import DummyClass as Dataset
class TorchNormalDataset(Dataset):

View File

@ -1,9 +1,14 @@
import torch
import torch.nn as nn
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
import torch
from torch.nn import Module
import torch.nn as nn
else:
from fastNLP.core.utils.dummy_class import DummyClass as Module
# 1. 最为基础的分类模型
class TorchNormalModel_Classification_1(nn.Module):
class TorchNormalModel_Classification_1(Module):
"""
单独实现 train_step evaluate_step
"""
@ -38,7 +43,7 @@ class TorchNormalModel_Classification_1(nn.Module):
return {"preds": x, "target": y}
class TorchNormalModel_Classification_2(nn.Module):
class TorchNormalModel_Classification_2(Module):
"""
只实现一个 forward 函数来测试用户自己在外面初始化 DDP 的场景
"""
@ -62,7 +67,7 @@ class TorchNormalModel_Classification_2(nn.Module):
return {"loss": loss, "preds": x, "target": y}
class TorchNormalModel_Classification_3(nn.Module):
class TorchNormalModel_Classification_3(Module):
"""
只实现一个 forward 函数来测试用户自己在外面初始化 DDP 的场景
关闭 auto_param_callforward 只有一个 batch 参数

6
tests/pytest.ini Normal file
View File

@ -0,0 +1,6 @@
[pytest]
markers =
torch
paddle
jittor
torchpaddle