mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-03 04:37:37 +08:00
commit
d539ba3b14
@ -438,26 +438,29 @@ class EarlyStopCallback(Callback):
|
||||
|
||||
class FitlogCallback(Callback):
|
||||
"""
|
||||
该callback将loss和progress自动写入到fitlog中; 如果Trainer有dev的数据,将自动把dev的结果写入到log中; 同时还支持传入
|
||||
一个(或多个)test数据集进行测试(只有在trainer具有dev时才能使用),每次在dev上evaluate之后会在这些数据集上验证一下。
|
||||
并将验证结果写入到fitlog中。这些数据集的结果是根据dev上最好的结果报道的,即如果dev在第3个epoch取得了最佳,则
|
||||
fitlog中记录的关于这些数据集的结果就是来自第三个epoch的结果。
|
||||
别名: :class:`fastNLP.FitlogCallback` :class:`fastNLP.core.callback.FitlogCallback`
|
||||
|
||||
该callback可将loss和progress写入到fitlog中; 如果Trainer有dev的数据,将自动把dev的结果写入到log中; 同时还支持传入
|
||||
一个(或多个)test数据集进行测试(只有在trainer具有dev时才能使用),每次在dev上evaluate之后会在这些数据集上验证一下。
|
||||
并将验证结果写入到fitlog中。这些数据集的结果是根据dev上最好的结果报道的,即如果dev在第3个epoch取得了最佳,则
|
||||
fitlog中记录的关于这些数据集的结果就是来自第三个epoch的结果。
|
||||
|
||||
:param DataSet,dict(DataSet) data: 传入DataSet对象,会使用多个Trainer中的metric对数据进行验证。如果需要传入多个
|
||||
DataSet请通过dict的方式传入,dict的key将作为对应dataset的name传递给fitlog。若tester不为None时,data需要通过
|
||||
dict的方式传入。如果仅传入DataSet, 则被命名为test
|
||||
:param Tester tester: Tester对象,将在on_valid_end时调用。tester中的DataSet会被称为为`test`
|
||||
:param int verbose: 是否在终端打印内容,0不打印
|
||||
:param int log_loss_every: 多少个step记录一次loss(记录的是这几个batch的loss平均值),如果数据集较大建议将该值设置得
|
||||
大一些,不然会导致log文件巨大。默认为0, 即不要记录loss。
|
||||
:param int verbose: 是否在终端打印evaluation的结果,0不打印。
|
||||
:param bool log_exception: fitlog是否记录发生的exception信息
|
||||
"""
|
||||
# 还没有被导出到 fastNLP 层
|
||||
# 别名: :class:`fastNLP.FitlogCallback` :class:`fastNLP.core.callback.FitlogCallback`
|
||||
|
||||
def __init__(self, data=None, tester=None, verbose=0, log_exception=False):
|
||||
|
||||
def __init__(self, data=None, tester=None, log_loss_every=0, verbose=0, log_exception=False):
|
||||
super().__init__()
|
||||
self.datasets = {}
|
||||
self.testers = {}
|
||||
self._log_exception = log_exception
|
||||
assert isinstance(log_loss_every, int) and log_loss_every>=0
|
||||
if tester is not None:
|
||||
assert isinstance(tester, Tester), "Only fastNLP.Tester allowed."
|
||||
assert isinstance(data, dict) or data is None, "If tester is not None, only dict[DataSet] allowed for data."
|
||||
@ -477,7 +480,9 @@ class FitlogCallback(Callback):
|
||||
raise TypeError("data receives dict[DataSet] or DataSet object.")
|
||||
|
||||
self.verbose = verbose
|
||||
|
||||
self._log_loss_every = log_loss_every
|
||||
self._avg_loss = 0
|
||||
|
||||
def on_train_begin(self):
|
||||
if (len(self.datasets) > 0 or len(self.testers) > 0) and self.trainer.dev_data is None:
|
||||
raise RuntimeError("Trainer has no dev data, you cannot pass extra data to do evaluation.")
|
||||
@ -490,8 +495,12 @@ class FitlogCallback(Callback):
|
||||
fitlog.add_progress(total_steps=self.n_steps)
|
||||
|
||||
def on_backward_begin(self, loss):
|
||||
fitlog.add_loss(loss.item(), name='loss', step=self.step, epoch=self.epoch)
|
||||
|
||||
if self._log_loss_every>0:
|
||||
self._avg_loss += loss.item()
|
||||
if self.step%self._log_loss_every==0:
|
||||
fitlog.add_loss(self._avg_loss/self._log_loss_every, name='loss', step=self.step, epoch=self.epoch)
|
||||
self._avg_loss = 0
|
||||
|
||||
def on_valid_end(self, eval_result, metric_key, optimizer, better_result):
|
||||
if better_result:
|
||||
eval_result = deepcopy(eval_result)
|
||||
@ -518,7 +527,7 @@ class FitlogCallback(Callback):
|
||||
def on_exception(self, exception):
|
||||
fitlog.finish(status=1)
|
||||
if self._log_exception:
|
||||
fitlog.add_other(str(exception), name='except_info')
|
||||
fitlog.add_other(repr(exception), name='except_info')
|
||||
|
||||
|
||||
class LRScheduler(Callback):
|
||||
|
@ -516,7 +516,7 @@ class EngChar2DPadder(Padder):
|
||||
))
|
||||
self._exactly_three_dims(contents, field_name)
|
||||
if self.pad_length < 1:
|
||||
max_char_length = max(max([[len(char_lst) for char_lst in word_lst] for word_lst in contents]))
|
||||
max_char_length = max([max(len(char_lst) for char_lst in word_lst) for word_lst in contents])
|
||||
else:
|
||||
max_char_length = self.pad_length
|
||||
max_sent_length = max(len(word_lst) for word_lst in contents)
|
||||
|
@ -476,8 +476,8 @@ class SpanFPreRecMetric(MetricBase):
|
||||
label的f1, pre, rec
|
||||
:param str f_type: 'micro'或'macro'. 'micro':通过先计算总体的TP,FN和FP的数量,再计算f, precision, recall; 'macro':
|
||||
分布计算每个类别的f, precision, recall,然后做平均(各类别f的权重相同)
|
||||
:param float beta: f_beta分数,f_beta = (1 + beta^2)*(pre*rec)/(beta^2*pre + rec). 常用为beta=0.5, 1, 2. 若为0.5
|
||||
则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。
|
||||
:param float beta: f_beta分数,:math:`f_beta = \frac{(1 + {beta}^{2})*(pre*rec)}{({beta}^{2}*pre + rec)}`.
|
||||
常用为beta=0.5, 1, 2. 若为0.5则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。
|
||||
"""
|
||||
|
||||
def __init__(self, tag_vocab, pred=None, target=None, seq_len=None, encoding_type='bio', ignore_labels=None,
|
||||
@ -708,8 +708,8 @@ class SQuADMetric(MetricBase):
|
||||
:param pred2: 参数映射表中`pred2`的映射关系,None表示映射关系为`pred2`->`pred2`
|
||||
:param target1: 参数映射表中`target1`的映射关系,None表示映射关系为`target1`->`target1`
|
||||
:param target2: 参数映射表中`target2`的映射关系,None表示映射关系为`target2`->`target2`
|
||||
:param float beta: f_beta分数,f_beta = (1 + beta^2)*(pre*rec)/(beta^2*pre + rec). 常用为beta=0.5, 1, 2. 若为0.5
|
||||
则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。
|
||||
:param float beta: f_beta分数,:math:`f_beta = \frac{(1 + {beta}^{2})*(pre*rec)}{({beta}^{2}*pre + rec)}`.
|
||||
常用为beta=0.5, 1, 2. 若为0.5则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。
|
||||
:param bool right_open: right_open为true表示start跟end指针指向一个左闭右开区间,为false表示指向一个左闭右闭区间。
|
||||
:param bool print_predict_stat: True则输出预测答案是否为空与正确答案是否为空的统计信息, False则不输出
|
||||
|
||||
|
@ -494,12 +494,14 @@ class Trainer(object):
|
||||
self.callback_manager = CallbackManager(env={"trainer": self},
|
||||
callbacks=callbacks)
|
||||
|
||||
def train(self, load_best_model=True):
|
||||
def train(self, load_best_model=True, on_exception='ignore'):
|
||||
"""
|
||||
使用该函数使Trainer开始训练。
|
||||
|
||||
:param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,
|
||||
如果True, trainer将在返回之前重新加载dev表现最好的模型参数。
|
||||
:param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现
|
||||
最好的模型参数。
|
||||
:param str on_exception: 在训练过程遭遇exception,并被 :py:class:Callback 的on_exception()处理后,是否继续抛出异常。
|
||||
支持'ignore'与'raise': 'ignore'将捕获异常,写在Trainer.train()后面的代码将继续运行; 'raise'将异常抛出。
|
||||
:return dict: 返回一个字典类型的数据,
|
||||
内含以下内容::
|
||||
|
||||
@ -528,8 +530,10 @@ class Trainer(object):
|
||||
self.callback_manager.on_train_begin()
|
||||
self._train()
|
||||
self.callback_manager.on_train_end()
|
||||
except (CallbackException, KeyboardInterrupt) as e:
|
||||
except (CallbackException, KeyboardInterrupt, Exception) as e:
|
||||
self.callback_manager.on_exception(e)
|
||||
if on_exception=='raise':
|
||||
raise e
|
||||
|
||||
if self.dev_data is not None and hasattr(self, 'best_dev_perf'):
|
||||
print(
|
||||
|
@ -3,7 +3,8 @@ utils模块实现了 fastNLP 内部和外部所需的很多工具。其中用户
|
||||
"""
|
||||
__all__ = [
|
||||
"cache_results",
|
||||
"seq_len_to_mask"
|
||||
"seq_len_to_mask",
|
||||
"Example",
|
||||
]
|
||||
|
||||
import _pickle
|
||||
@ -21,6 +22,32 @@ _CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'require
|
||||
'varargs'])
|
||||
|
||||
|
||||
class Example(dict):
|
||||
"""a dict can treat keys as attributes"""
|
||||
def __getattr__(self, item):
|
||||
try:
|
||||
return self.__getitem__(item)
|
||||
except KeyError:
|
||||
raise AttributeError(item)
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
if key.startswith('__') and key.endswith('__'):
|
||||
raise AttributeError(key)
|
||||
self.__setitem__(key, value)
|
||||
|
||||
def __delattr__(self, item):
|
||||
try:
|
||||
self.pop(item)
|
||||
except KeyError:
|
||||
raise AttributeError(item)
|
||||
|
||||
def __getstate__(self):
|
||||
return self
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.update(state)
|
||||
|
||||
|
||||
def _prepare_cache_filepath(filepath):
|
||||
"""
|
||||
检查filepath是否可以作为合理的cache文件. 如果可以的话,会自动创造路径
|
||||
|
@ -1,11 +1,26 @@
|
||||
__all__ = [
|
||||
"Vocabulary"
|
||||
"Vocabulary",
|
||||
"VocabularyOption",
|
||||
]
|
||||
|
||||
from functools import wraps
|
||||
from collections import Counter
|
||||
|
||||
from .dataset import DataSet
|
||||
from .utils import Example
|
||||
|
||||
|
||||
class VocabularyOption(Example):
|
||||
def __init__(self,
|
||||
max_size=None,
|
||||
min_freq=None,
|
||||
padding='<pad>',
|
||||
unknown='<unk>'):
|
||||
super().__init__(
|
||||
max_size=max_size,
|
||||
min_freq=min_freq,
|
||||
padding=padding,
|
||||
unknown=unknown
|
||||
)
|
||||
|
||||
|
||||
def _check_build_vocab(func):
|
||||
|
@ -1,10 +1,14 @@
|
||||
__all__ = [
|
||||
"BaseLoader"
|
||||
"BaseLoader",
|
||||
'DataInfo',
|
||||
'DataSetLoader',
|
||||
]
|
||||
|
||||
import _pickle as pickle
|
||||
import os
|
||||
|
||||
from typing import Union, Dict
|
||||
import os
|
||||
from ..core.dataset import DataSet
|
||||
|
||||
class BaseLoader(object):
|
||||
"""
|
||||
@ -51,24 +55,161 @@ class BaseLoader(object):
|
||||
return obj
|
||||
|
||||
|
||||
class DataLoaderRegister:
|
||||
_readers = {}
|
||||
|
||||
@classmethod
|
||||
def set_reader(cls, reader_cls, read_fn_name):
|
||||
# def wrapper(reader_cls):
|
||||
if read_fn_name in cls._readers:
|
||||
raise KeyError(
|
||||
'duplicate reader: {} and {} for read_func: {}'.format(cls._readers[read_fn_name], reader_cls,
|
||||
read_fn_name))
|
||||
if hasattr(reader_cls, 'load'):
|
||||
cls._readers[read_fn_name] = reader_cls().load
|
||||
return reader_cls
|
||||
|
||||
@classmethod
|
||||
def get_reader(cls, read_fn_name):
|
||||
if read_fn_name in cls._readers:
|
||||
return cls._readers[read_fn_name]
|
||||
raise AttributeError('no read function: {}'.format(read_fn_name))
|
||||
|
||||
# TODO 这个类使用在何处?
|
||||
|
||||
|
||||
def _download_from_url(url, path):
|
||||
try:
|
||||
from tqdm.auto import tqdm
|
||||
except:
|
||||
from ..core.utils import _pseudo_tqdm as tqdm
|
||||
import requests
|
||||
|
||||
"""Download file"""
|
||||
r = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'}, stream=True)
|
||||
chunk_size = 16 * 1024
|
||||
total_size = int(r.headers.get('Content-length', 0))
|
||||
with open(path, "wb") as file, \
|
||||
tqdm(total=total_size, unit='B', unit_scale=1, desc=path.split('/')[-1]) as t:
|
||||
for chunk in r.iter_content(chunk_size):
|
||||
if chunk:
|
||||
file.write(chunk)
|
||||
t.update(len(chunk))
|
||||
|
||||
|
||||
def _uncompress(src, dst):
|
||||
import zipfile
|
||||
import gzip
|
||||
import tarfile
|
||||
import os
|
||||
|
||||
def unzip(src, dst):
|
||||
with zipfile.ZipFile(src, 'r') as f:
|
||||
f.extractall(dst)
|
||||
|
||||
def ungz(src, dst):
|
||||
with gzip.open(src, 'rb') as f, open(dst, 'wb') as uf:
|
||||
length = 16 * 1024 # 16KB
|
||||
buf = f.read(length)
|
||||
while buf:
|
||||
uf.write(buf)
|
||||
buf = f.read(length)
|
||||
|
||||
def untar(src, dst):
|
||||
with tarfile.open(src, 'r:gz') as f:
|
||||
f.extractall(dst)
|
||||
|
||||
fn, ext = os.path.splitext(src)
|
||||
_, ext_2 = os.path.splitext(fn)
|
||||
if ext == '.zip':
|
||||
unzip(src, dst)
|
||||
elif ext == '.gz' and ext_2 != '.tar':
|
||||
ungz(src, dst)
|
||||
elif (ext == '.gz' and ext_2 == '.tar') or ext_2 == '.tgz':
|
||||
untar(src, dst)
|
||||
else:
|
||||
raise ValueError('unsupported file {}'.format(src))
|
||||
|
||||
|
||||
class DataInfo:
|
||||
"""
|
||||
经过处理的数据信息,包括一系列数据集(比如:分开的训练集、验证集和测试集)及它们所用的词表和词嵌入。
|
||||
|
||||
:param vocabs: 从名称(字符串)到 :class:`~fastNLP.Vocabulary` 类型的dict
|
||||
:param embeddings: 从名称(字符串)到一系列 embedding 的dict,参考 :class:`~fastNLP.io.EmbedLoader`
|
||||
:param datasets: 从名称(字符串)到 :class:`~fastNLP.DataSet` 类型的dict
|
||||
"""
|
||||
|
||||
def __init__(self, vocabs: dict = None, embeddings: dict = None, datasets: dict = None):
|
||||
self.vocabs = vocabs or {}
|
||||
self.embeddings = embeddings or {}
|
||||
self.datasets = datasets or {}
|
||||
|
||||
|
||||
class DataSetLoader:
|
||||
"""
|
||||
别名::class:`fastNLP.io.DataSetLoader` :class:`fastNLP.io.dataset_loader.DataSetLoader`
|
||||
|
||||
定义了各种 DataSetLoader 所需的API 接口,开发者应该继承它实现各种的 DataSetLoader。
|
||||
|
||||
开发者至少应该编写如下内容:
|
||||
|
||||
- _load 函数:从一个数据文件中读取数据到一个 :class:`~fastNLP.DataSet`
|
||||
- load 函数(可以使用基类的方法):从一个或多个数据文件中读取数据到一个或多个 :class:`~fastNLP.DataSet`
|
||||
- process 函数:一个或多个从数据文件中读取数据,并处理成可以训练的一个或多个 :class:`~fastNLP.DataSet`
|
||||
|
||||
**process 函数中可以 调用load 函数或 _load 函数**
|
||||
|
||||
"""
|
||||
URL = ''
|
||||
DATA_DIR = ''
|
||||
|
||||
ROOT_DIR = '.fastnlp/datasets/'
|
||||
UNCOMPRESS = True
|
||||
|
||||
def _download(self, url: str, pdir: str, uncompress=True) -> str:
|
||||
"""
|
||||
|
||||
从 ``url`` 下载数据到 ``path``, 如果 ``uncompress`` 为 ``True`` ,自动解压。
|
||||
|
||||
:param url: 下载的网站
|
||||
:param pdir: 下载到的目录
|
||||
:param uncompress: 是否自动解压缩
|
||||
:return: 数据的存放路径
|
||||
"""
|
||||
fn = os.path.basename(url)
|
||||
path = os.path.join(pdir, fn)
|
||||
"""check data exists"""
|
||||
if not os.path.exists(path):
|
||||
os.makedirs(pdir, exist_ok=True)
|
||||
_download_from_url(url, path)
|
||||
if uncompress:
|
||||
dst = os.path.join(pdir, 'data')
|
||||
if not os.path.exists(dst):
|
||||
_uncompress(path, dst)
|
||||
return dst
|
||||
return path
|
||||
|
||||
def download(self):
|
||||
return self._download(
|
||||
self.URL,
|
||||
os.path.join(self.ROOT_DIR, self.DATA_DIR),
|
||||
uncompress=self.UNCOMPRESS)
|
||||
|
||||
def load(self, paths: Union[str, Dict[str, str]]) -> Union[DataSet, Dict[str, DataSet]]:
|
||||
"""
|
||||
从指定一个或多个路径中的文件中读取数据,返回一个或多个数据集 :class:`~fastNLP.DataSet` 。
|
||||
如果处理多个路径,传入的 dict 中的 key 与返回的 dict 中的 key 保存一致。
|
||||
|
||||
:param Union[str, Dict[str, str]] paths: 文件路径
|
||||
:return: :class:`~fastNLP.DataSet` 类的对象或存储多个 :class:`~fastNLP.DataSet` 的字典
|
||||
"""
|
||||
if isinstance(paths, str):
|
||||
return self._load(paths)
|
||||
return {name: self._load(path) for name, path in paths.items()}
|
||||
|
||||
def _load(self, path: str) -> DataSet:
|
||||
"""从指定路径的文件中读取数据,返回 :class:`~fastNLP.DataSet` 类型的对象
|
||||
|
||||
:param str path: 文件路径
|
||||
:return: 一个 :class:`~fastNLP.DataSet` 类型的对象
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def process(self, paths: Union[str, Dict[str, str]], **options) -> DataInfo:
|
||||
"""
|
||||
对于特定的任务和数据集,读取并处理数据,返回处理DataInfo类对象或字典。
|
||||
|
||||
从指定一个或多个路径中的文件中读取数据,DataInfo对象中可以包含一个或多个数据集 。
|
||||
如果处理多个路径,传入的 dict 的 key 与返回DataInfo中的 dict 中的 key 保存一致。
|
||||
|
||||
返回的 :class:`DataInfo` 对象有如下属性:
|
||||
|
||||
- vocabs: 由从数据集中获取的词表组成的字典,每个词表
|
||||
- embeddings: (可选) 数据集对应的词嵌入
|
||||
- datasets: 一个dict,包含一系列 :class:`~fastNLP.DataSet` 类型的对象。其中 field 的命名参考 :mod:`~fastNLP.core.const`
|
||||
|
||||
:param paths: 原始数据读取的路径
|
||||
:param options: 根据不同的任务和数据集,设计自己的参数
|
||||
:return: 返回一个 DataInfo
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
95
fastNLP/io/data_loader/sst.py
Normal file
95
fastNLP/io/data_loader/sst.py
Normal file
@ -0,0 +1,95 @@
|
||||
from typing import Iterable
|
||||
from nltk import Tree
|
||||
from ..base_loader import DataInfo, DataSetLoader
|
||||
from ...core.vocabulary import VocabularyOption, Vocabulary
|
||||
from ...core.dataset import DataSet
|
||||
from ...core.instance import Instance
|
||||
from ..embed_loader import EmbeddingOption, EmbedLoader
|
||||
|
||||
|
||||
class SSTLoader(DataSetLoader):
|
||||
URL = 'https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip'
|
||||
DATA_DIR = 'sst/'
|
||||
|
||||
"""
|
||||
别名::class:`fastNLP.io.SSTLoader` :class:`fastNLP.io.dataset_loader.SSTLoader`
|
||||
|
||||
读取SST数据集, DataSet包含fields::
|
||||
|
||||
words: list(str) 需要分类的文本
|
||||
target: str 文本的标签
|
||||
|
||||
数据来源: https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip
|
||||
|
||||
:param subtree: 是否将数据展开为子树,扩充数据量. Default: ``False``
|
||||
:param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False``
|
||||
"""
|
||||
|
||||
def __init__(self, subtree=False, fine_grained=False):
|
||||
self.subtree = subtree
|
||||
|
||||
tag_v = {'0': 'very negative', '1': 'negative', '2': 'neutral',
|
||||
'3': 'positive', '4': 'very positive'}
|
||||
if not fine_grained:
|
||||
tag_v['0'] = tag_v['1']
|
||||
tag_v['4'] = tag_v['3']
|
||||
self.tag_v = tag_v
|
||||
|
||||
def _load(self, path):
|
||||
"""
|
||||
|
||||
:param str path: 存储数据的路径
|
||||
:return: 一个 :class:`~fastNLP.DataSet` 类型的对象
|
||||
"""
|
||||
datalist = []
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
datas = []
|
||||
for l in f:
|
||||
datas.extend([(s, self.tag_v[t])
|
||||
for s, t in self._get_one(l, self.subtree)])
|
||||
ds = DataSet()
|
||||
for words, tag in datas:
|
||||
ds.append(Instance(words=words, target=tag))
|
||||
return ds
|
||||
|
||||
@staticmethod
|
||||
def _get_one(data, subtree):
|
||||
tree = Tree.fromstring(data)
|
||||
if subtree:
|
||||
return [(t.leaves(), t.label()) for t in tree.subtrees()]
|
||||
return [(tree.leaves(), tree.label())]
|
||||
|
||||
def process(self,
|
||||
paths,
|
||||
train_ds: Iterable[str] = None,
|
||||
src_vocab_op: VocabularyOption = None,
|
||||
tgt_vocab_op: VocabularyOption = None,
|
||||
src_embed_op: EmbeddingOption = None):
|
||||
input_name, target_name = 'words', 'target'
|
||||
src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op)
|
||||
tgt_vocab = Vocabulary(unknown=None, padding=None) \
|
||||
if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op)
|
||||
|
||||
info = DataInfo(datasets=self.load(paths))
|
||||
_train_ds = [info.datasets[name]
|
||||
for name in train_ds] if train_ds else info.datasets.values()
|
||||
src_vocab.from_dataset(*_train_ds, field_name=input_name)
|
||||
tgt_vocab.from_dataset(*_train_ds, field_name=target_name)
|
||||
src_vocab.index_dataset(
|
||||
*info.datasets.values(),
|
||||
field_name=input_name, new_field_name=input_name)
|
||||
tgt_vocab.index_dataset(
|
||||
*info.datasets.values(),
|
||||
field_name=target_name, new_field_name=target_name)
|
||||
info.vocabs = {
|
||||
input_name: src_vocab,
|
||||
target_name: tgt_vocab
|
||||
}
|
||||
|
||||
if src_embed_op is not None:
|
||||
src_embed_op.vocab = src_vocab
|
||||
init_emb = EmbedLoader.load_with_vocab(**src_embed_op)
|
||||
info.embeddings[input_name] = init_emb
|
||||
|
||||
return info
|
||||
|
@ -13,8 +13,6 @@ dataset_loader模块实现了许多 DataSetLoader, 用于读取不同格式的
|
||||
为 fastNLP 提供 DataSetLoader 的开发者请参考 :class:`~fastNLP.io.DataSetLoader` 的介绍。
|
||||
"""
|
||||
__all__ = [
|
||||
'DataInfo',
|
||||
'DataSetLoader',
|
||||
'CSVLoader',
|
||||
'JsonLoader',
|
||||
'ConllLoader',
|
||||
@ -24,158 +22,12 @@ __all__ = [
|
||||
'Conll2003Loader',
|
||||
]
|
||||
|
||||
from nltk.tree import Tree
|
||||
|
||||
from nltk import Tree
|
||||
from ..core.dataset import DataSet
|
||||
from ..core.instance import Instance
|
||||
from .file_reader import _read_csv, _read_json, _read_conll
|
||||
from typing import Union, Dict
|
||||
import os
|
||||
|
||||
|
||||
def _download_from_url(url, path):
|
||||
try:
|
||||
from tqdm.auto import tqdm
|
||||
except:
|
||||
from ..core.utils import _pseudo_tqdm as tqdm
|
||||
import requests
|
||||
|
||||
"""Download file"""
|
||||
r = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'}, stream=True)
|
||||
chunk_size = 16 * 1024
|
||||
total_size = int(r.headers.get('Content-length', 0))
|
||||
with open(path, "wb") as file, \
|
||||
tqdm(total=total_size, unit='B', unit_scale=1, desc=path.split('/')[-1]) as t:
|
||||
for chunk in r.iter_content(chunk_size):
|
||||
if chunk:
|
||||
file.write(chunk)
|
||||
t.update(len(chunk))
|
||||
return
|
||||
|
||||
|
||||
def _uncompress(src, dst):
|
||||
import zipfile
|
||||
import gzip
|
||||
import tarfile
|
||||
import os
|
||||
|
||||
def unzip(src, dst):
|
||||
with zipfile.ZipFile(src, 'r') as f:
|
||||
f.extractall(dst)
|
||||
|
||||
def ungz(src, dst):
|
||||
with gzip.open(src, 'rb') as f, open(dst, 'wb') as uf:
|
||||
length = 16 * 1024 # 16KB
|
||||
buf = f.read(length)
|
||||
while buf:
|
||||
uf.write(buf)
|
||||
buf = f.read(length)
|
||||
|
||||
def untar(src, dst):
|
||||
with tarfile.open(src, 'r:gz') as f:
|
||||
f.extractall(dst)
|
||||
|
||||
fn, ext = os.path.splitext(src)
|
||||
_, ext_2 = os.path.splitext(fn)
|
||||
if ext == '.zip':
|
||||
unzip(src, dst)
|
||||
elif ext == '.gz' and ext_2 != '.tar':
|
||||
ungz(src, dst)
|
||||
elif (ext == '.gz' and ext_2 == '.tar') or ext_2 == '.tgz':
|
||||
untar(src, dst)
|
||||
else:
|
||||
raise ValueError('unsupported file {}'.format(src))
|
||||
|
||||
|
||||
class DataInfo:
|
||||
"""
|
||||
经过处理的数据信息,包括一系列数据集(比如:分开的训练集、验证集和测试集)及它们所用的词表和词嵌入。
|
||||
|
||||
:param vocabs: 从名称(字符串)到 :class:`~fastNLP.Vocabulary` 类型的dict
|
||||
:param embeddings: 从名称(字符串)到一系列 embedding 的dict,参考 :class:`~fastNLP.io.EmbedLoader`
|
||||
:param datasets: 从名称(字符串)到 :class:`~fastNLP.DataSet` 类型的dict
|
||||
"""
|
||||
|
||||
def __init__(self, vocabs: dict = None, embeddings: dict = None, datasets: dict = None):
|
||||
self.vocabs = vocabs or {}
|
||||
self.embeddings = embeddings or {}
|
||||
self.datasets = datasets or {}
|
||||
|
||||
|
||||
class DataSetLoader:
|
||||
"""
|
||||
别名::class:`fastNLP.io.DataSetLoader` :class:`fastNLP.io.dataset_loader.DataSetLoader`
|
||||
|
||||
定义了各种 DataSetLoader (针对特定数据上的特定任务) 所需的API 接口,开发者应该继承它实现各种的 DataSetLoader。
|
||||
|
||||
开发者至少应该编写如下内容:
|
||||
|
||||
- _load 函数:从一个数据文件中读取数据到一个 :class:`~fastNLP.DataSet`
|
||||
- load 函数(可以使用基类的方法):从一个或多个数据文件中读取数据到一个或多个 :class:`~fastNLP.DataSet`
|
||||
- process 函数:一个或多个从数据文件中读取数据,并处理成可以训练的一个或多个 :class:`~fastNLP.DataSet`
|
||||
|
||||
**process 函数中可以 调用load 函数或 _load 函数**
|
||||
|
||||
"""
|
||||
|
||||
def _download(self, url: str, path: str, uncompress=True) -> str:
|
||||
"""
|
||||
|
||||
从 ``url`` 下载数据到 ``path``, 如果 ``uncompress`` 为 ``True`` ,自动解压。
|
||||
|
||||
:param url: 下载的网站
|
||||
:param path: 下载到的目录
|
||||
:param uncompress: 是否自动解压缩
|
||||
:return: 数据的存放路径
|
||||
"""
|
||||
pdir = os.path.dirname(path)
|
||||
os.makedirs(pdir, exist_ok=True)
|
||||
_download_from_url(url, path)
|
||||
if uncompress:
|
||||
dst = os.path.join(pdir, 'data')
|
||||
_uncompress(path, dst)
|
||||
return dst
|
||||
return path
|
||||
|
||||
def load(self, paths: Union[str, Dict[str, str]]) -> Union[DataSet, Dict[str, DataSet]]:
|
||||
"""
|
||||
从指定一个或多个路径中的文件中读取数据,返回一个或多个数据集 :class:`~fastNLP.DataSet` 。
|
||||
如果处理多个路径,传入的 dict 中的 key 与返回的 dict 中的 key 保存一致。
|
||||
|
||||
:param Union[str, Dict[str, str]] paths: 文件路径
|
||||
:return: :class:`~fastNLP.DataSet` 类的对象或存储多个 :class:`~fastNLP.DataSet` 的字典
|
||||
"""
|
||||
if isinstance(paths, str):
|
||||
return self._load(paths)
|
||||
return {name: self._load(path) for name, path in paths.items()}
|
||||
|
||||
def _load(self, path: str) -> DataSet:
|
||||
"""从指定路径的文件中读取数据,返回 :class:`~fastNLP.DataSet` 类型的对象
|
||||
|
||||
:param str path: 文件路径
|
||||
:return: 一个 :class:`~fastNLP.DataSet` 类型的对象
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def process(self, paths: Union[str, Dict[str, str]], **options) -> DataInfo:
|
||||
"""
|
||||
对于特定的任务和数据集,读取并处理数据,返回处理DataInfo类对象或字典。
|
||||
|
||||
从指定一个或多个路径中的文件中读取数据,DataInfo对象中可以包含一个或多个数据集 。
|
||||
如果处理多个路径,传入的 dict 的 key 与返回DataInfo中的 dict 中的 key 保存一致。
|
||||
|
||||
返回的 :class:`DataInfo` 对象有如下属性:
|
||||
|
||||
- vocabs: 由从数据集中获取的词表组成的字典,每个词表
|
||||
- embeddings: (可选) 数据集对应的词嵌入
|
||||
- datasets: 一个dict,包含一系列 :class:`~fastNLP.DataSet` 类型的对象。其中 field 的命名参考 :mod:`~fastNLP.core.const`
|
||||
|
||||
:param paths: 原始数据读取的路径
|
||||
:param options: 根据不同的任务和数据集,设计自己的参数
|
||||
:return: 返回一个 DataInfo
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
from .base_loader import DataSetLoader
|
||||
from .data_loader.sst import SSTLoader
|
||||
|
||||
class PeopleDailyCorpusLoader(DataSetLoader):
|
||||
"""
|
||||
@ -183,12 +35,12 @@ class PeopleDailyCorpusLoader(DataSetLoader):
|
||||
|
||||
读取人民日报数据集
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, pos=True, ner=True):
|
||||
super(PeopleDailyCorpusLoader, self).__init__()
|
||||
self.pos = pos
|
||||
self.ner = ner
|
||||
|
||||
|
||||
def _load(self, data_path):
|
||||
with open(data_path, "r", encoding="utf-8") as f:
|
||||
sents = f.readlines()
|
||||
@ -233,7 +85,7 @@ class PeopleDailyCorpusLoader(DataSetLoader):
|
||||
example.append(sent_ner)
|
||||
examples.append(example)
|
||||
return self.convert(examples)
|
||||
|
||||
|
||||
def convert(self, data):
|
||||
"""
|
||||
|
||||
@ -284,7 +136,7 @@ class ConllLoader(DataSetLoader):
|
||||
:param indexes: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None``
|
||||
:param dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``False``
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, headers, indexes=None, dropna=False):
|
||||
super(ConllLoader, self).__init__()
|
||||
if not isinstance(headers, (list, tuple)):
|
||||
@ -298,7 +150,7 @@ class ConllLoader(DataSetLoader):
|
||||
if len(indexes) != len(headers):
|
||||
raise ValueError
|
||||
self.indexes = indexes
|
||||
|
||||
|
||||
def _load(self, path):
|
||||
ds = DataSet()
|
||||
for idx, data in _read_conll(path, indexes=self.indexes, dropna=self.dropna):
|
||||
@ -316,7 +168,7 @@ class Conll2003Loader(ConllLoader):
|
||||
关于数据集的更多信息,参考:
|
||||
https://sites.google.com/site/ermasoftware/getting-started/ne-tagging-conll2003-data
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
headers = [
|
||||
'tokens', 'pos', 'chunks', 'ner',
|
||||
@ -354,56 +206,6 @@ def _cut_long_sentence(sent, max_sample_length=200):
|
||||
return cutted_sentence
|
||||
|
||||
|
||||
class SSTLoader(DataSetLoader):
|
||||
"""
|
||||
别名::class:`fastNLP.io.SSTLoader` :class:`fastNLP.io.dataset_loader.SSTLoader`
|
||||
|
||||
读取SST数据集, DataSet包含fields::
|
||||
|
||||
words: list(str) 需要分类的文本
|
||||
target: str 文本的标签
|
||||
|
||||
数据来源: https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip
|
||||
|
||||
:param subtree: 是否将数据展开为子树,扩充数据量. Default: ``False``
|
||||
:param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False``
|
||||
"""
|
||||
|
||||
def __init__(self, subtree=False, fine_grained=False):
|
||||
self.subtree = subtree
|
||||
|
||||
tag_v = {'0': 'very negative', '1': 'negative', '2': 'neutral',
|
||||
'3': 'positive', '4': 'very positive'}
|
||||
if not fine_grained:
|
||||
tag_v['0'] = tag_v['1']
|
||||
tag_v['4'] = tag_v['3']
|
||||
self.tag_v = tag_v
|
||||
|
||||
def _load(self, path):
|
||||
"""
|
||||
|
||||
:param str path: 存储数据的路径
|
||||
:return: 一个 :class:`~fastNLP.DataSet` 类型的对象
|
||||
"""
|
||||
datalist = []
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
datas = []
|
||||
for l in f:
|
||||
datas.extend([(s, self.tag_v[t])
|
||||
for s, t in self._get_one(l, self.subtree)])
|
||||
ds = DataSet()
|
||||
for words, tag in datas:
|
||||
ds.append(Instance(words=words, target=tag))
|
||||
return ds
|
||||
|
||||
@staticmethod
|
||||
def _get_one(data, subtree):
|
||||
tree = Tree.fromstring(data)
|
||||
if subtree:
|
||||
return [(t.leaves(), t.label()) for t in tree.subtrees()]
|
||||
return [(tree.leaves(), tree.label())]
|
||||
|
||||
|
||||
class JsonLoader(DataSetLoader):
|
||||
"""
|
||||
别名::class:`fastNLP.io.JsonLoader` :class:`fastNLP.io.dataset_loader.JsonLoader`
|
||||
@ -417,7 +219,7 @@ class JsonLoader(DataSetLoader):
|
||||
:param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` .
|
||||
Default: ``False``
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, fields=None, dropna=False):
|
||||
super(JsonLoader, self).__init__()
|
||||
self.dropna = dropna
|
||||
@ -428,7 +230,7 @@ class JsonLoader(DataSetLoader):
|
||||
for k, v in fields.items():
|
||||
self.fields[k] = k if v is None else v
|
||||
self.fields_list = list(self.fields.keys())
|
||||
|
||||
|
||||
def _load(self, path):
|
||||
ds = DataSet()
|
||||
for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna):
|
||||
@ -452,7 +254,7 @@ class SNLILoader(JsonLoader):
|
||||
|
||||
数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
fields = {
|
||||
'sentence1_parse': 'words1',
|
||||
@ -460,14 +262,14 @@ class SNLILoader(JsonLoader):
|
||||
'gold_label': 'target',
|
||||
}
|
||||
super(SNLILoader, self).__init__(fields=fields)
|
||||
|
||||
|
||||
def _load(self, path):
|
||||
ds = super(SNLILoader, self)._load(path)
|
||||
|
||||
|
||||
def parse_tree(x):
|
||||
t = Tree.fromstring(x)
|
||||
return t.leaves()
|
||||
|
||||
|
||||
ds.apply(lambda ins: parse_tree(
|
||||
ins['words1']), new_field_name='words1')
|
||||
ds.apply(lambda ins: parse_tree(
|
||||
@ -488,12 +290,12 @@ class CSVLoader(DataSetLoader):
|
||||
:param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` .
|
||||
Default: ``False``
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, headers=None, sep=",", dropna=False):
|
||||
self.headers = headers
|
||||
self.sep = sep
|
||||
self.dropna = dropna
|
||||
|
||||
|
||||
def _load(self, path):
|
||||
ds = DataSet()
|
||||
for idx, data in _read_csv(path, headers=self.headers,
|
||||
@ -508,7 +310,7 @@ def _add_seg_tag(data):
|
||||
:param data: list of ([word], [pos], [heads], [head_tags])
|
||||
:return: list of ([word], [pos])
|
||||
"""
|
||||
|
||||
|
||||
_processed = []
|
||||
for word_list, pos_list, _, _ in data:
|
||||
new_sample = []
|
||||
|
@ -1,5 +1,6 @@
|
||||
__all__ = [
|
||||
"EmbedLoader"
|
||||
"EmbedLoader",
|
||||
"EmbeddingOption",
|
||||
]
|
||||
|
||||
import os
|
||||
@ -9,8 +10,22 @@ import numpy as np
|
||||
|
||||
from ..core.vocabulary import Vocabulary
|
||||
from .base_loader import BaseLoader
|
||||
from ..core.utils import Example
|
||||
|
||||
|
||||
class EmbeddingOption(Example):
|
||||
def __init__(self,
|
||||
embed_filepath=None,
|
||||
dtype=np.float32,
|
||||
normalize=True,
|
||||
error='ignore'):
|
||||
super().__init__(
|
||||
embed_filepath=embed_filepath,
|
||||
dtype=dtype,
|
||||
normalize=normalize,
|
||||
error=error
|
||||
)
|
||||
|
||||
class EmbedLoader(BaseLoader):
|
||||
"""
|
||||
别名::class:`fastNLP.io.EmbedLoader` :class:`fastNLP.io.embed_loader.EmbedLoader`
|
||||
|
@ -10,6 +10,35 @@ from ..core.const import Const
|
||||
from ..modules.encoder import BertModel
|
||||
|
||||
|
||||
class BertConfig:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=30522,
|
||||
hidden_size=768,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=3072,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=2,
|
||||
initializer_range=0.02
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.initializer_range = initializer_range
|
||||
|
||||
|
||||
class BertForSequenceClassification(BaseModel):
|
||||
"""BERT model for classification.
|
||||
This module is composed of the BERT model with a linear layer on top of
|
||||
@ -44,14 +73,19 @@ class BertForSequenceClassification(BaseModel):
|
||||
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
|
||||
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
|
||||
num_labels = 2
|
||||
model = BertForSequenceClassification(config, num_labels)
|
||||
model = BertForSequenceClassification(num_labels, config)
|
||||
logits = model(input_ids, token_type_ids, input_mask)
|
||||
```
|
||||
"""
|
||||
def __init__(self, config, num_labels, bert_dir):
|
||||
def __init__(self, num_labels, config=None, bert_dir=None):
|
||||
super(BertForSequenceClassification, self).__init__()
|
||||
self.num_labels = num_labels
|
||||
self.bert = BertModel.from_pretrained(bert_dir)
|
||||
if bert_dir is not None:
|
||||
self.bert = BertModel.from_pretrained(bert_dir)
|
||||
else:
|
||||
if config is None:
|
||||
config = BertConfig()
|
||||
self.bert = BertModel(**config.__dict__)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.classifier = nn.Linear(config.hidden_size, num_labels)
|
||||
|
||||
@ -106,14 +140,19 @@ class BertForMultipleChoice(BaseModel):
|
||||
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
|
||||
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
|
||||
num_choices = 2
|
||||
model = BertForMultipleChoice(config, num_choices, bert_dir)
|
||||
model = BertForMultipleChoice(num_choices, config, bert_dir)
|
||||
logits = model(input_ids, token_type_ids, input_mask)
|
||||
```
|
||||
"""
|
||||
def __init__(self, config, num_choices, bert_dir):
|
||||
def __init__(self, num_choices, config=None, bert_dir=None):
|
||||
super(BertForMultipleChoice, self).__init__()
|
||||
self.num_choices = num_choices
|
||||
self.bert = BertModel.from_pretrained(bert_dir)
|
||||
if bert_dir is not None:
|
||||
self.bert = BertModel.from_pretrained(bert_dir)
|
||||
else:
|
||||
if config is None:
|
||||
config = BertConfig()
|
||||
self.bert = BertModel(**config.__dict__)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.classifier = nn.Linear(config.hidden_size, 1)
|
||||
|
||||
@ -174,14 +213,19 @@ class BertForTokenClassification(BaseModel):
|
||||
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
|
||||
num_labels = 2
|
||||
bert_dir = 'your-bert-file-dir'
|
||||
model = BertForTokenClassification(config, num_labels, bert_dir)
|
||||
model = BertForTokenClassification(num_labels, config, bert_dir)
|
||||
logits = model(input_ids, token_type_ids, input_mask)
|
||||
```
|
||||
"""
|
||||
def __init__(self, config, num_labels, bert_dir):
|
||||
def __init__(self, num_labels, config=None, bert_dir=None):
|
||||
super(BertForTokenClassification, self).__init__()
|
||||
self.num_labels = num_labels
|
||||
self.bert = BertModel.from_pretrained(bert_dir)
|
||||
if bert_dir is not None:
|
||||
self.bert = BertModel.from_pretrained(bert_dir)
|
||||
else:
|
||||
if config is None:
|
||||
config = BertConfig()
|
||||
self.bert = BertModel(**config.__dict__)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.classifier = nn.Linear(config.hidden_size, num_labels)
|
||||
|
||||
@ -252,9 +296,14 @@ class BertForQuestionAnswering(BaseModel):
|
||||
start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
|
||||
```
|
||||
"""
|
||||
def __init__(self, config, bert_dir):
|
||||
def __init__(self, config=None, bert_dir=None):
|
||||
super(BertForQuestionAnswering, self).__init__()
|
||||
self.bert = BertModel.from_pretrained(bert_dir)
|
||||
if bert_dir is not None:
|
||||
self.bert = BertModel.from_pretrained(bert_dir)
|
||||
else:
|
||||
if config is None:
|
||||
config = BertConfig()
|
||||
self.bert = BertModel(**config.__dict__)
|
||||
# TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version
|
||||
# self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.qa_outputs = nn.Linear(config.hidden_size, 2)
|
||||
|
@ -2,20 +2,64 @@ import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from fastNLP.models.bert import BertModel
|
||||
from fastNLP.models.bert import *
|
||||
|
||||
|
||||
class TestBert(unittest.TestCase):
|
||||
def test_bert_1(self):
|
||||
# model = BertModel.from_pretrained("/home/zyfeng/data/bert-base-chinese")
|
||||
model = BertModel(vocab_size=32000, hidden_size=768,
|
||||
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
|
||||
from fastNLP.core.const import Const
|
||||
|
||||
model = BertForSequenceClassification(2)
|
||||
|
||||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
||||
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
||||
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
|
||||
|
||||
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
|
||||
for layer in all_encoder_layers:
|
||||
self.assertEqual(tuple(layer.shape), (2, 3, 768))
|
||||
self.assertEqual(tuple(pooled_output.shape), (2, 768))
|
||||
pred = model(input_ids, token_type_ids, input_mask)
|
||||
self.assertTrue(isinstance(pred, dict))
|
||||
self.assertTrue(Const.OUTPUT in pred)
|
||||
self.assertEqual(tuple(pred[Const.OUTPUT].shape), (2, 2))
|
||||
|
||||
def test_bert_2(self):
|
||||
from fastNLP.core.const import Const
|
||||
|
||||
model = BertForMultipleChoice(2)
|
||||
|
||||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
||||
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
||||
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
|
||||
|
||||
pred = model(input_ids, token_type_ids, input_mask)
|
||||
self.assertTrue(isinstance(pred, dict))
|
||||
self.assertTrue(Const.OUTPUT in pred)
|
||||
self.assertEqual(tuple(pred[Const.OUTPUT].shape), (1, 2))
|
||||
|
||||
def test_bert_3(self):
|
||||
from fastNLP.core.const import Const
|
||||
|
||||
model = BertForTokenClassification(7)
|
||||
|
||||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
||||
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
||||
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
|
||||
|
||||
pred = model(input_ids, token_type_ids, input_mask)
|
||||
self.assertTrue(isinstance(pred, dict))
|
||||
self.assertTrue(Const.OUTPUT in pred)
|
||||
self.assertEqual(tuple(pred[Const.OUTPUT].shape), (2, 3, 7))
|
||||
|
||||
def test_bert_4(self):
|
||||
from fastNLP.core.const import Const
|
||||
|
||||
model = BertForQuestionAnswering()
|
||||
|
||||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
||||
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
||||
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
|
||||
|
||||
pred = model(input_ids, token_type_ids, input_mask)
|
||||
self.assertTrue(isinstance(pred, dict))
|
||||
self.assertTrue(Const.OUTPUTS(0) in pred)
|
||||
self.assertTrue(Const.OUTPUTS(1) in pred)
|
||||
self.assertEqual(tuple(pred[Const.OUTPUTS(0)].shape), (2, 3))
|
||||
self.assertEqual(tuple(pred[Const.OUTPUTS(1)].shape), (2, 3))
|
||||
|
21
test/modules/encoder/test_bert.py
Normal file
21
test/modules/encoder/test_bert.py
Normal file
@ -0,0 +1,21 @@
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from fastNLP.models.bert import BertModel
|
||||
|
||||
|
||||
class TestBert(unittest.TestCase):
|
||||
def test_bert_1(self):
|
||||
model = BertModel(vocab_size=32000, hidden_size=768,
|
||||
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
|
||||
|
||||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
||||
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
||||
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
|
||||
|
||||
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
|
||||
for layer in all_encoder_layers:
|
||||
self.assertEqual(tuple(layer.shape), (2, 3, 768))
|
||||
self.assertEqual(tuple(pooled_output.shape), (2, 768))
|
Loading…
Reference in New Issue
Block a user