Merge pull request #159 from fastnlp/dev0.4.0

Dev0.4.0
This commit is contained in:
ChenXin 2019-06-04 18:05:35 +08:00 committed by GitHub
commit d539ba3b14
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 506 additions and 284 deletions

View File

@ -438,26 +438,29 @@ class EarlyStopCallback(Callback):
class FitlogCallback(Callback):
"""
该callback将loss和progress自动写入到fitlog中; 如果Trainer有dev的数据将自动把dev的结果写入到log中; 同时还支持传入
一个(或多个)test数据集进行测试(只有在trainer具有dev时才能使用)每次在dev上evaluate之后会在这些数据集上验证一下
并将验证结果写入到fitlog中这些数据集的结果是根据dev上最好的结果报道的即如果dev在第3个epoch取得了最佳
fitlog中记录的关于这些数据集的结果就是来自第三个epoch的结果
别名: :class:`fastNLP.FitlogCallback` :class:`fastNLP.core.callback.FitlogCallback`
该callback可将loss和progress写入到fitlog中; 如果Trainer有dev的数据将自动把dev的结果写入到log中; 同时还支持传入
一个(或多个)test数据集进行测试(只有在trainer具有dev时才能使用)每次在dev上evaluate之后会在这些数据集上验证一下
并将验证结果写入到fitlog中这些数据集的结果是根据dev上最好的结果报道的即如果dev在第3个epoch取得了最佳
fitlog中记录的关于这些数据集的结果就是来自第三个epoch的结果
:param DataSet,dict(DataSet) data: 传入DataSet对象会使用多个Trainer中的metric对数据进行验证如果需要传入多个
DataSet请通过dict的方式传入dict的key将作为对应dataset的name传递给fitlog若tester不为None时data需要通过
dict的方式传入如果仅传入DataSet, 则被命名为test
:param Tester tester: Tester对象将在on_valid_end时调用tester中的DataSet会被称为为`test`
:param int verbose: 是否在终端打印内容0不打印
:param int log_loss_every: 多少个step记录一次loss(记录的是这几个batch的loss平均值)如果数据集较大建议将该值设置得
大一些不然会导致log文件巨大默认为0, 即不要记录loss
:param int verbose: 是否在终端打印evaluation的结果0不打印
:param bool log_exception: fitlog是否记录发生的exception信息
"""
# 还没有被导出到 fastNLP 层
# 别名: :class:`fastNLP.FitlogCallback` :class:`fastNLP.core.callback.FitlogCallback`
def __init__(self, data=None, tester=None, verbose=0, log_exception=False):
def __init__(self, data=None, tester=None, log_loss_every=0, verbose=0, log_exception=False):
super().__init__()
self.datasets = {}
self.testers = {}
self._log_exception = log_exception
assert isinstance(log_loss_every, int) and log_loss_every>=0
if tester is not None:
assert isinstance(tester, Tester), "Only fastNLP.Tester allowed."
assert isinstance(data, dict) or data is None, "If tester is not None, only dict[DataSet] allowed for data."
@ -477,7 +480,9 @@ class FitlogCallback(Callback):
raise TypeError("data receives dict[DataSet] or DataSet object.")
self.verbose = verbose
self._log_loss_every = log_loss_every
self._avg_loss = 0
def on_train_begin(self):
if (len(self.datasets) > 0 or len(self.testers) > 0) and self.trainer.dev_data is None:
raise RuntimeError("Trainer has no dev data, you cannot pass extra data to do evaluation.")
@ -490,8 +495,12 @@ class FitlogCallback(Callback):
fitlog.add_progress(total_steps=self.n_steps)
def on_backward_begin(self, loss):
fitlog.add_loss(loss.item(), name='loss', step=self.step, epoch=self.epoch)
if self._log_loss_every>0:
self._avg_loss += loss.item()
if self.step%self._log_loss_every==0:
fitlog.add_loss(self._avg_loss/self._log_loss_every, name='loss', step=self.step, epoch=self.epoch)
self._avg_loss = 0
def on_valid_end(self, eval_result, metric_key, optimizer, better_result):
if better_result:
eval_result = deepcopy(eval_result)
@ -518,7 +527,7 @@ class FitlogCallback(Callback):
def on_exception(self, exception):
fitlog.finish(status=1)
if self._log_exception:
fitlog.add_other(str(exception), name='except_info')
fitlog.add_other(repr(exception), name='except_info')
class LRScheduler(Callback):

View File

@ -516,7 +516,7 @@ class EngChar2DPadder(Padder):
))
self._exactly_three_dims(contents, field_name)
if self.pad_length < 1:
max_char_length = max(max([[len(char_lst) for char_lst in word_lst] for word_lst in contents]))
max_char_length = max([max(len(char_lst) for char_lst in word_lst) for word_lst in contents])
else:
max_char_length = self.pad_length
max_sent_length = max(len(word_lst) for word_lst in contents)

