增加部分文档; Trainer

;Evaluator报错会尝试打印indices
This commit is contained in:
yh_cc 2022-05-10 21:09:49 +08:00
parent 7763b2e087
commit 8c5ac2776c
7 changed files with 44 additions and 27 deletions

View File

@ -34,7 +34,7 @@ class EvaluateBatchLoop(Loop):
except BaseException as e:
if callable(getattr(dataloader, 'get_batch_indices', None)):
indices = dataloader.get_batch_indices()
logger.debug(f"The following exception happens when running on samples: {indices}")
logger.error(f"Exception happens when evaluating on samples: {indices}")
raise e
self.batch_step_fn(evaluator, batch)

View File

@ -32,7 +32,7 @@ class TrainBatchLoop(Loop):
break
except BaseException as e:
if indices and not isinstance(e, EarlyStopException):
logger.debug(f"The following exception happens when running on samples: {indices}")
logger.error(f"Exception happens when running on samples: {indices}")
raise e
trainer.on_train_batch_begin(batch, indices)

View File

@ -514,7 +514,7 @@ class Trainer(TrainerEventTrigger):
else:
raise FileNotFoundError("You are using `resume_from`, but we can not find your specific file.")
if self.evaluator is not None and num_eval_sanity_batch > 0:
if self.evaluator is not None and num_eval_sanity_batch != 0:
logger.info(f"Running evaluator sanity check for {num_eval_sanity_batch} batches.")
self.on_sanity_check_begin()
sanity_check_res = self.evaluator.run(num_eval_batch_per_dl=num_eval_sanity_batch)

View File

