mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-03 12:47:35 +08:00
fix some mistakes
This commit is contained in:
parent
cdf8406ec1
commit
ad9d5eba3a
@ -176,9 +176,9 @@ class BertWordPieceEncoder(nn.Module):
|
||||
def index_datasets(self, *datasets, field_name, add_cls_sep=True):
|
||||
"""
|
||||
使用bert的tokenizer新生成word_pieces列加入到datasets中,并将他们设置为input,且将word_pieces这一列的pad value设置为了
|
||||
bert的pad value。
|
||||
bert的pad value。
|
||||
|
||||
:param DataSet datasets: DataSet对象
|
||||
:param ~fastNLP.DataSet datasets: DataSet对象
|
||||
:param str field_name: 基于哪一列的内容生成word_pieces列。这一列中每个数据应该是List[str]的形式。
|
||||
:param bool add_cls_sep: 如果首尾不是[CLS]与[SEP]会在首尾额外加入[CLS]与[SEP]。
|
||||
:return:
|
||||
|
@ -128,7 +128,7 @@ class DataBundle:
|
||||
"""
|
||||
向DataBunlde中增加vocab
|
||||
|
||||
:param Vocabulary vocab: 词表
|
||||
:param ~fastNLP.Vocabulary vocab: 词表
|
||||
:param str field_name: 这个vocab对应的field名称
|
||||
:return:
|
||||
"""
|
||||
@ -138,7 +138,7 @@ class DataBundle:
|
||||
def set_dataset(self, dataset, name):
|
||||
"""
|
||||
|
||||
:param DataSet dataset: 传递给DataBundle的DataSet
|
||||
:param ~fastNLP.DataSet dataset: 传递给DataBundle的DataSet
|
||||
:param str name: dataset的名称
|
||||
:return:
|
||||
"""
|
||||
|
@ -84,6 +84,7 @@ def cached_path(url_or_filename:str, cache_dir:str=None, name=None) -> Path:
|
||||
给定一个url,尝试通过url中的解析出来的文件名字filename到{cache_dir}/{name}/{filename}下寻找这个文件,
|
||||
(1)如果cache_dir=None, 则cache_dir=~/.fastNLP/; 否则cache_dir=cache_dir
|
||||
(2)如果name=None, 则没有中间的{name}这一层结构;否者中间结构就为{name}
|
||||
|
||||
如果有该文件,就直接返回路径
|
||||
如果没有该文件,则尝试用传入的url下载
|
||||
|
||||
@ -126,8 +127,10 @@ def get_filepath(filepath):
|
||||
如果filepath为文件夹,
|
||||
如果内含多个文件, 返回filepath
|
||||
如果只有一个文件, 返回filepath + filename
|
||||
|
||||
如果filepath为文件
|
||||
返回filepath
|
||||
|
||||
:param str filepath: 路径
|
||||
:return:
|
||||
"""
|
||||
@ -237,7 +240,8 @@ def split_filename_suffix(filepath):
|
||||
def get_from_cache(url: str, cache_dir: Path = None) -> Path:
|
||||
"""
|
||||
尝试在cache_dir中寻找url定义的资源; 如果没有找到; 则从url下载并将结果放在cache_dir下,缓存的名称由url的结果推断而来。会将下载的
|
||||
文件解压,将解压后的文件全部放在cache_dir文件夹中。
|
||||
文件解压,将解压后的文件全部放在cache_dir文件夹中。
|
||||
|
||||
如果从url中下载的资源解压后有多个文件,则返回目录的路径; 如果只有一个资源文件,则返回具体的路径。
|
||||
"""
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
@ -1,5 +1,5 @@
|
||||
"""
|
||||
Loader用于读取数据,并将内容读取到 :class:`~fastNLP.DataSet` 或者 :class:`~fastNLP.io.DataBundle`中。所有的Loader都支持以下的
|
||||
Loader用于读取数据,并将内容读取到 :class:`~fastNLP.DataSet` 或者 :class:`~fastNLP.io.DataBundle` 中。所有的Loader都支持以下的
|
||||
三个方法: __init__(),_load(), loads(). 其中__init__()用于申明读取参数,以及说明该Loader支持的数据格式,读取后Dataset中field
|
||||
; _load(path)方法传入一个文件路径读取单个文件,并返回DataSet; load(paths)用于读取文件夹下的文件,并返回DataBundle, load()方法
|
||||
支持以下三种类型的参数::
|
||||
|
@ -257,7 +257,7 @@ class SSTPipe(_CLSPipe):
|
||||
"(4 (4 (2 Offers) (3 (3 (2 that) (3 (3 rare)..."
|
||||
"..."
|
||||
|
||||
:param DataBundle data_bundle: 需要处理的DataBundle对象
|
||||
:param ~fastNLP.io.DataBundle data_bundle: 需要处理的DataBundle对象
|
||||
:return:
|
||||
"""
|
||||
# 先取出subtree
|
||||
@ -407,7 +407,7 @@ class IMDBPipe(_CLSPipe):
|
||||
|
||||
:param DataBunlde data_bundle: 传入的DataBundle中的DataSet必须包含raw_words和target两个field,且raw_words列应该为str,
|
||||
target列应该为str。
|
||||
:return:DataBundle
|
||||
:return: DataBundle
|
||||
"""
|
||||
# 替换<br />
|
||||
def replace_br(raw_words):
|
||||
|
@ -6,12 +6,14 @@ from pathlib import Path
|
||||
|
||||
def check_loader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]:
|
||||
"""
|
||||
检查传入dataloader的文件的合法性。如果为合法路径,将返回至少包含'train'这个key的dict。类似于下面的结果
|
||||
{
|
||||
'train': '/some/path/to/', # 一定包含,建词表应该在这上面建立,剩下的其它文件应该只需要处理并index。
|
||||
'test': 'xxx' # 可能有,也可能没有
|
||||
...
|
||||
}
|
||||
检查传入dataloader的文件的合法性。如果为合法路径,将返回至少包含'train'这个key的dict。类似于下面的结果::
|
||||
|
||||
{
|
||||
'train': '/some/path/to/', # 一定包含,建词表应该在这上面建立,剩下的其它文件应该只需要处理并index。
|
||||
'test': 'xxx' # 可能有,也可能没有
|
||||
...
|
||||
}
|
||||
|
||||
如果paths为不合法的,将直接进行raise相应的错误. 如果paths内不包含train也会报错。
|
||||
|
||||
:param str paths: 路径. 可以为一个文件路径(则认为该文件就是train的文件); 可以为一个文件目录,将在该目录下寻找train(文件名
|
||||
|
@ -112,7 +112,7 @@ def get_dropout_mask(drop_p: float, tensor: torch.Tensor):
|
||||
根据tensor的形状,生成一个mask
|
||||
|
||||
:param drop_p: float, 以多大的概率置为0。
|
||||
:param tensor:torch.Tensor
|
||||
:param tensor: torch.Tensor
|
||||
:return: torch.FloatTensor. 与tensor一样的shape
|
||||
"""
|
||||
mask_x = torch.ones_like(tensor)
|
||||
|
Loading…
Reference in New Issue
Block a user