View File

@ -476,8 +476,8 @@ class SpanFPreRecMetric(MetricBase):
label的f1, pre, rec
:param str f_type: 'micro''macro'. 'micro':通过先计算总体的TPFN和FP的数量再计算f, precision, recall; 'macro':
分布计算每个类别的f, precision, recall然后做平均各类别f的权重相同
:param float beta: f_beta分数f_beta = (1 + beta^2)*(pre*rec)/(beta^2*pre + rec). 常用为beta=0.5, 1, 2. 若为0.5
则精确率的权重高于召回率若为1则两者平等若为2则召回率权重高于精确率
:param float beta: f_beta分数:math:`f_beta = \frac{(1 + {beta}^{2})*(pre*rec)}{({beta}^{2}*pre + rec)}`.
常用为beta=0.5, 1, 2. 若为0.5则精确率的权重高于召回率若为1则两者平等若为2则召回率权重高于精确率
"""
def __init__(self, tag_vocab, pred=None, target=None, seq_len=None, encoding_type='bio', ignore_labels=None,
@ -708,8 +708,8 @@ class SQuADMetric(MetricBase):
:param pred2: 参数映射表中`pred2`的映射关系None表示映射关系为`pred2`->`pred2`
:param target1: 参数映射表中`target1`的映射关系None表示映射关系为`target1`->`target1`
:param target2: 参数映射表中`target2`的映射关系None表示映射关系为`target2`->`target2`
:param float beta: f_beta分数f_beta = (1 + beta^2)*(pre*rec)/(beta^2*pre + rec). 常用为beta=0.5, 1, 2. 若为0.5
则精确率的权重高于召回率若为1则两者平等若为2则召回率权重高于精确率
:param float beta: f_beta分数:math:`f_beta = \frac{(1 + {beta}^{2})*(pre*rec)}{({beta}^{2}*pre + rec)}`.
常用为beta=0.5, 1, 2. 若为0.5则精确率的权重高于召回率若为1则两者平等若为2则召回率权重高于精确率
:param bool right_open: right_open为true表示start跟end指针指向一个左闭右开区间为false表示指向一个左闭右闭区间
:param bool print_predict_stat: True则输出预测答案是否为空与正确答案是否为空的统计信息, False则不输出

View File

@ -494,12 +494,14 @@ class Trainer(object):
self.callback_manager = CallbackManager(env={"trainer": self},
callbacks=callbacks)
def train(self, load_best_model=True):
def train(self, load_best_model=True, on_exception='ignore'):
"""
使用该函数使Trainer开始训练
:param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效
如果True, trainer将在返回之前重新加载dev表现最好的模型参数
:param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效如果True, trainer将在返回之前重新加载dev表现
最好的模型参数
:param str on_exception: 在训练过程遭遇exception并被 :py:class:Callback 的on_exception()处理后是否继续抛出异常
支持'ignore''raise': 'ignore'将捕获异常写在Trainer.train()后面的代码将继续运行; 'raise'将异常抛出
:return dict: 返回一个字典类型的数据,
内含以下内容::
@ -528,8 +530,10 @@ class Trainer(object):
self.callback_manager.on_train_begin()
self._train()
self.callback_manager.on_train_end()
except (CallbackException, KeyboardInterrupt) as e:
except (CallbackException, KeyboardInterrupt, Exception) as e:
self.callback_manager.on_exception(e)
if on_exception=='raise':
raise e
if self.dev_data is not None and hasattr(self, 'best_dev_perf'):
print(

View File

@ -3,7 +3,8 @@ utils模块实现了 fastNLP 内部和外部所需的很多工具。其中用户
"""
__all__ = [
"cache_results",
"seq_len_to_mask"
"seq_len_to_mask",
"Example",
]
import _pickle
@ -21,6 +22,32 @@ _CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'require
'varargs'])
class Example(dict):
"""a dict can treat keys as attributes"""
def __getattr__(self, item):
try:
return self.__getitem__(item)
except KeyError:
raise AttributeError(item)
def __setattr__(self, key, value):
if key.startswith('__') and key.endswith('__'):
raise AttributeError(key)
self.__setitem__(key, value)
def __delattr__(self, item):
try:
self.pop(item)
except KeyError:
raise AttributeError(item)
def __getstate__(self):
return self
def __setstate__(self, state):
self.update(state)
def _prepare_cache_filepath(filepath):
"""
检查filepath是否可以作为合理的cache文件. 如果可以的话会自动创造路径

View File

@ -1,11 +1,26 @@
__all__ = [
"Vocabulary"
"Vocabulary",
"VocabularyOption",
]
from functools import wraps
from collections import Counter
from .dataset import DataSet
from .utils import Example
class VocabularyOption(Example):
def __init__(self,
max_size=None,
min_freq=None,
padding='<pad>',
unknown='<unk>'):
super().__init__(
max_size=max_size,
min_freq=min_freq,
padding=padding,
unknown=unknown
)
def _check_build_vocab(func):

View File

@ -1,10 +1,14 @@
__all__ = [
"BaseLoader"
"BaseLoader",
'DataInfo',
'DataSetLoader',
]
import _pickle as pickle
import os
from typing import Union, Dict
import os
from ..core.dataset import DataSet
class BaseLoader(object):
"""
@ -51,24 +55,161 @@ class BaseLoader(object):
return obj
class DataLoaderRegister:
_readers = {}
@classmethod
def set_reader(cls, reader_cls, read_fn_name):
# def wrapper(reader_cls):
if read_fn_name in cls._readers:
raise KeyError(
'duplicate reader: {} and {} for read_func: {}'.format(cls._readers[read_fn_name], reader_cls,
read_fn_name))
if hasattr(reader_cls, 'load'):
cls._readers[read_fn_name] = reader_cls().load
return reader_cls
@classmethod
def get_reader(cls, read_fn_name):
if read_fn_name in cls._readers:
return cls._readers[read_fn_name]
raise AttributeError('no read function: {}'.format(read_fn_name))
# TODO 这个类使用在何处?
def _download_from_url(url, path):
try:
from tqdm.auto import tqdm
except:
from ..core.utils import _pseudo_tqdm as tqdm
import requests
"""Download file"""
r = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'}, stream=True)
chunk_size = 16 * 1024
total_size = int(r.headers.get('Content-length', 0))
with open(path, "wb") as file, \
tqdm(total=total_size, unit='B', unit_scale=1, desc=path.split('/')[-1]) as t:
for chunk in r.iter_content(chunk_size):
if chunk:
file.write(chunk)
t.update(len(chunk))
def _uncompress(src, dst):
import zipfile
import gzip
import tarfile
import os
def unzip(src, dst):
with zipfile.ZipFile(src, 'r') as f:
f.extractall(dst)
def ungz(src, dst):
with gzip.open(src, 'rb') as f, open(dst, 'wb') as uf:
length = 16 * 1024 # 16KB
buf = f.read(length)
while buf:
uf.write(buf)
buf = f.read(length)
def untar(src, dst):
with tarfile.open(src, 'r:gz') as f:
f.extractall(dst)
fn, ext = os.path.splitext(src)
_, ext_2 = os.path.splitext(fn)
if ext == '.zip':
unzip(src, dst)
elif ext == '.gz' and ext_2 != '.tar':
ungz(src, dst)
elif (ext == '.gz' and ext_2 == '.tar') or ext_2 == '.tgz':
untar(src, dst)
else:
raise ValueError('unsupported file {}'.format(src))
class DataInfo:
"""
经过处理的数据信息包括一系列数据集比如分开的训练集验证集和测试集及它们所用的词表和词嵌入
:param vocabs: 从名称(字符串) :class:`~fastNLP.Vocabulary` 类型的dict
:param embeddings: 从名称(字符串)到一系列 embedding 的dict参考 :class:`~fastNLP.io.EmbedLoader`
:param datasets: 从名称(字符串) :class:`~fastNLP.DataSet` 类型的dict
"""
def __init__(self, vocabs: dict = None, embeddings: dict = None, datasets: dict = None):
self.vocabs = vocabs or {}
self.embeddings = embeddings or {}
self.datasets = datasets or {}
class DataSetLoader:
"""
别名:class:`fastNLP.io.DataSetLoader` :class:`fastNLP.io.dataset_loader.DataSetLoader`
定义了各种 DataSetLoader 所需的API 接口开发者应该继承它实现各种的 DataSetLoader
开发者至少应该编写如下内容:
- _load 函数从一个数据文件中读取数据到一个 :class:`~fastNLP.DataSet`
- load 函数可以使用基类的方法从一个或多个数据文件中读取数据到一个或多个 :class:`~fastNLP.DataSet`
- process 函数一个或多个从数据文件中读取数据并处理成可以训练的一个或多个 :class:`~fastNLP.DataSet`
**process 函数中可以 调用load 函数或 _load 函数**
"""
URL = ''
DATA_DIR = ''
ROOT_DIR = '.fastnlp/datasets/'
UNCOMPRESS = True
def _download(self, url: str, pdir: str, uncompress=True) -> str:
"""
``url`` 下载数据到 ``path`` 如果 ``uncompress`` ``True`` 自动解压
:param url: 下载的网站
:param pdir: 下载到的目录
:param uncompress: 是否自动解压缩
:return: 数据的存放路径
"""
fn = os.path.basename(url)
path = os.path.join(pdir, fn)
"""check data exists"""
if not os.path.exists(path):
os.makedirs(pdir, exist_ok=True)
_download_from_url(url, path)
if uncompress:
dst = os.path.join(pdir, 'data')
if not os.path.exists(dst):
_uncompress(path, dst)
return dst
return path
def download(self):
return self._download(
self.URL,
os.path.join(self.ROOT_DIR, self.DATA_DIR),
uncompress=self.UNCOMPRESS)
def load(self, paths: Union[str, Dict[str, str]]) -> Union[DataSet, Dict[str, DataSet]]:
"""
从指定一个或多个路径中的文件中读取数据返回一个或多个数据集 :class:`~fastNLP.DataSet`
如果处理多个路径传入的 dict 中的 key 与返回的 dict 中的 key 保存一致
:param Union[str, Dict[str, str]] paths: 文件路径
:return: :class:`~fastNLP.DataSet` 类的对象或存储多个 :class:`~fastNLP.DataSet` 的字典
"""
if isinstance(paths, str):
return self._load(paths)
return {name: self._load(path) for name, path in paths.items()}
def _load(self, path: str) -> DataSet:
"""从指定路径的文件中读取数据,返回 :class:`~fastNLP.DataSet` 类型的对象
:param str path: 文件路径
:return: 一个 :class:`~fastNLP.DataSet` 类型的对象
"""
raise NotImplementedError
def process(self, paths: Union[str, Dict[str, str]], **options) -> DataInfo:
"""
对于特定的任务和数据集读取并处理数据返回处理DataInfo类对象或字典
从指定一个或多个路径中的文件中读取数据DataInfo对象中可以包含一个或多个数据集
如果处理多个路径传入的 dict key 与返回DataInfo中的 dict 中的 key 保存一致
返回的 :class:`DataInfo` 对象有如下属性
- vocabs: 由从数据集中获取的词表组成的字典每个词表
- embeddings: (可选) 数据集对应的词嵌入
- datasets: 一个dict包含一系列 :class:`~fastNLP.DataSet` 类型的对象其中 field 的命名参考 :mod:`~fastNLP.core.const`
:param paths: 原始数据读取的路径
:param options: 根据不同的任务和数据集设计自己的参数
:return: 返回一个 DataInfo
"""
raise NotImplementedError

View File

@ -0,0 +1,95 @@
from typing import Iterable
from nltk import Tree
from ..base_loader import DataInfo, DataSetLoader
from ...core.vocabulary import VocabularyOption, Vocabulary
from ...core.dataset import DataSet
from ...core.instance import Instance
from ..embed_loader import EmbeddingOption, EmbedLoader
class SSTLoader(DataSetLoader):
URL = 'https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip'
DATA_DIR = 'sst/'
"""
别名:class:`fastNLP.io.SSTLoader` :class:`fastNLP.io.dataset_loader.SSTLoader`
读取SST数据集, DataSet包含fields::
words: list(str) 需要分类的文本
target: str 文本的标签
数据来源: https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip
:param subtree: 是否将数据展开为子树扩充数据量. Default: ``False``
:param fine_grained: 是否使用SST-5标准 ``False`` , 使用SST-2Default: ``False``
"""
def __init__(self, subtree=False, fine_grained=False):
self.subtree = subtree
tag_v = {'0': 'very negative', '1': 'negative', '2': 'neutral',
'3': 'positive', '4': 'very positive'}
if not fine_grained:
tag_v['0'] = tag_v['1']
tag_v['4'] = tag_v['3']
self.tag_v = tag_v
def _load(self, path):
"""
:param str path: 存储数据的路径
:return: 一个 :class:`~fastNLP.DataSet` 类型的对象
"""
datalist = []
with open(path, 'r', encoding='utf-8') as f:
datas = []
for l in f:
datas.extend([(s, self.tag_v[t])
for s, t in self._get_one(l, self.subtree)])
ds = DataSet()
for words, tag in datas:
ds.append(Instance(words=words, target=tag))
return ds
@staticmethod
def _get_one(data, subtree):
tree = Tree.fromstring(data)
if subtree:
return [(t.leaves(), t.label()) for t in tree.subtrees()]
return [(tree.leaves(), tree.label())]
def process(self,
paths,
train_ds: Iterable[str] = None,
src_vocab_op: VocabularyOption = None,
tgt_vocab_op: VocabularyOption = None,
src_embed_op: EmbeddingOption = None):
input_name, target_name = 'words', 'target'
src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op)
tgt_vocab = Vocabulary(unknown=None, padding=None) \
if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op)
info = DataInfo(datasets=self.load(paths))
_train_ds = [info.datasets[name]
for name in train_ds] if train_ds else info.datasets.values()
src_vocab.from_dataset(*_train_ds, field_name=input_name)
tgt_vocab.from_dataset(*_train_ds, field_name=target_name)
src_vocab.index_dataset(
*info.datasets.values(),
field_name=input_name, new_field_name=input_name)
tgt_vocab.index_dataset(
*info.datasets.values(),
field_name=target_name, new_field_name=target_name)
info.vocabs = {
input_name: src_vocab,
target_name: tgt_vocab
}
if src_embed_op is not None:
src_embed_op.vocab = src_vocab
init_emb = EmbedLoader.load_with_vocab(**src_embed_op)
info.embeddings[input_name] = init_emb
return info

View File

@ -13,8 +13,6 @@ dataset_loader模块实现了许多 DataSetLoader, 用于读取不同格式的
fastNLP 提供 DataSetLoader 的开发者请参考 :class:`~fastNLP.io.DataSetLoader` 的介绍
"""
__all__ = [
'DataInfo',
'DataSetLoader',
'CSVLoader',
'JsonLoader',
'ConllLoader',
@ -24,158 +22,12 @@ __all__ = [
'Conll2003Loader',
]
from nltk.tree import Tree
from nltk import Tree
from ..core.dataset import DataSet
from ..core.instance import Instance
from .file_reader import _read_csv, _read_json, _read_conll
from typing import Union, Dict
import os
def _download_from_url(url, path):
try:
from tqdm.auto import tqdm
except:
from ..core.utils import _pseudo_tqdm as tqdm
import requests
"""Download file"""
r = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'}, stream=True)
chunk_size = 16 * 1024
total_size = int(r.headers.get('Content-length', 0))
with open(path, "wb") as file, \
tqdm(total=total_size, unit='B', unit_scale=1, desc=path.split('/')[-1]) as t:
for chunk in r.iter_content(chunk_size):
if chunk:
file.write(chunk)
t.update(len(chunk))
return
def _uncompress(src, dst):
import zipfile
import gzip
import tarfile
import os
def unzip(src, dst):
with zipfile.ZipFile(src, 'r') as f:
f.extractall(dst)
def ungz(src, dst):
with gzip.open(src, 'rb') as f, open(dst, 'wb') as uf:
length = 16 * 1024 # 16KB
buf = f.read(length)
while buf:
uf.write(buf)
buf = f.read(length)
def untar(src, dst):
with tarfile.open(src, 'r:gz') as f:
f.extractall(dst)
fn, ext = os.path.splitext(src)
_, ext_2 = os.path.splitext(fn)
if ext == '.zip':
unzip(src, dst)
elif ext == '.gz' and ext_2 != '.tar':
ungz(src, dst)
elif (ext == '.gz' and ext_2 == '.tar') or ext_2 == '.tgz':
untar(src, dst)
else:
raise ValueError('unsupported file {}'.format(src))
class DataInfo:
"""
经过处理的数据信息包括一系列数据集比如分开的训练集验证集和测试集及它们所用的词表和词嵌入
:param vocabs: 从名称(字符串) :class:`~fastNLP.Vocabulary` 类型的dict
:param embeddings: 从名称(字符串)到一系列 embedding 的dict参考 :class:`~fastNLP.io.EmbedLoader`
:param datasets: 从名称(字符串) :class:`~fastNLP.DataSet` 类型的dict
"""
def __init__(self, vocabs: dict = None, embeddings: dict = None, datasets: dict = None):
self.vocabs = vocabs or {}
self.embeddings = embeddings or {}
self.datasets = datasets or {}
class DataSetLoader:
"""
别名:class:`fastNLP.io.DataSetLoader` :class:`fastNLP.io.dataset_loader.DataSetLoader`
定义了各种 DataSetLoader (针对特定数据上的特定任务) 所需的API 接口开发者应该继承它实现各种的 DataSetLoader
开发者至少应该编写如下内容:
- _load 函数从一个数据文件中读取数据到一个 :class:`~fastNLP.DataSet`
- load 函数可以使用基类的方法从一个或多个数据文件中读取数据到一个或多个 :class:`~fastNLP.DataSet`
- process 函数一个或多个从数据文件中读取数据并处理成可以训练的一个或多个 :class:`~fastNLP.DataSet`
**process 函数中可以 调用load 函数或 _load 函数**
"""
def _download(self, url: str, path: str, uncompress=True) -> str:
"""
``url`` 下载数据到 ``path`` 如果 ``uncompress`` ``True`` 自动解压
:param url: 下载的网站
:param path: 下载到的目录
:param uncompress: 是否自动解压缩
:return: 数据的存放路径
"""
pdir = os.path.dirname(path)
os.makedirs(pdir, exist_ok=True)
_download_from_url(url, path)
if uncompress:
dst = os.path.join(pdir, 'data')
_uncompress(path, dst)
return dst
return path
def load(self, paths: Union[str, Dict[str, str]]) -> Union[DataSet, Dict[str, DataSet]]:
"""
从指定一个或多个路径中的文件中读取数据返回一个或多个数据集 :class:`~fastNLP.DataSet`
如果处理多个路径传入的 dict 中的 key 与返回的 dict 中的 key 保存一致
:param Union[str, Dict[str, str]] paths: 文件路径
:return: :class:`~fastNLP.DataSet` 类的对象或存储多个 :class:`~fastNLP.DataSet` 的字典
"""
if isinstance(paths, str):
return self._load(paths)
return {name: self._load(path) for name, path in paths.items()}
def _load(self, path: str) -> DataSet:
"""从指定路径的文件中读取数据,返回 :class:`~fastNLP.DataSet` 类型的对象
:param str path: 文件路径
:return: 一个 :class:`~fastNLP.DataSet` 类型的对象
"""
raise NotImplementedError
def process(self, paths: Union[str, Dict[str, str]], **options) -> DataInfo:
"""
对于特定的任务和数据集读取并处理数据返回处理DataInfo类对象或字典
从指定一个或多个路径中的文件中读取数据DataInfo对象中可以包含一个或多个数据集
如果处理多个路径传入的 dict key 与返回DataInfo中的 dict 中的 key 保存一致
返回的 :class:`DataInfo` 对象有如下属性
- vocabs: 由从数据集中获取的词表组成的字典每个词表
- embeddings: (可选) 数据集对应的词嵌入
- datasets: 一个dict包含一系列 :class:`~fastNLP.DataSet` 类型的对象其中 field 的命名参考 :mod:`~fastNLP.core.const`
:param paths: 原始数据读取的路径
:param options: 根据不同的任务和数据集设计自己的参数
:return: 返回一个 DataInfo
"""
raise NotImplementedError
from .base_loader import DataSetLoader
from .data_loader.sst import SSTLoader
class PeopleDailyCorpusLoader(DataSetLoader):
"""
@ -183,12 +35,12 @@ class PeopleDailyCorpusLoader(DataSetLoader):
读取人民日报数据集
"""
def __init__(self, pos=True, ner=True):
super(PeopleDailyCorpusLoader, self).__init__()
self.pos = pos
self.ner = ner
def _load(self, data_path):
with open(data_path, "r", encoding="utf-8") as f:
sents = f.readlines()
@ -233,7 +85,7 @@ class PeopleDailyCorpusLoader(DataSetLoader):
example.append(sent_ner)
examples.append(example)
return self.convert(examples)
def convert(self, data):
"""
@ -284,7 +136,7 @@ class ConllLoader(DataSetLoader):
:param indexes: 需要保留的数据列下标从0开始若为 ``None`` 则所有列都保留Default: ``None``
:param dropna: 是否忽略非法数据 ``False`` 遇到非法数据时抛出 ``ValueError`` Default: ``False``
"""
def __init__(self, headers, indexes=None, dropna=False):
super(ConllLoader, self).__init__()
if not isinstance(headers, (list, tuple)):
@ -298,7 +150,7 @@ class ConllLoader(DataSetLoader):
if len(indexes) != len(headers):
raise ValueError
self.indexes = indexes
def _load(self, path):
ds = DataSet()
for idx, data in _read_conll(path, indexes=self.indexes, dropna=self.dropna):
@ -316,7 +168,7 @@ class Conll2003Loader(ConllLoader):
关于数据集的更多信息,参考:
https://sites.google.com/site/ermasoftware/getting-started/ne-tagging-conll2003-data
"""
def __init__(self):
headers = [
'tokens', 'pos', 'chunks', 'ner',
@ -354,56 +206,6 @@ def _cut_long_sentence(sent, max_sample_length=200):
return cutted_sentence
class SSTLoader(DataSetLoader):
"""
别名:class:`fastNLP.io.SSTLoader` :class:`fastNLP.io.dataset_loader.SSTLoader`
读取SST数据集, DataSet包含fields::
words: list(str) 需要分类的文本
target: str 文本的标签
数据来源: https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip
:param subtree: 是否将数据展开为子树扩充数据量. Default: ``False``
:param fine_grained: 是否使用SST-5标准 ``False`` , 使用SST-2Default: ``False``
"""
def __init__(self, subtree=False, fine_grained=False):
self.subtree = subtree
tag_v = {'0': 'very negative', '1': 'negative', '2': 'neutral',
'3': 'positive', '4': 'very positive'}
if not fine_grained:
tag_v['0'] = tag_v['1']
tag_v['4'] = tag_v['3']
self.tag_v = tag_v
def _load(self, path):
"""
:param str path: 存储数据的路径
:return: 一个 :class:`~fastNLP.DataSet` 类型的对象
"""
datalist = []
with open(path, 'r', encoding='utf-8') as f:
datas = []
for l in f:
datas.extend([(s, self.tag_v[t])
for s, t in self._get_one(l, self.subtree)])
ds = DataSet()
for words, tag in datas:
ds.append(Instance(words=words, target=tag))
return ds
@staticmethod
def _get_one(data, subtree):
tree = Tree.fromstring(data)
if subtree:
return [(t.leaves(), t.label()) for t in tree.subtrees()]
return [(tree.leaves(), tree.label())]
class JsonLoader(DataSetLoader):
"""
别名:class:`fastNLP.io.JsonLoader` :class:`fastNLP.io.dataset_loader.JsonLoader`
@ -417,7 +219,7 @@ class JsonLoader(DataSetLoader):
:param bool dropna: 是否忽略非法数据, ``True`` 则忽略, ``False`` ,在遇到非法数据时,抛出 ``ValueError`` .
Default: ``False``
"""
def __init__(self, fields=None, dropna=False):
super(JsonLoader, self).__init__()
self.dropna = dropna
@ -428,7 +230,7 @@ class JsonLoader(DataSetLoader):
for k, v in fields.items():
self.fields[k] = k if v is None else v
self.fields_list = list(self.fields.keys())
def _load(self, path):
ds = DataSet()
for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna):
@ -452,7 +254,7 @@ class SNLILoader(JsonLoader):
数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip
"""
def __init__(self):
fields = {
'sentence1_parse': 'words1',
@ -460,14 +262,14 @@ class SNLILoader(JsonLoader):
'gold_label': 'target',
}
super(SNLILoader, self).__init__(fields=fields)
def _load(self, path):
ds = super(SNLILoader, self)._load(path)
def parse_tree(x):
t = Tree.fromstring(x)
return t.leaves()
ds.apply(lambda ins: parse_tree(
ins['words1']), new_field_name='words1')
ds.apply(lambda ins: parse_tree(
@ -488,12 +290,12 @@ class CSVLoader(DataSetLoader):
:param bool dropna: 是否忽略非法数据, ``True`` 则忽略, ``False`` ,在遇到非法数据时,抛出 ``ValueError`` .
Default: ``False``
"""
def __init__(self, headers=None, sep=",", dropna=False):
self.headers = headers
self.sep = sep
self.dropna = dropna
def _load(self, path):
ds = DataSet()
for idx, data in _read_csv(path, headers=self.headers,
@ -508,7 +310,7 @@ def _add_seg_tag(data):
:param data: list of ([word], [pos], [heads], [head_tags])
:return: list of ([word], [pos])
"""
_processed = []
for word_list, pos_list, _, _ in data:
new_sample = []

View File

@ -1,5 +1,6 @@
__all__ = [
"EmbedLoader"
"EmbedLoader",
"EmbeddingOption",
]
import os
@ -9,8 +10,22 @@ import numpy as np
from ..core.vocabulary import Vocabulary
from .base_loader import BaseLoader
from ..core.utils import Example
class EmbeddingOption(Example):
def __init__(self,
embed_filepath=None,
dtype=np.float32,
normalize=True,
error='ignore'):
super().__init__(
embed_filepath=embed_filepath,
dtype=dtype,
normalize=normalize,
error=error
)
class EmbedLoader(BaseLoader):
"""
别名:class:`fastNLP.io.EmbedLoader` :class:`fastNLP.io.embed_loader.EmbedLoader`

View File

@ -10,6 +10,35 @@ from ..core.const import Const
from ..modules.encoder import BertModel
class BertConfig:
def __init__(
self,
vocab_size=30522,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02
):
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
class BertForSequenceClassification(BaseModel):
"""BERT model for classification.
This module is composed of the BERT model with a linear layer on top of
@ -44,14 +73,19 @@ class BertForSequenceClassification(BaseModel):
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
num_labels = 2
model = BertForSequenceClassification(config, num_labels)
model = BertForSequenceClassification(num_labels, config)
logits = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config, num_labels, bert_dir):
def __init__(self, num_labels, config=None, bert_dir=None):
super(BertForSequenceClassification, self).__init__()
self.num_labels = num_labels
self.bert = BertModel.from_pretrained(bert_dir)
if bert_dir is not None:
self.bert = BertModel.from_pretrained(bert_dir)
else:
if config is None:
config = BertConfig()
self.bert = BertModel(**config.__dict__)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, num_labels)
@ -106,14 +140,19 @@ class BertForMultipleChoice(BaseModel):
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
num_choices = 2
model = BertForMultipleChoice(config, num_choices, bert_dir)
model = BertForMultipleChoice(num_choices, config, bert_dir)
logits = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config, num_choices, bert_dir):
def __init__(self, num_choices, config=None, bert_dir=None):
super(BertForMultipleChoice, self).__init__()
self.num_choices = num_choices
self.bert = BertModel.from_pretrained(bert_dir)
if bert_dir is not None:
self.bert = BertModel.from_pretrained(bert_dir)
else:
if config is None:
config = BertConfig()
self.bert = BertModel(**config.__dict__)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, 1)
@ -174,14 +213,19 @@ class BertForTokenClassification(BaseModel):
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
num_labels = 2
bert_dir = 'your-bert-file-dir'
model = BertForTokenClassification(config, num_labels, bert_dir)
model = BertForTokenClassification(num_labels, config, bert_dir)
logits = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config, num_labels, bert_dir):
def __init__(self, num_labels, config=None, bert_dir=None):
super(BertForTokenClassification, self).__init__()
self.num_labels = num_labels
self.bert = BertModel.from_pretrained(bert_dir)
if bert_dir is not None:
self.bert = BertModel.from_pretrained(bert_dir)
else:
if config is None:
config = BertConfig()
self.bert = BertModel(**config.__dict__)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, num_labels)
@ -252,9 +296,14 @@ class BertForQuestionAnswering(BaseModel):
start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config, bert_dir):
def __init__(self, config=None, bert_dir=None):
super(BertForQuestionAnswering, self).__init__()
self.bert = BertModel.from_pretrained(bert_dir)
if bert_dir is not None:
self.bert = BertModel.from_pretrained(bert_dir)
else:
if config is None:
config = BertConfig()
self.bert = BertModel(**config.__dict__)
# TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version
# self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.qa_outputs = nn.Linear(config.hidden_size, 2)