@ -1,7 +1,7 @@
__all__ = [
'print'
]
from logging import INFO
from .logger import logger
@ -22,4 +22,6 @@ def print(*args, sep=' ', end='\n', file=None, flush=False):
:return:
"""
line = sep.join(map(str, args))
logger.info(line)
if logger.isEnabledFor(INFO):
kwargs = logger._add_rank_info({})
logger._log(INFO, line, args, **kwargs)

View File

@ -84,7 +84,7 @@ class Metric:
def _sync_get_metric(self, get_metric):
@functools.wraps(get_metric)
def _wrap_get_metric(*args, **kwargs):
assert self._updated, f"You have to call `{self.__class__.__name__}` update() function before calling " \
assert self._updated, f"You have to call `{self.__class__.__name__}'s update() function before calling " \
f"get_metric()."
with self.sync(recover=True, aggregate=self.aggregate_when_get_metric):
results = get_metric(*args, **kwargs)

View File

@ -366,17 +366,22 @@ class BucketedBatchSampler(ReproducibleBatchSampler):
def __init__(self, dataset, length: Union[List[int], str], batch_size:int = 32, num_batch_per_bucket:int = 10,
shuffle: bool = True, drop_last: bool = False, seed: int = 0, **kwargs):
"""
首先按照 sample 的长度排序然后按照 batch_size*num_batch_per_bucket 为一个桶的大小sample 只会在这个桶内进行组这样
每个 batch 中的 padding 数量会比较少 因为桶内的数据的长度都接近
首先按照 ``sample`` 的长度排序然后按照 batch_size*num_batch_per_bucket 为一个桶的大小``sample`` 只会在这个桶内进行组
这样每个 ``batch`` 中的 ``padding`` 数量会比较少 因为桶内的数据的长度都接近
:param dataset: 实现了 __len__ 方法的数据容器
:param length: 如果为 List应当与 dataset 有一样的长度表示 dataset 中每个元素的数量仅当传入的 dataset fastNLP
DataSet 时支持传入 str会将该str理解为 dataset field 名称 field 中的元素为 int则认为该值是 sample 的长度
如果否则使用 len() 函数得到每个 sample 中这个 field 的长度
:param length: 每条数据的长度
* ``List[int]``
应当与 dataset 有一样的长度表示 dataset 中每个元素的数量
* ``str``
仅当传入的 ``dataset`` :class:`fastNLP.DataSet` 允许传入 `str` `str` 将被认为是 ``dataset`` 中的
``field`` field 中的元素为 ``int``则认为该值是 sample 的长度若不为 ``int`` 则尝试使用 ``len`` 方法
获取该 ``field`` 中每个元素的长度
:param batch_size: 每个 batch 的大小
:param num_batch_per_bucket: 多少个 batch 组成一个桶数据只会在一个桶内进行 shuffle
:param shuffle: 如果为 True将不进行 shuffle实际上数据会以从长到短的方式输出
:param drop_last: 如果最后一个 batch sample 数量无法凑齐 batch_size 这么多是否需要丢掉
:param num_batch_per_bucket: 多少个 ``batch`` 组成一个桶数据只会在一个桶内进行 ``shuffle``
:param shuffle: 如果为 True将不进行 ``shuffle``实际上数据会以从长到短的方式输出
:param drop_last: 如果最后一个 `batch` ``sample`` 数量无法凑齐 ``batch_size`` 这么多是否需要丢掉
:param seed: 设置的随机数种子
:param kwargs: fastNLP 保留使用
"""
@ -386,10 +391,12 @@ class BucketedBatchSampler(ReproducibleBatchSampler):
if not isinstance(length[0], int):
length = list(map(len, length))
else:
assert len(length) == len(dataset), "When the dataset is not fastNLP.DataSet, " \
"the length parameter can only be List[int]"
types = set(map(type, length))
assert isinstance(length, list) and len(types)==1 and types.pop()==int, \
"When the dataset is not fastNLP.DataSet, the length parameter can only be List[int]"
assert len(length) == len(dataset), "The length of `data` and `length` should be equal."
assert len(length) == len(dataset), f"The length of `dataset`({len(dataset)}) and " \
f"`length`({len(length)}) should be equal."
self.dataset = dataset
self.length = np.array(length, dtype=int) # 按照长到短排列的序号。

View File

@ -55,6 +55,7 @@ class ReproducibleSampler:
class RandomSampler(ReproducibleSampler):
def __init__(self, dataset, shuffle: bool = True, seed: int = 0, **kwargs):
"""
随机顺序的 Sampler
:param dataset: 实现了 __len__ 方法的数据容器
:param shuffle: 是否在每次 iterate 的时候打乱顺序
@ -169,9 +170,8 @@ class RandomSampler(ReproducibleSampler):
def set_epoch(self, epoch: int) -> None:
self.epoch = epoch
def set_distributed(self, num_replicas, rank, pad=True):
def set_distributed(self, num_replicas:int, rank:int, pad:bool=True):
"""
该方法本质上等同于 ddp 情形下的没有完成的初始化应当在初始化该 sampler 本身后立即被调用
:param num_replicas:
:param rank:
@ -215,7 +215,7 @@ class RandomSampler(ReproducibleSampler):
class SequentialSampler(RandomSampler):
def __init__(self, dataset, **kwargs):
"""
按照顺序读取 dataset 在多卡情况下间隔读取例如在两卡情况下0取 [0,2,4,..], 卡1取 [1,3,5...]
按照顺序读取 ``dataset`` 在多卡情况下间隔读取例如在两卡情况下 0 ``[0,2,4,..]``, 卡1取 ``[1,3,5...]``
:param dataset: 实现了 __len__ 方法的数据容器
:param kwargs:
@ -285,13 +285,20 @@ class SequentialSampler(RandomSampler):
class SortedSampler(SequentialSampler):
def __init__(self, dataset, length:Union[str, List], **kwargs):
"""
dataset 中的数据根据 length 从长到短进行迭代在多卡情况下由于padding 最后一个 sample 可能是最长的那个 sample
``dataset`` 中的数据根据 ``length`` 从长到短进行迭代在多卡情况下由于 ``padding`` , 最后一个 ``sample`` 可能是最长
的那个 ``sample``
:param dataset: 实现了 __len__ 方法的数据容器
:param length: 如果为 List应当与 dataset 有一样的长度表示 dataset 中每个元素的数量仅当传入的 dataset fastNLP
DataSet 时支持传入 str会将该str理解为 dataset field 名称 field 中的元素为 int则认为该值是 sample 的长度
:param seed: 设置的随机数种子
:param kwargs: fastNLP 保留使用
:param length: 每条数据的长度
* ``List[int]``
应当与 dataset 有一样的长度表示 dataset 中每个元素的数量
* ``str``
仅当传入的 ``dataset`` :class:`fastNLP.DataSet` 允许传入 `str` `str` 将被认为是 ``dataset`` 中的
``field`` field 中的元素为 ``int``则认为该值是 sample 的长度若不为 ``int`` 则尝试使用 ``len`` 方法
获取该 ``field`` 中每个元素的长度
:param seed: 设置的随机数种子
:param kwargs: fastNLP 保留使用
"""
super().__init__(dataset=dataset, **kwargs)
if isinstance(dataset, DataSet) and isinstance(length, str):
@ -299,8 +306,9 @@ class SortedSampler(SequentialSampler):
if not isinstance(length[0], int):
length = list(map(len, length))
else:
assert len(length) == len(dataset), "When the dataset is not fastNLP.DataSet, " \
"the length parameter can only be List[int]"
types = set(map(type, length))
assert isinstance(length, list) and len(types)==1 and types.pop()==int, \
"When the dataset is not fastNLP.DataSet, the length parameter can only be List[int]"
assert len(length) == len(dataset), "The length of `data` and `length` should be equal."