diff --git a/fastNLP/embeddings/bert_embedding.py b/fastNLP/embeddings/bert_embedding.py index 261007ae..b079f69f 100644 --- a/fastNLP/embeddings/bert_embedding.py +++ b/fastNLP/embeddings/bert_embedding.py @@ -176,9 +176,9 @@ class BertWordPieceEncoder(nn.Module): def index_datasets(self, *datasets, field_name, add_cls_sep=True): """ 使用bert的tokenizer新生成word_pieces列加入到datasets中,并将他们设置为input,且将word_pieces这一列的pad value设置为了 - bert的pad value。 + bert的pad value。 - :param DataSet datasets: DataSet对象 + :param ~fastNLP.DataSet datasets: DataSet对象 :param str field_name: 基于哪一列的内容生成word_pieces列。这一列中每个数据应该是List[str]的形式。 :param bool add_cls_sep: 如果首尾不是[CLS]与[SEP]会在首尾额外加入[CLS]与[SEP]。 :return: diff --git a/fastNLP/io/base_loader.py b/fastNLP/io/base_loader.py index 01232627..429a8406 100644 --- a/fastNLP/io/base_loader.py +++ b/fastNLP/io/base_loader.py @@ -128,7 +128,7 @@ class DataBundle: """ 向DataBunlde中增加vocab - :param Vocabulary vocab: 词表 + :param ~fastNLP.Vocabulary vocab: 词表 :param str field_name: 这个vocab对应的field名称 :return: """ @@ -138,7 +138,7 @@ class DataBundle: def set_dataset(self, dataset, name): """ - :param DataSet dataset: 传递给DataBundle的DataSet + :param ~fastNLP.DataSet dataset: 传递给DataBundle的DataSet :param str name: dataset的名称 :return: """ diff --git a/fastNLP/io/file_utils.py b/fastNLP/io/file_utils.py index b465ed9b..43fe2ab1 100644 --- a/fastNLP/io/file_utils.py +++ b/fastNLP/io/file_utils.py @@ -84,6 +84,7 @@ def cached_path(url_or_filename:str, cache_dir:str=None, name=None) -> Path: 给定一个url,尝试通过url中的解析出来的文件名字filename到{cache_dir}/{name}/{filename}下寻找这个文件, (1)如果cache_dir=None, 则cache_dir=~/.fastNLP/; 否则cache_dir=cache_dir (2)如果name=None, 则没有中间的{name}这一层结构;否者中间结构就为{name} + 如果有该文件,就直接返回路径 如果没有该文件,则尝试用传入的url下载 @@ -126,8 +127,10 @@ def get_filepath(filepath): 如果filepath为文件夹, 如果内含多个文件, 返回filepath 如果只有一个文件, 返回filepath + filename + 如果filepath为文件 返回filepath + :param str filepath: 路径 :return: """ @@ -237,7 +240,8 @@ def split_filename_suffix(filepath): def get_from_cache(url: str, cache_dir: Path = None) -> Path: """ 尝试在cache_dir中寻找url定义的资源; 如果没有找到; 则从url下载并将结果放在cache_dir下,缓存的名称由url的结果推断而来。会将下载的 - 文件解压,将解压后的文件全部放在cache_dir文件夹中。 + 文件解压,将解压后的文件全部放在cache_dir文件夹中。 + 如果从url中下载的资源解压后有多个文件,则返回目录的路径; 如果只有一个资源文件,则返回具体的路径。 """ cache_dir.mkdir(parents=True, exist_ok=True) diff --git a/fastNLP/io/loader/__init__.py b/fastNLP/io/loader/__init__.py index 4905a34f..8c0d391c 100644 --- a/fastNLP/io/loader/__init__.py +++ b/fastNLP/io/loader/__init__.py @@ -1,5 +1,5 @@ """ -Loader用于读取数据,并将内容读取到 :class:`~fastNLP.DataSet` 或者 :class:`~fastNLP.io.DataBundle`中。所有的Loader都支持以下的 +Loader用于读取数据,并将内容读取到 :class:`~fastNLP.DataSet` 或者 :class:`~fastNLP.io.DataBundle` 中。所有的Loader都支持以下的 三个方法: __init__(),_load(), loads(). 其中__init__()用于申明读取参数,以及说明该Loader支持的数据格式,读取后Dataset中field ; _load(path)方法传入一个文件路径读取单个文件,并返回DataSet; load(paths)用于读取文件夹下的文件,并返回DataBundle, load()方法 支持以下三种类型的参数:: diff --git a/fastNLP/io/pipe/classification.py b/fastNLP/io/pipe/classification.py index 1b111e40..429b6552 100644 --- a/fastNLP/io/pipe/classification.py +++ b/fastNLP/io/pipe/classification.py @@ -257,7 +257,7 @@ class SSTPipe(_CLSPipe): "(4 (4 (2 Offers) (3 (3 (2 that) (3 (3 rare)..." "..." - :param DataBundle data_bundle: 需要处理的DataBundle对象 + :param ~fastNLP.io.DataBundle data_bundle: 需要处理的DataBundle对象 :return: """ # 先取出subtree @@ -407,7 +407,7 @@ class IMDBPipe(_CLSPipe): :param DataBunlde data_bundle: 传入的DataBundle中的DataSet必须包含raw_words和target两个field,且raw_words列应该为str, target列应该为str。 - :return:DataBundle + :return: DataBundle """ # 替换
def replace_br(raw_words): diff --git a/fastNLP/io/utils.py b/fastNLP/io/utils.py index a4ca2954..76b32b0a 100644 --- a/fastNLP/io/utils.py +++ b/fastNLP/io/utils.py @@ -6,12 +6,14 @@ from pathlib import Path def check_loader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]: """ - 检查传入dataloader的文件的合法性。如果为合法路径,将返回至少包含'train'这个key的dict。类似于下面的结果 - { - 'train': '/some/path/to/', # 一定包含,建词表应该在这上面建立,剩下的其它文件应该只需要处理并index。 - 'test': 'xxx' # 可能有,也可能没有 - ... - } + 检查传入dataloader的文件的合法性。如果为合法路径,将返回至少包含'train'这个key的dict。类似于下面的结果:: + + { + 'train': '/some/path/to/', # 一定包含,建词表应该在这上面建立,剩下的其它文件应该只需要处理并index。 + 'test': 'xxx' # 可能有,也可能没有 + ... + } + 如果paths为不合法的,将直接进行raise相应的错误. 如果paths内不包含train也会报错。 :param str paths: 路径. 可以为一个文件路径(则认为该文件就是train的文件); 可以为一个文件目录,将在该目录下寻找train(文件名 diff --git a/fastNLP/modules/utils.py b/fastNLP/modules/utils.py index 21608c5d..ead75711 100644 --- a/fastNLP/modules/utils.py +++ b/fastNLP/modules/utils.py @@ -112,7 +112,7 @@ def get_dropout_mask(drop_p: float, tensor: torch.Tensor): 根据tensor的形状,生成一个mask :param drop_p: float, 以多大的概率置为0。 - :param tensor:torch.Tensor + :param tensor: torch.Tensor :return: torch.FloatTensor. 与tensor一样的shape """ mask_x = torch.ones_like(tensor)