View File

@ -2,20 +2,64 @@ import unittest
import torch
from fastNLP.models.bert import BertModel
from fastNLP.models.bert import *
class TestBert(unittest.TestCase):
def test_bert_1(self):
# model = BertModel.from_pretrained("/home/zyfeng/data/bert-base-chinese")
model = BertModel(vocab_size=32000, hidden_size=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
from fastNLP.core.const import Const
model = BertForSequenceClassification(2)
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
for layer in all_encoder_layers:
self.assertEqual(tuple(layer.shape), (2, 3, 768))
self.assertEqual(tuple(pooled_output.shape), (2, 768))
pred = model(input_ids, token_type_ids, input_mask)
self.assertTrue(isinstance(pred, dict))
self.assertTrue(Const.OUTPUT in pred)
self.assertEqual(tuple(pred[Const.OUTPUT].shape), (2, 2))
def test_bert_2(self):
from fastNLP.core.const import Const
model = BertForMultipleChoice(2)
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
pred = model(input_ids, token_type_ids, input_mask)
self.assertTrue(isinstance(pred, dict))
self.assertTrue(Const.OUTPUT in pred)
self.assertEqual(tuple(pred[Const.OUTPUT].shape), (1, 2))
def test_bert_3(self):
from fastNLP.core.const import Const
model = BertForTokenClassification(7)
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
pred = model(input_ids, token_type_ids, input_mask)
self.assertTrue(isinstance(pred, dict))
self.assertTrue(Const.OUTPUT in pred)
self.assertEqual(tuple(pred[Const.OUTPUT].shape), (2, 3, 7))
def test_bert_4(self):
from fastNLP.core.const import Const
model = BertForQuestionAnswering()
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
pred = model(input_ids, token_type_ids, input_mask)
self.assertTrue(isinstance(pred, dict))
self.assertTrue(Const.OUTPUTS(0) in pred)
self.assertTrue(Const.OUTPUTS(1) in pred)
self.assertEqual(tuple(pred[Const.OUTPUTS(0)].shape), (2, 3))
self.assertEqual(tuple(pred[Const.OUTPUTS(1)].shape), (2, 3))

View File

@ -0,0 +1,21 @@
import unittest
import torch
from fastNLP.models.bert import BertModel
class TestBert(unittest.TestCase):
def test_bert_1(self):
model = BertModel(vocab_size=32000, hidden_size=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
for layer in all_encoder_layers:
self.assertEqual(tuple(layer.shape), (2, 3, 768))
self.assertEqual(tuple(pooled_output.shape), (2, 768))