mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-03 04:37:37 +08:00
Merge branch 'dev0.5.0' of https://github.com/fastnlp/fastNLP into dev0.5.0
This commit is contained in:
commit
2f8b194c23
@ -12,20 +12,19 @@ from ...core.instance import Instance
|
||||
|
||||
class JsonLoader(Loader):
|
||||
"""
|
||||
别名::class:`fastNLP.io.JsonLoader` :class:`fastNLP.io.loader.JsonLoader`
|
||||
|
||||
读取json格式数据.数据必须按行存储,每行是一个包含各类属性的json对象
|
||||
|
||||
:param dict fields: 需要读入的json属性名称, 和读入后在DataSet中存储的field_name
|
||||
``fields`` 的 `key` 必须是json对象的属性名. ``fields`` 的 `value` 为读入后在DataSet存储的 `field_name` ,
|
||||
`value` 也可为 ``None`` , 这时读入后的 `field_name` 与json对象对应属性同名
|
||||
``fields`` 可为 ``None`` , 这时,json对象所有属性都保存在DataSet中. Default: ``None``
|
||||
:param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` .
|
||||
Default: ``False``
|
||||
"""
|
||||
|
||||
def __init__(self, fields=None, dropna=False):
|
||||
"""
|
||||
|
||||
:param dict fields: 需要读入的json属性名称, 和读入后在DataSet中存储的field_name
|
||||
``fields`` 的 `key` 必须是json对象的属性名. ``fields`` 的 `value` 为读入后在DataSet存储的 `field_name` ,
|
||||
`value` 也可为 ``None`` , 这时读入后的 `field_name` 与json对象对应属性同名
|
||||
``fields`` 可为 ``None`` , 这时,json对象所有属性都保存在DataSet中. Default: ``None``
|
||||
:param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` .
|
||||
Default: ``False``
|
||||
"""
|
||||
super(JsonLoader, self).__init__()
|
||||
self.dropna = dropna
|
||||
self.fields = None
|
||||
|
187
fastNLP/io/pipe/summarization.py
Normal file
187
fastNLP/io/pipe/summarization.py
Normal file
@ -0,0 +1,187 @@
|
||||
"""undocumented"""
|
||||
import numpy as np
|
||||
|
||||
from .pipe import Pipe
|
||||
from .utils import get_tokenizer, _indexize, _add_words_field, _drop_empty_instance
|
||||
from ..loader.json import JsonLoader
|
||||
from ..data_bundle import DataBundle
|
||||
from ..loader.classification import IMDBLoader, YelpFullLoader, SSTLoader, SST2Loader, YelpPolarityLoader
|
||||
from ...core.const import Const
|
||||
from ...core.dataset import DataSet
|
||||
from ...core.instance import Instance
|
||||
from ...core.vocabulary import Vocabulary
|
||||
|
||||
|
||||
WORD_PAD = "[PAD]"
|
||||
WORD_UNK = "[UNK]"
|
||||
DOMAIN_UNK = "X"
|
||||
TAG_UNK = "X"
|
||||
|
||||
|
||||
|
||||
class ExtCNNDMPipe(Pipe):
|
||||
"""
|
||||
对CNN/Daily Mail数据进行适用于extractive summarization task的预处理,预处理之后的数据,具备以下结构:
|
||||
|
||||
.. csv-table::
|
||||
:header: "text", "summary", "label", "publication", "text_wd", "words", "seq_len", "target"
|
||||
|
||||
"""
|
||||
def __init__(self, vocab_size, vocab_path, sent_max_len, doc_max_timesteps, domain=False):
|
||||
"""
|
||||
|
||||
:param vocab_size: int, 词表大小
|
||||
:param vocab_path: str, 外部词表路径
|
||||
:param sent_max_len: int, 句子最大长度,不足的句子将padding,超出的将截断
|
||||
:param doc_max_timesteps: int, 文章最多句子个数,不足的将padding,超出的将截断
|
||||
:param domain: bool, 是否需要建立domain词表
|
||||
"""
|
||||
self.vocab_size = vocab_size
|
||||
self.vocab_path = vocab_path
|
||||
self.sent_max_len = sent_max_len
|
||||
self.doc_max_timesteps = doc_max_timesteps
|
||||
self.domain = domain
|
||||
|
||||
|
||||
def process(self, db: DataBundle):
|
||||
"""
|
||||
传入的DataSet应该具备如下的结构
|
||||
|
||||
.. csv-table::
|
||||
:header: "text", "summary", "label", "publication"
|
||||
|
||||
["I got new tires from them and... ","..."], ["The new tires...","..."], [0, 1], "cnndm"
|
||||
["Don't waste your time. We had two...","..."], ["Time is precious","..."], [1], "cnndm"
|
||||
["..."], ["..."], [], "cnndm"
|
||||
|
||||
:param data_bundle:
|
||||
:return: 处理得到的数据包括
|
||||
.. csv-table::
|
||||
:header: "text_wd", "words", "seq_len", "target"
|
||||
|
||||
[["I","got",..."."],...,["..."]], [[54,89,...,5],...,[9,43,..,0]], [1,1,...,0], [0,1,...,0]
|
||||
[["Don't","waste",...,"."],...,["..."]], [[5234,653,...,5],...,[87,234,..,0]], [1,1,...,0], [1,1,...,0]
|
||||
[[""],...,[""]], [[],...,[]], [], []
|
||||
"""
|
||||
|
||||
db.apply(lambda x: _lower_text(x['text']), new_field_name='text')
|
||||
db.apply(lambda x: _lower_text(x['summary']), new_field_name='summary')
|
||||
db.apply(lambda x: _split_list(x['text']), new_field_name='text_wd')
|
||||
db.apply(lambda x: _convert_label(x["label"], len(x["text"])), new_field_name=Const.TARGET)
|
||||
|
||||
db.apply(lambda x: _pad_sent(x["text_wd"], self.sent_max_len), new_field_name=Const.INPUT)
|
||||
# db.apply(lambda x: _token_mask(x["text_wd"], self.sent_max_len), new_field_name="pad_token_mask")
|
||||
|
||||
# pad document
|
||||
db.apply(lambda x: _pad_doc(x[Const.INPUT], self.sent_max_len, self.doc_max_timesteps), new_field_name=Const.INPUT)
|
||||
db.apply(lambda x: _sent_mask(x[Const.INPUT], self.doc_max_timesteps), new_field_name=Const.INPUT_LEN)
|
||||
db.apply(lambda x: _pad_label(x[Const.TARGET], self.doc_max_timesteps), new_field_name=Const.TARGET)
|
||||
|
||||
db = _drop_empty_instance(db, "label")
|
||||
|
||||
# set input and target
|
||||
db.set_input(Const.INPUT, Const.INPUT_LEN)
|
||||
db.set_target(Const.TARGET, Const.INPUT_LEN)
|
||||
|
||||
# print("[INFO] Load existing vocab from %s!" % self.vocab_path)
|
||||
word_list = []
|
||||
with open(self.vocab_path, 'r', encoding='utf8') as vocab_f:
|
||||
cnt = 2 # pad and unk
|
||||
for line in vocab_f:
|
||||
pieces = line.split("\t")
|
||||
word_list.append(pieces[0])
|
||||
cnt += 1
|
||||
if cnt > self.vocab_size:
|
||||
break
|
||||
vocabs = Vocabulary(max_size=self.vocab_size, padding=WORD_PAD, unknown=WORD_UNK)
|
||||
vocabs.add_word_lst(word_list)
|
||||
vocabs.build_vocab()
|
||||
db.set_vocab(vocabs, "vocab")
|
||||
|
||||
if self.domain == True:
|
||||
domaindict = Vocabulary(padding=None, unknown=DOMAIN_UNK)
|
||||
domaindict.from_dataset(db.get_dataset("train"), field_name="publication")
|
||||
db.set_vocab(domaindict, "domain")
|
||||
|
||||
return db
|
||||
|
||||
|
||||
def process_from_file(self, paths=None):
|
||||
"""
|
||||
:param paths: dict or string
|
||||
:return: DataBundle
|
||||
"""
|
||||
db = DataBundle()
|
||||
if isinstance(paths, dict):
|
||||
for key, value in paths.items():
|
||||
db.set_dataset(JsonLoader(fields={"text":None, "summary":None, "label":None, "publication":None})._load(value), key)
|
||||
else:
|
||||
db.set_dataset(JsonLoader(fields={"text":None, "summary":None, "label":None, "publication":None})._load(paths), 'test')
|
||||
self.process(db)
|
||||
for ds in db.datasets.values():
|
||||
db.get_vocab("vocab").index_dataset(ds, field_name=Const.INPUT, new_field_name=Const.INPUT)
|
||||
|
||||
return db
|
||||
|
||||
|
||||
|
||||
def _lower_text(text_list):
|
||||
return [text.lower() for text in text_list]
|
||||
|
||||
def _split_list(text_list):
|
||||
return [text.split() for text in text_list]
|
||||
|
||||
def _convert_label(label, sent_len):
|
||||
np_label = np.zeros(sent_len, dtype=int)
|
||||
if label != []:
|
||||
np_label[np.array(label)] = 1
|
||||
return np_label.tolist()
|
||||
|
||||
def _pad_sent(text_wd, sent_max_len):
|
||||
pad_text_wd = []
|
||||
for sent_wd in text_wd:
|
||||
if len(sent_wd) < sent_max_len:
|
||||
pad_num = sent_max_len - len(sent_wd)
|
||||
sent_wd.extend([WORD_PAD] * pad_num)
|
||||
else:
|
||||
sent_wd = sent_wd[:sent_max_len]
|
||||
pad_text_wd.append(sent_wd)
|
||||
return pad_text_wd
|
||||
|
||||
def _token_mask(text_wd, sent_max_len):
|
||||
token_mask_list = []
|
||||
for sent_wd in text_wd:
|
||||
token_num = len(sent_wd)
|
||||
if token_num < sent_max_len:
|
||||
mask = [1] * token_num + [0] * (sent_max_len - token_num)
|
||||
else:
|
||||
mask = [1] * sent_max_len
|
||||
token_mask_list.append(mask)
|
||||
return token_mask_list
|
||||
|
||||
def _pad_label(label, doc_max_timesteps):
|
||||
text_len = len(label)
|
||||
if text_len < doc_max_timesteps:
|
||||
pad_label = label + [0] * (doc_max_timesteps - text_len)
|
||||
else:
|
||||
pad_label = label[:doc_max_timesteps]
|
||||
return pad_label
|
||||
|
||||
def _pad_doc(text_wd, sent_max_len, doc_max_timesteps):
|
||||
text_len = len(text_wd)
|
||||
if text_len < doc_max_timesteps:
|
||||
padding = [WORD_PAD] * sent_max_len
|
||||
pad_text = text_wd + [padding] * (doc_max_timesteps - text_len)
|
||||
else:
|
||||
pad_text = text_wd[:doc_max_timesteps]
|
||||
return pad_text
|
||||
|
||||
def _sent_mask(text_wd, doc_max_timesteps):
|
||||
text_len = len(text_wd)
|
||||
if text_len < doc_max_timesteps:
|
||||
sent_mask = [1] * text_len + [0] * (doc_max_timesteps - text_len)
|
||||
else:
|
||||
sent_mask = [1] * doc_max_timesteps
|
||||
return sent_mask
|
||||
|
||||
|
@ -1,188 +0,0 @@
|
||||
import pickle
|
||||
import numpy as np
|
||||
|
||||
from fastNLP.core.vocabulary import Vocabulary
|
||||
from fastNLP.io.data_bundle import DataBundle
|
||||
from fastNLP.io.dataset_loader import JsonLoader
|
||||
from fastNLP.core.const import Const
|
||||
|
||||
from tools.logger import *
|
||||
|
||||
WORD_PAD = "[PAD]"
|
||||
WORD_UNK = "[UNK]"
|
||||
DOMAIN_UNK = "X"
|
||||
TAG_UNK = "X"
|
||||
|
||||
|
||||
class SummarizationLoader(JsonLoader):
|
||||
"""
|
||||
读取summarization数据集,读取的DataSet包含fields::
|
||||
|
||||
text: list(str),document
|
||||
summary: list(str), summary
|
||||
text_wd: list(list(str)),tokenized document
|
||||
summary_wd: list(list(str)), tokenized summary
|
||||
labels: list(int),
|
||||
flatten_label: list(int), 0 or 1, flatten labels
|
||||
domain: str, optional
|
||||
tag: list(str), optional
|
||||
|
||||
数据来源: CNN_DailyMail Newsroom DUC
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(SummarizationLoader, self).__init__()
|
||||
|
||||
def _load(self, path):
|
||||
ds = super(SummarizationLoader, self)._load(path)
|
||||
|
||||
def _lower_text(text_list):
|
||||
return [text.lower() for text in text_list]
|
||||
|
||||
def _split_list(text_list):
|
||||
return [text.split() for text in text_list]
|
||||
|
||||
def _convert_label(label, sent_len):
|
||||
np_label = np.zeros(sent_len, dtype=int)
|
||||
if label != []:
|
||||
np_label[np.array(label)] = 1
|
||||
return np_label.tolist()
|
||||
|
||||
ds.apply(lambda x: _lower_text(x['text']), new_field_name='text')
|
||||
ds.apply(lambda x: _lower_text(x['summary']), new_field_name='summary')
|
||||
ds.apply(lambda x:_split_list(x['text']), new_field_name='text_wd')
|
||||
ds.apply(lambda x:_split_list(x['summary']), new_field_name='summary_wd')
|
||||
ds.apply(lambda x:_convert_label(x["label"], len(x["text"])), new_field_name="flatten_label")
|
||||
|
||||
return ds
|
||||
|
||||
def process(self, paths, vocab_size, vocab_path, sent_max_len, doc_max_timesteps, domain=False, tag=False, load_vocab_file=True):
|
||||
"""
|
||||
:param paths: dict path for each dataset
|
||||
:param vocab_size: int max_size for vocab
|
||||
:param vocab_path: str vocab path
|
||||
:param sent_max_len: int max token number of the sentence
|
||||
:param doc_max_timesteps: int max sentence number of the document
|
||||
:param domain: bool build vocab for publication, use 'X' for unknown
|
||||
:param tag: bool build vocab for tag, use 'X' for unknown
|
||||
:param load_vocab_file: bool build vocab (False) or load vocab (True)
|
||||
:return: DataBundle
|
||||
datasets: dict keys correspond to the paths dict
|
||||
vocabs: dict key: vocab(if "train" in paths), domain(if domain=True), tag(if tag=True)
|
||||
embeddings: optional
|
||||
"""
|
||||
|
||||
def _pad_sent(text_wd):
|
||||
pad_text_wd = []
|
||||
for sent_wd in text_wd:
|
||||
if len(sent_wd) < sent_max_len:
|
||||
pad_num = sent_max_len - len(sent_wd)
|
||||
sent_wd.extend([WORD_PAD] * pad_num)
|
||||
else:
|
||||
sent_wd = sent_wd[:sent_max_len]
|
||||
pad_text_wd.append(sent_wd)
|
||||
return pad_text_wd
|
||||
|
||||
def _token_mask(text_wd):
|
||||
token_mask_list = []
|
||||
for sent_wd in text_wd:
|
||||
token_num = len(sent_wd)
|
||||
if token_num < sent_max_len:
|
||||
mask = [1] * token_num + [0] * (sent_max_len - token_num)
|
||||
else:
|
||||
mask = [1] * sent_max_len
|
||||
token_mask_list.append(mask)
|
||||
return token_mask_list
|
||||
|
||||
def _pad_label(label):
|
||||
text_len = len(label)
|
||||
if text_len < doc_max_timesteps:
|
||||
pad_label = label + [0] * (doc_max_timesteps - text_len)
|
||||
else:
|
||||
pad_label = label[:doc_max_timesteps]
|
||||
return pad_label
|
||||
|
||||
def _pad_doc(text_wd):
|
||||
text_len = len(text_wd)
|
||||
if text_len < doc_max_timesteps:
|
||||
padding = [WORD_PAD] * sent_max_len
|
||||
pad_text = text_wd + [padding] * (doc_max_timesteps - text_len)
|
||||
else:
|
||||
pad_text = text_wd[:doc_max_timesteps]
|
||||
return pad_text
|
||||
|
||||
def _sent_mask(text_wd):
|
||||
text_len = len(text_wd)
|
||||
if text_len < doc_max_timesteps:
|
||||
sent_mask = [1] * text_len + [0] * (doc_max_timesteps - text_len)
|
||||
else:
|
||||
sent_mask = [1] * doc_max_timesteps
|
||||
return sent_mask
|
||||
|
||||
|
||||
datasets = {}
|
||||
train_ds = None
|
||||
for key, value in paths.items():
|
||||
ds = self.load(value)
|
||||
# pad sent
|
||||
ds.apply(lambda x:_pad_sent(x["text_wd"]), new_field_name="pad_text_wd")
|
||||
ds.apply(lambda x:_token_mask(x["text_wd"]), new_field_name="pad_token_mask")
|
||||
# pad document
|
||||
ds.apply(lambda x:_pad_doc(x["pad_text_wd"]), new_field_name="pad_text")
|
||||
ds.apply(lambda x:_sent_mask(x["pad_text_wd"]), new_field_name="seq_len")
|
||||
ds.apply(lambda x:_pad_label(x["flatten_label"]), new_field_name="pad_label")
|
||||
|
||||
# rename field
|
||||
ds.rename_field("pad_text", Const.INPUT)
|
||||
ds.rename_field("seq_len", Const.INPUT_LEN)
|
||||
ds.rename_field("pad_label", Const.TARGET)
|
||||
|
||||
# set input and target
|
||||
ds.set_input(Const.INPUT, Const.INPUT_LEN)
|
||||
ds.set_target(Const.TARGET, Const.INPUT_LEN)
|
||||
|
||||
datasets[key] = ds
|
||||
if "train" in key:
|
||||
train_ds = datasets[key]
|
||||
|
||||
vocab_dict = {}
|
||||
if load_vocab_file == False:
|
||||
logger.info("[INFO] Build new vocab from training dataset!")
|
||||
if train_ds == None:
|
||||
raise ValueError("Lack train file to build vocabulary!")
|
||||
|
||||
vocabs = Vocabulary(max_size=vocab_size, padding=WORD_PAD, unknown=WORD_UNK)
|
||||
vocabs.from_dataset(train_ds, field_name=["text_wd","summary_wd"])
|
||||
vocab_dict["vocab"] = vocabs
|
||||
else:
|
||||
logger.info("[INFO] Load existing vocab from %s!" % vocab_path)
|
||||
word_list = []
|
||||
with open(vocab_path, 'r', encoding='utf8') as vocab_f:
|
||||
cnt = 2 # pad and unk
|
||||
for line in vocab_f:
|
||||
pieces = line.split("\t")
|
||||
word_list.append(pieces[0])
|
||||
cnt += 1
|
||||
if cnt > vocab_size:
|
||||
break
|
||||
vocabs = Vocabulary(max_size=vocab_size, padding=WORD_PAD, unknown=WORD_UNK)
|
||||
vocabs.add_word_lst(word_list)
|
||||
vocabs.build_vocab()
|
||||
vocab_dict["vocab"] = vocabs
|
||||
|
||||
if domain == True:
|
||||
domaindict = Vocabulary(padding=None, unknown=DOMAIN_UNK)
|
||||
domaindict.from_dataset(train_ds, field_name="publication")
|
||||
vocab_dict["domain"] = domaindict
|
||||
if tag == True:
|
||||
tagdict = Vocabulary(padding=None, unknown=TAG_UNK)
|
||||
tagdict.from_dataset(train_ds, field_name="tag")
|
||||
vocab_dict["tag"] = tagdict
|
||||
|
||||
for ds in datasets.values():
|
||||
vocab_dict["vocab"].index_dataset(ds, field_name=Const.INPUT, new_field_name=Const.INPUT)
|
||||
|
||||
return DataBundle(vocabs=vocab_dict, datasets=datasets)
|
||||
|
||||
|
||||
|
@ -94,6 +94,8 @@ class Encoder(nn.Module):
|
||||
if self._hps.cuda:
|
||||
input_pos = input_pos.cuda()
|
||||
enc_pos_embed_input = self.position_embedding(input_pos.long()) # [batch_size*N, D]
|
||||
# print(enc_embed_input.size())
|
||||
# print(enc_pos_embed_input.size())
|
||||
enc_conv_input = enc_embed_input + enc_pos_embed_input
|
||||
enc_conv_input = enc_conv_input.unsqueeze(1) # (batch * N,Ci,L,D)
|
||||
enc_conv_output = [F.relu(conv(enc_conv_input)).squeeze(3) for conv in self.convs] # kernel_sizes * (batch*N, Co, W)
|
||||
|
@ -17,11 +17,12 @@ class SummarizationModel(nn.Module):
|
||||
"""
|
||||
|
||||
:param hps: hyperparameters for the model
|
||||
:param vocab: vocab object
|
||||
:param embed: word embedding
|
||||
"""
|
||||
super(SummarizationModel, self).__init__()
|
||||
|
||||
self._hps = hps
|
||||
self.Train = (hps.mode == 'train')
|
||||
|
||||
# sentence encoder
|
||||
self.encoder = Encoder(hps, embed)
|
||||
@ -45,18 +46,19 @@ class SummarizationModel(nn.Module):
|
||||
self.wh = nn.Linear(self.d_v, 2)
|
||||
|
||||
|
||||
def forward(self, input, input_len, Train):
|
||||
def forward(self, words, seq_len):
|
||||
"""
|
||||
|
||||
:param input: [batch_size, N, seq_len], word idx long tensor
|
||||
:param input_len: [batch_size, N], 1 for sentence and 0 for padding
|
||||
:param Train: True for train and False for eval and test
|
||||
:param return_atten: True or False to return multi-head attention output self.output_slf_attn
|
||||
:return:
|
||||
p_sent: [batch_size, N, 2]
|
||||
output_slf_attn: (option) [n_head, batch_size, N, N]
|
||||
"""
|
||||
|
||||
input = words
|
||||
input_len = seq_len
|
||||
|
||||
# -- Sentence Encoder
|
||||
self.sent_embedding = self.encoder(input) # [batch, N, Co * kernel_sizes]
|
||||
|
||||
@ -67,7 +69,7 @@ class SummarizationModel(nn.Module):
|
||||
self.inputs[0] = self.sent_embedding.permute(1, 0, 2) # [N, batch, Co * kernel_sizes]
|
||||
self.input_masks[0] = input_len.permute(1, 0).unsqueeze(2)
|
||||
|
||||
self.lstm_output_state = self.deep_lstm(self.inputs, self.input_masks, Train) # [batch, N, hidden_size]
|
||||
self.lstm_output_state = self.deep_lstm(self.inputs, self.input_masks, Train=self.train) # [batch, N, hidden_size]
|
||||
|
||||
# -- Prepare masks
|
||||
batch_size, N = input_len.size()
|
||||
|
@ -21,7 +21,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from fastNLP.core.losses import LossBase
|
||||
from tools.logger import *
|
||||
from fastNLP.core._logger import logger
|
||||
|
||||
class MyCrossEntropyLoss(LossBase):
|
||||
def __init__(self, pred=None, target=None, mask=None, padding_idx=-100, reduce='mean'):
|
||||
|
@ -20,14 +20,60 @@ from __future__ import division
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from rouge import Rouge
|
||||
|
||||
from fastNLP.core.const import Const
|
||||
from fastNLP.core.metrics import MetricBase
|
||||
|
||||
from tools.logger import *
|
||||
# from tools.logger import *
|
||||
from fastNLP.core._logger import logger
|
||||
from tools.utils import pyrouge_score_all, pyrouge_score_all_multi
|
||||
|
||||
|
||||
class LossMetric(MetricBase):
|
||||
def __init__(self, pred=None, target=None, mask=None, padding_idx=-100, reduce='mean'):
|
||||
super().__init__()
|
||||
|
||||
self._init_param_map(pred=pred, target=target, mask=mask)
|
||||
self.padding_idx = padding_idx
|
||||
self.reduce = reduce
|
||||
self.loss = 0.0
|
||||
self.iteration = 0
|
||||
|
||||
def evaluate(self, pred, target, mask):
|
||||
"""
|
||||
|
||||
:param pred: [batch, N, 2]
|
||||
:param target: [batch, N]
|
||||
:param input_mask: [batch, N]
|
||||
:return:
|
||||
"""
|
||||
|
||||
batch, N, _ = pred.size()
|
||||
pred = pred.view(-1, 2)
|
||||
target = target.view(-1)
|
||||
loss = F.cross_entropy(input=pred, target=target,
|
||||
ignore_index=self.padding_idx, reduction=self.reduce)
|
||||
loss = loss.view(batch, -1)
|
||||
loss = loss.masked_fill(mask.eq(0), 0)
|
||||
loss = loss.sum(1).mean()
|
||||
self.loss += loss
|
||||
self.iteration += 1
|
||||
|
||||
def get_metric(self, reset=True):
|
||||
epoch_avg_loss = self.loss / self.iteration
|
||||
if reset:
|
||||
self.loss = 0.0
|
||||
self.iteration = 0
|
||||
metric = {"loss": -epoch_avg_loss}
|
||||
logger.info(metric)
|
||||
return metric
|
||||
|
||||
|
||||
|
||||
|
||||
class LabelFMetric(MetricBase):
|
||||
def __init__(self, pred=None, target=None):
|
||||
super().__init__()
|
||||
|
@ -51,7 +51,7 @@ class TransformerModel(nn.Module):
|
||||
ffn_inner_hidden_size: FFN hiddens size
|
||||
atten_dropout_prob: dropout size
|
||||
doc_max_timesteps: max sentence number of the document
|
||||
:param vocab:
|
||||
:param embed: word embedding
|
||||
"""
|
||||
super(TransformerModel, self).__init__()
|
||||
|
||||
|
@ -28,7 +28,7 @@ from fastNLP.core.const import Const
|
||||
from fastNLP.io.model_io import ModelSaver
|
||||
from fastNLP.core.callback import Callback, EarlyStopError
|
||||
|
||||
from tools.logger import *
|
||||
from fastNLP.core._logger import logger
|
||||
|
||||
class TrainCallback(Callback):
|
||||
def __init__(self, hps, patience=3, quit_all=True):
|
||||
@ -36,6 +36,9 @@ class TrainCallback(Callback):
|
||||
self._hps = hps
|
||||
self.patience = patience
|
||||
self.wait = 0
|
||||
self.train_loss = 0.0
|
||||
self.prev_train_avg_loss = 1000.0
|
||||
self.train_dir = os.path.join(self._hps.save_root, "train")
|
||||
|
||||
if type(quit_all) != bool:
|
||||
raise ValueError("In KeyBoardInterrupt, quit_all arguemnt must be a bool.")
|
||||
@ -43,20 +46,7 @@ class TrainCallback(Callback):
|
||||
|
||||
def on_epoch_begin(self):
|
||||
self.epoch_start_time = time.time()
|
||||
|
||||
# def on_loss_begin(self, batch_y, predict_y):
|
||||
# """
|
||||
#
|
||||
# :param batch_y: dict
|
||||
# input_len: [batch, N]
|
||||
# :param predict_y: dict
|
||||
# p_sent: [batch, N, 2]
|
||||
# :return:
|
||||
# """
|
||||
# input_len = batch_y[Const.INPUT_LEN]
|
||||
# batch_y[Const.TARGET] = batch_y[Const.TARGET] * ((1 - input_len) * -100)
|
||||
# # predict_y["p_sent"] = predict_y["p_sent"] * input_len.unsqueeze(-1)
|
||||
# # logger.debug(predict_y["p_sent"][0:5,:,:])
|
||||
self.model.Train = True
|
||||
|
||||
def on_backward_begin(self, loss):
|
||||
"""
|
||||
@ -72,19 +62,34 @@ class TrainCallback(Callback):
|
||||
logger.info(name)
|
||||
logger.info(param.grad.data.sum())
|
||||
raise Exception("train Loss is not finite. Stopping.")
|
||||
self.train_loss += loss.data
|
||||
|
||||
|
||||
def on_backward_end(self):
|
||||
if self._hps.grad_clip:
|
||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self._hps.max_grad_norm)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def on_epoch_end(self):
|
||||
logger.info(' | end of epoch {:3d} | time: {:5.2f}s | '
|
||||
.format(self.epoch, (time.time() - self.epoch_start_time)))
|
||||
epoch_avg_loss = self.train_loss / self.n_steps
|
||||
logger.info(' | end of epoch {:3d} | time: {:5.2f}s | train loss: {:5.6f}'
|
||||
.format(self.epoch, (time.time() - self.epoch_start_time), epoch_avg_loss))
|
||||
if self.prev_train_avg_loss < epoch_avg_loss:
|
||||
save_file = os.path.join(self.train_dir, "earlystop.pkl")
|
||||
self.save_model(save_file)
|
||||
else:
|
||||
self.prev_train_avg_loss = epoch_avg_loss
|
||||
self.train_loss = 0.0
|
||||
|
||||
# save epoch
|
||||
save_file = os.path.join(self.train_dir, "epoch_%d.pkl" % self.epoch)
|
||||
self.save_model(save_file)
|
||||
|
||||
|
||||
|
||||
def on_valid_begin(self):
|
||||
self.valid_start_time = time.time()
|
||||
self.model.Train = False
|
||||
|
||||
def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval):
|
||||
logger.info(' | end of valid {:3d} | time: {:5.2f}s | '
|
||||
@ -95,9 +100,7 @@ class TrainCallback(Callback):
|
||||
if self.wait == self.patience:
|
||||
train_dir = os.path.join(self._hps.save_root, "train")
|
||||
save_file = os.path.join(train_dir, "earlystop.pkl")
|
||||
saver = ModelSaver(save_file)
|
||||
saver.save_pytorch(self.model)
|
||||
logger.info('[INFO] Saving early stop model to %s', save_file)
|
||||
self.save_model(save_file)
|
||||
raise EarlyStopError("Early stopping raised.")
|
||||
else:
|
||||
self.wait += 1
|
||||
@ -111,14 +114,12 @@ class TrainCallback(Callback):
|
||||
param_group['lr'] = new_lr
|
||||
logger.info("[INFO] The learning rate now is %f", new_lr)
|
||||
|
||||
|
||||
def on_exception(self, exception):
|
||||
if isinstance(exception, KeyboardInterrupt):
|
||||
logger.error("[Error] Caught keyboard interrupt on worker. Stopping supervisor...")
|
||||
train_dir = os.path.join(self._hps.save_root, "train")
|
||||
save_file = os.path.join(train_dir, "earlystop.pkl")
|
||||
saver = ModelSaver(save_file)
|
||||
saver.save_pytorch(self.model)
|
||||
logger.info('[INFO] Saving early stop model to %s', save_file)
|
||||
save_file = os.path.join(self.train_dir, "earlystop.pkl")
|
||||
self.save_model(save_file)
|
||||
|
||||
if self.quit_all is True:
|
||||
sys.exit(0) # 直接退出程序
|
||||
@ -127,6 +128,11 @@ class TrainCallback(Callback):
|
||||
else:
|
||||
raise exception # 抛出陌生Error
|
||||
|
||||
def save_model(self, save_file):
|
||||
saver = ModelSaver(save_file)
|
||||
saver.save_pytorch(self.model)
|
||||
logger.info('[INFO] Saving model to %s', save_file)
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -1,562 +0,0 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd import *
|
||||
import torch.nn.init as init
|
||||
|
||||
import data
|
||||
from tools.logger import *
|
||||
from transformer.Models import get_sinusoid_encoding_table
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, hps, vocab):
|
||||
super(Encoder, self).__init__()
|
||||
|
||||
self._hps = hps
|
||||
self._vocab = vocab
|
||||
self.sent_max_len = hps.sent_max_len
|
||||
|
||||
vocab_size = len(vocab)
|
||||
logger.info("[INFO] Vocabulary size is %d", vocab_size)
|
||||
embed_size = hps.word_emb_dim
|
||||
sent_max_len = hps.sent_max_len
|
||||
|
||||
input_channels = 1
|
||||
out_channels = hps.output_channel
|
||||
min_kernel_size = hps.min_kernel_size
|
||||
max_kernel_size = hps.max_kernel_size
|
||||
width = embed_size
|
||||
|
||||
# word embedding
|
||||
self.embed = nn.Embedding(vocab_size, embed_size, padding_idx=vocab.word2id('[PAD]'))
|
||||
if hps.word_embedding:
|
||||
word2vec = data.Word_Embedding(hps.embedding_path, vocab)
|
||||
word_vecs = word2vec.load_my_vecs(embed_size)
|
||||
# pretrained_weight = word2vec.add_unknown_words_by_zero(word_vecs, embed_size)
|
||||
pretrained_weight = word2vec.add_unknown_words_by_avg(word_vecs, embed_size)
|
||||
pretrained_weight = np.array(pretrained_weight)
|
||||
self.embed.weight.data.copy_(torch.from_numpy(pretrained_weight))
|
||||
self.embed.weight.requires_grad = hps.embed_train
|
||||
|
||||
# position embedding
|
||||
self.position_embedding = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(sent_max_len + 1, embed_size, padding_idx=0), freeze=True)
|
||||
|
||||
# cnn
|
||||
self.convs = nn.ModuleList([nn.Conv2d(input_channels, out_channels, kernel_size = (height, width)) for height in range(min_kernel_size, max_kernel_size+1)])
|
||||
logger.info("[INFO] Initing W for CNN.......")
|
||||
for conv in self.convs:
|
||||
init_weight_value = 6.0
|
||||
init.xavier_normal_(conv.weight.data, gain=np.sqrt(init_weight_value))
|
||||
fan_in, fan_out = Encoder.calculate_fan_in_and_fan_out(conv.weight.data)
|
||||
std = np.sqrt(init_weight_value) * np.sqrt(2.0 / (fan_in + fan_out))
|
||||
|
||||
def calculate_fan_in_and_fan_out(tensor):
|
||||
dimensions = tensor.ndimension()
|
||||
if dimensions < 2:
|
||||
logger.error("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions")
|
||||
raise ValueError("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions")
|
||||
|
||||
if dimensions == 2: # Linear
|
||||
fan_in = tensor.size(1)
|
||||
fan_out = tensor.size(0)
|
||||
else:
|
||||
num_input_fmaps = tensor.size(1)
|
||||
num_output_fmaps = tensor.size(0)
|
||||
receptive_field_size = 1
|
||||
if tensor.dim() > 2:
|
||||
receptive_field_size = tensor[0][0].numel()
|
||||
fan_in = num_input_fmaps * receptive_field_size
|
||||
fan_out = num_output_fmaps * receptive_field_size
|
||||
|
||||
return fan_in, fan_out
|
||||
|
||||
def forward(self, input):
|
||||
# input: a batch of Example object [batch_size, N, seq_len]
|
||||
vocab = self._vocab
|
||||
|
||||
batch_size, N, _ = input.size()
|
||||
input = input.view(-1, input.size(2)) # [batch_size*N, L]
|
||||
input_sent_len = ((input!=vocab.word2id('[PAD]')).sum(dim=1)).int() # [batch_size*N, 1]
|
||||
enc_embed_input = self.embed(input) # [batch_size*N, L, D]
|
||||
|
||||
input_pos = torch.Tensor([np.hstack((np.arange(1, sentlen + 1), np.zeros(self.sent_max_len - sentlen))) for sentlen in input_sent_len])
|
||||
if self._hps.cuda:
|
||||
input_pos = input_pos.cuda()
|
||||
enc_pos_embed_input = self.position_embedding(input_pos.long()) # [batch_size*N, D]
|
||||
enc_conv_input = enc_embed_input + enc_pos_embed_input
|
||||
enc_conv_input = enc_conv_input.unsqueeze(1) # (batch * N,Ci,L,D)
|
||||
enc_conv_output = [F.relu(conv(enc_conv_input)).squeeze(3) for conv in self.convs] # kernel_sizes * (batch*N, Co, W)
|
||||
enc_maxpool_output = [F.max_pool1d(x, x.size(2)).squeeze(2) for x in enc_conv_output] # kernel_sizes * (batch*N, Co)
|
||||
sent_embedding = torch.cat(enc_maxpool_output, 1) # (batch*N, Co * kernel_sizes)
|
||||
sent_embedding = sent_embedding.view(batch_size, N, -1)
|
||||
return sent_embedding
|
||||
|
||||
class DomainEncoder(Encoder):
|
||||
def __init__(self, hps, vocab, domaindict):
|
||||
super(DomainEncoder, self).__init__(hps, vocab)
|
||||
|
||||
# domain embedding
|
||||
self.domain_embedding = nn.Embedding(domaindict.size(), hps.domain_emb_dim)
|
||||
self.domain_embedding.weight.requires_grad = True
|
||||
|
||||
def forward(self, input, domain):
|
||||
"""
|
||||
:param input: [batch_size, N, seq_len], N sentence number, seq_len token number
|
||||
:param domain: [batch_size]
|
||||
:return: sent_embedding: [batch_size, N, Co * kernel_sizes]
|
||||
"""
|
||||
|
||||
batch_size, N, _ = input.size()
|
||||
|
||||
sent_embedding = super().forward(input)
|
||||
enc_domain_input = self.domain_embedding(domain) # [batch, D]
|
||||
enc_domain_input = enc_domain_input.unsqueeze(1).expand(batch_size, N, -1) # [batch, N, D]
|
||||
sent_embedding = torch.cat((sent_embedding, enc_domain_input), dim=2)
|
||||
return sent_embedding
|
||||
|
||||
class MultiDomainEncoder(Encoder):
|
||||
def __init__(self, hps, vocab, domaindict):
|
||||
super(MultiDomainEncoder, self).__init__(hps, vocab)
|
||||
|
||||
self.domain_size = domaindict.size()
|
||||
|
||||
# domain embedding
|
||||
self.domain_embedding = nn.Embedding(self.domain_size, hps.domain_emb_dim)
|
||||
self.domain_embedding.weight.requires_grad = True
|
||||
|
||||
def forward(self, input, domain):
|
||||
"""
|
||||
:param input: [batch_size, N, seq_len], N sentence number, seq_len token number
|
||||
:param domain: [batch_size, domain_size]
|
||||
:return: sent_embedding: [batch_size, N, Co * kernel_sizes]
|
||||
"""
|
||||
|
||||
batch_size, N, _ = input.size()
|
||||
|
||||
# logger.info(domain[:5, :])
|
||||
|
||||
sent_embedding = super().forward(input)
|
||||
domain_padding = torch.arange(self.domain_size).unsqueeze(0).expand(batch_size, -1)
|
||||
domain_padding = domain_padding.cuda().view(-1) if self._hps.cuda else domain_padding.view(-1) # [batch * domain_size]
|
||||
|
||||
enc_domain_input = self.domain_embedding(domain_padding) # [batch * domain_size, D]
|
||||
enc_domain_input = enc_domain_input.view(batch_size, self.domain_size, -1) * domain.unsqueeze(-1).float() # [batch, domain_size, D]
|
||||
|
||||
# logger.info(enc_domain_input[:5,:]) # [batch, domain_size, D]
|
||||
|
||||
enc_domain_input = enc_domain_input.sum(1) / domain.sum(1).float().unsqueeze(-1) # [batch, D]
|
||||
enc_domain_input = enc_domain_input.unsqueeze(1).expand(batch_size, N, -1) # [batch, N, D]
|
||||
sent_embedding = torch.cat((sent_embedding, enc_domain_input), dim=2)
|
||||
return sent_embedding
|
||||
|
||||
|
||||
class BertEncoder(nn.Module):
|
||||
def __init__(self, hps):
|
||||
super(BertEncoder, self).__init__()
|
||||
|
||||
from pytorch_pretrained_bert.modeling import BertModel
|
||||
|
||||
self._hps = hps
|
||||
self.sent_max_len = hps.sent_max_len
|
||||
self._cuda = hps.cuda
|
||||
|
||||
embed_size = hps.word_emb_dim
|
||||
sent_max_len = hps.sent_max_len
|
||||
|
||||
input_channels = 1
|
||||
out_channels = hps.output_channel
|
||||
min_kernel_size = hps.min_kernel_size
|
||||
max_kernel_size = hps.max_kernel_size
|
||||
width = embed_size
|
||||
|
||||
# word embedding
|
||||
self._bert = BertModel.from_pretrained("/remote-home/dqwang/BERT/pre-train/uncased_L-24_H-1024_A-16")
|
||||
self._bert.eval()
|
||||
for p in self._bert.parameters():
|
||||
p.requires_grad = False
|
||||
|
||||
self.word_embedding_proj = nn.Linear(4096, embed_size)
|
||||
|
||||
# position embedding
|
||||
self.position_embedding = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(sent_max_len + 1, embed_size, padding_idx=0), freeze=True)
|
||||
|
||||
# cnn
|
||||
self.convs = nn.ModuleList([nn.Conv2d(input_channels, out_channels, kernel_size = (height, width)) for height in range(min_kernel_size, max_kernel_size+1)])
|
||||
logger.info("[INFO] Initing W for CNN.......")
|
||||
for conv in self.convs:
|
||||
init_weight_value = 6.0
|
||||
init.xavier_normal_(conv.weight.data, gain=np.sqrt(init_weight_value))
|
||||
fan_in, fan_out = Encoder.calculate_fan_in_and_fan_out(conv.weight.data)
|
||||
std = np.sqrt(init_weight_value) * np.sqrt(2.0 / (fan_in + fan_out))
|
||||
|
||||
def calculate_fan_in_and_fan_out(tensor):
|
||||
dimensions = tensor.ndimension()
|
||||
if dimensions < 2:
|
||||
logger.error("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions")
|
||||
raise ValueError("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions")
|
||||
|
||||
if dimensions == 2: # Linear
|
||||
fan_in = tensor.size(1)
|
||||
fan_out = tensor.size(0)
|
||||
else:
|
||||
num_input_fmaps = tensor.size(1)
|
||||
num_output_fmaps = tensor.size(0)
|
||||
receptive_field_size = 1
|
||||
if tensor.dim() > 2:
|
||||
receptive_field_size = tensor[0][0].numel()
|
||||
fan_in = num_input_fmaps * receptive_field_size
|
||||
fan_out = num_output_fmaps * receptive_field_size
|
||||
|
||||
return fan_in, fan_out
|
||||
|
||||
def pad_encoder_input(self, input_list):
|
||||
"""
|
||||
:param input_list: N [seq_len, hidden_state]
|
||||
:return: enc_sent_input_pad: list, N [max_len, hidden_state]
|
||||
"""
|
||||
max_len = self.sent_max_len
|
||||
enc_sent_input_pad = []
|
||||
_, hidden_size = input_list[0].size()
|
||||
for i in range(len(input_list)):
|
||||
article_words = input_list[i] # [seq_len, hidden_size]
|
||||
seq_len = article_words.size(0)
|
||||
if seq_len > max_len:
|
||||
pad_words = article_words[:max_len, :]
|
||||
else:
|
||||
pad_tensor = torch.zeros(max_len - seq_len, hidden_size).cuda() if self._cuda else torch.zeros(max_len - seq_len, hidden_size)
|
||||
pad_words = torch.cat([article_words, pad_tensor], dim=0)
|
||||
enc_sent_input_pad.append(pad_words)
|
||||
return enc_sent_input_pad
|
||||
|
||||
def forward(self, inputs, input_masks, enc_sent_len):
|
||||
"""
|
||||
|
||||
:param inputs: a batch of Example object [batch_size, doc_len=512]
|
||||
:param input_masks: 0 or 1, [batch, doc_len=512]
|
||||
:param enc_sent_len: sentence original length [batch, N]
|
||||
:return:
|
||||
"""
|
||||
|
||||
|
||||
# Use Bert to get word embedding
|
||||
batch_size, N = enc_sent_len.size()
|
||||
input_pad_list = []
|
||||
for i in range(batch_size):
|
||||
tokens_id = inputs[i]
|
||||
input_mask = input_masks[i]
|
||||
sent_len = enc_sent_len[i]
|
||||
input_ids = tokens_id.unsqueeze(0)
|
||||
input_mask = input_mask.unsqueeze(0)
|
||||
|
||||
out, _ = self._bert(input_ids, token_type_ids=None, attention_mask=input_mask)
|
||||
out = torch.cat(out[-4:], dim=-1).squeeze(0) # [doc_len=512, hidden_state=4096]
|
||||
|
||||
_, hidden_size = out.size()
|
||||
|
||||
# restore the sentence
|
||||
last_end = 1
|
||||
enc_sent_input = []
|
||||
for length in sent_len:
|
||||
if length != 0 and last_end < 511:
|
||||
enc_sent_input.append(out[last_end: min(511, last_end + length), :])
|
||||
last_end += length
|
||||
else:
|
||||
pad_tensor = torch.zeros(self.sent_max_len, hidden_size).cuda() if self._hps.cuda else torch.zeros(self.sent_max_len, hidden_size)
|
||||
enc_sent_input.append(pad_tensor)
|
||||
|
||||
|
||||
# pad the sentence
|
||||
enc_sent_input_pad = self.pad_encoder_input(enc_sent_input) # [N, seq_len, hidden_state=4096]
|
||||
input_pad_list.append(torch.stack(enc_sent_input_pad))
|
||||
|
||||
input_pad = torch.stack(input_pad_list)
|
||||
|
||||
input_pad = input_pad.view(batch_size*N, self.sent_max_len, -1)
|
||||
enc_sent_len = enc_sent_len.view(-1) # [batch_size*N]
|
||||
enc_embed_input = self.word_embedding_proj(input_pad) # [batch_size * N, L, D]
|
||||
|
||||
sent_pos_list = []
|
||||
for sentlen in enc_sent_len:
|
||||
sent_pos = list(range(1, min(self.sent_max_len, sentlen) + 1))
|
||||
for k in range(self.sent_max_len - sentlen):
|
||||
sent_pos.append(0)
|
||||
sent_pos_list.append(sent_pos)
|
||||
input_pos = torch.Tensor(sent_pos_list).long()
|
||||
|
||||
if self._hps.cuda:
|
||||
input_pos = input_pos.cuda()
|
||||
enc_pos_embed_input = self.position_embedding(input_pos.long()) # [batch_size*N, D]
|
||||
enc_conv_input = enc_embed_input + enc_pos_embed_input
|
||||
enc_conv_input = enc_conv_input.unsqueeze(1) # (batch * N,Ci,L,D)
|
||||
enc_conv_output = [F.relu(conv(enc_conv_input)).squeeze(3) for conv in self.convs] # kernel_sizes * (batch*N, Co, W)
|
||||
enc_maxpool_output = [F.max_pool1d(x, x.size(2)).squeeze(2) for x in enc_conv_output] # kernel_sizes * (batch*N, Co)
|
||||
sent_embedding = torch.cat(enc_maxpool_output, 1) # (batch*N, Co * kernel_sizes)
|
||||
sent_embedding = sent_embedding.view(batch_size, N, -1)
|
||||
return sent_embedding
|
||||
|
||||
|
||||
class BertTagEncoder(BertEncoder):
|
||||
def __init__(self, hps, domaindict):
|
||||
super(BertTagEncoder, self).__init__(hps)
|
||||
|
||||
# domain embedding
|
||||
self.domain_embedding = nn.Embedding(domaindict.size(), hps.domain_emb_dim)
|
||||
self.domain_embedding.weight.requires_grad = True
|
||||
|
||||
def forward(self, inputs, input_masks, enc_sent_len, domain):
|
||||
sent_embedding = super().forward(inputs, input_masks, enc_sent_len)
|
||||
|
||||
batch_size, N = enc_sent_len.size()
|
||||
|
||||
enc_domain_input = self.domain_embedding(domain) # [batch, D]
|
||||
enc_domain_input = enc_domain_input.unsqueeze(1).expand(batch_size, N, -1) # [batch, N, D]
|
||||
sent_embedding = torch.cat((sent_embedding, enc_domain_input), dim=2)
|
||||
|
||||
return sent_embedding
|
||||
|
||||
class ELMoEndoer(nn.Module):
|
||||
def __init__(self, hps):
|
||||
super(ELMoEndoer, self).__init__()
|
||||
|
||||
self._hps = hps
|
||||
self.sent_max_len = hps.sent_max_len
|
||||
|
||||
from allennlp.modules.elmo import Elmo
|
||||
|
||||
elmo_dim = 1024
|
||||
options_file = "/remote-home/dqwang/ELMo/elmo_2x4096_512_2048cnn_2xhighway_5.5B_options.json"
|
||||
weight_file = "/remote-home/dqwang/ELMo/elmo_2x4096_512_2048cnn_2xhighway_5.5B_weights.hdf5"
|
||||
|
||||
# elmo_dim = 512
|
||||
# options_file = "/remote-home/dqwang/ELMo/elmo_2x2048_256_2048cnn_1xhighway_options.json"
|
||||
# weight_file = "/remote-home/dqwang/ELMo/elmo_2x2048_256_2048cnn_1xhighway_weights.hdf5"
|
||||
|
||||
embed_size = hps.word_emb_dim
|
||||
sent_max_len = hps.sent_max_len
|
||||
|
||||
input_channels = 1
|
||||
out_channels = hps.output_channel
|
||||
min_kernel_size = hps.min_kernel_size
|
||||
max_kernel_size = hps.max_kernel_size
|
||||
width = embed_size
|
||||
|
||||
# elmo embedding
|
||||
self.elmo = Elmo(options_file, weight_file, 1, dropout=0)
|
||||
self.embed_proj = nn.Linear(elmo_dim, embed_size)
|
||||
|
||||
# position embedding
|
||||
self.position_embedding = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(sent_max_len + 1, embed_size, padding_idx=0), freeze=True)
|
||||
|
||||
# cnn
|
||||
self.convs = nn.ModuleList([nn.Conv2d(input_channels, out_channels, kernel_size = (height, width)) for height in range(min_kernel_size, max_kernel_size+1)])
|
||||
logger.info("[INFO] Initing W for CNN.......")
|
||||
for conv in self.convs:
|
||||
init_weight_value = 6.0
|
||||
init.xavier_normal_(conv.weight.data, gain=np.sqrt(init_weight_value))
|
||||
fan_in, fan_out = Encoder.calculate_fan_in_and_fan_out(conv.weight.data)
|
||||
std = np.sqrt(init_weight_value) * np.sqrt(2.0 / (fan_in + fan_out))
|
||||
|
||||
def calculate_fan_in_and_fan_out(tensor):
|
||||
dimensions = tensor.ndimension()
|
||||
if dimensions < 2:
|
||||
logger.error("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions")
|
||||
raise ValueError("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions")
|
||||
|
||||
if dimensions == 2: # Linear
|
||||
fan_in = tensor.size(1)
|
||||
fan_out = tensor.size(0)
|
||||
else:
|
||||
num_input_fmaps = tensor.size(1)
|
||||
num_output_fmaps = tensor.size(0)
|
||||
receptive_field_size = 1
|
||||
if tensor.dim() > 2:
|
||||
receptive_field_size = tensor[0][0].numel()
|
||||
fan_in = num_input_fmaps * receptive_field_size
|
||||
fan_out = num_output_fmaps * receptive_field_size
|
||||
|
||||
return fan_in, fan_out
|
||||
|
||||
def forward(self, input):
|
||||
# input: a batch of Example object [batch_size, N, seq_len, character_len]
|
||||
|
||||
batch_size, N, seq_len, _ = input.size()
|
||||
input = input.view(batch_size * N, seq_len, -1) # [batch_size*N, seq_len, character_len]
|
||||
input_sent_len = ((input.sum(-1)!=0).sum(dim=1)).int() # [batch_size*N, 1]
|
||||
logger.debug(input_sent_len.view(batch_size, -1))
|
||||
enc_embed_input = self.elmo(input)['elmo_representations'][0] # [batch_size*N, L, D]
|
||||
enc_embed_input = self.embed_proj(enc_embed_input)
|
||||
|
||||
# input_pos = torch.Tensor([np.hstack((np.arange(1, sentlen + 1), np.zeros(self.sent_max_len - sentlen))) for sentlen in input_sent_len])
|
||||
|
||||
sent_pos_list = []
|
||||
for sentlen in input_sent_len:
|
||||
sent_pos = list(range(1, min(self.sent_max_len, sentlen) + 1))
|
||||
for k in range(self.sent_max_len - sentlen):
|
||||
sent_pos.append(0)
|
||||
sent_pos_list.append(sent_pos)
|
||||
input_pos = torch.Tensor(sent_pos_list).long()
|
||||
|
||||
if self._hps.cuda:
|
||||
input_pos = input_pos.cuda()
|
||||
enc_pos_embed_input = self.position_embedding(input_pos.long()) # [batch_size*N, D]
|
||||
enc_conv_input = enc_embed_input + enc_pos_embed_input
|
||||
enc_conv_input = enc_conv_input.unsqueeze(1) # (batch * N,Ci,L,D)
|
||||
enc_conv_output = [F.relu(conv(enc_conv_input)).squeeze(3) for conv in self.convs] # kernel_sizes * (batch*N, Co, W)
|
||||
enc_maxpool_output = [F.max_pool1d(x, x.size(2)).squeeze(2) for x in enc_conv_output] # kernel_sizes * (batch*N, Co)
|
||||
sent_embedding = torch.cat(enc_maxpool_output, 1) # (batch*N, Co * kernel_sizes)
|
||||
sent_embedding = sent_embedding.view(batch_size, N, -1)
|
||||
return sent_embedding
|
||||
|
||||
class ELMoEndoer2(nn.Module):
|
||||
def __init__(self, hps):
|
||||
super(ELMoEndoer2, self).__init__()
|
||||
|
||||
self._hps = hps
|
||||
self._cuda = hps.cuda
|
||||
self.sent_max_len = hps.sent_max_len
|
||||
|
||||
from allennlp.modules.elmo import Elmo
|
||||
|
||||
elmo_dim = 1024
|
||||
options_file = "/remote-home/dqwang/ELMo/elmo_2x4096_512_2048cnn_2xhighway_5.5B_options.json"
|
||||
weight_file = "/remote-home/dqwang/ELMo/elmo_2x4096_512_2048cnn_2xhighway_5.5B_weights.hdf5"
|
||||
|
||||
# elmo_dim = 512
|
||||
# options_file = "/remote-home/dqwang/ELMo/elmo_2x2048_256_2048cnn_1xhighway_options.json"
|
||||
# weight_file = "/remote-home/dqwang/ELMo/elmo_2x2048_256_2048cnn_1xhighway_weights.hdf5"
|
||||
|
||||
embed_size = hps.word_emb_dim
|
||||
sent_max_len = hps.sent_max_len
|
||||
|
||||
input_channels = 1
|
||||
out_channels = hps.output_channel
|
||||
min_kernel_size = hps.min_kernel_size
|
||||
max_kernel_size = hps.max_kernel_size
|
||||
width = embed_size
|
||||
|
||||
# elmo embedding
|
||||
self.elmo = Elmo(options_file, weight_file, 1, dropout=0)
|
||||
self.embed_proj = nn.Linear(elmo_dim, embed_size)
|
||||
|
||||
# position embedding
|
||||
self.position_embedding = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(sent_max_len + 1, embed_size, padding_idx=0), freeze=True)
|
||||
|
||||
# cnn
|
||||
self.convs = nn.ModuleList([nn.Conv2d(input_channels, out_channels, kernel_size = (height, width)) for height in range(min_kernel_size, max_kernel_size+1)])
|
||||
logger.info("[INFO] Initing W for CNN.......")
|
||||
for conv in self.convs:
|
||||
init_weight_value = 6.0
|
||||
init.xavier_normal_(conv.weight.data, gain=np.sqrt(init_weight_value))
|
||||
fan_in, fan_out = Encoder.calculate_fan_in_and_fan_out(conv.weight.data)
|
||||
std = np.sqrt(init_weight_value) * np.sqrt(2.0 / (fan_in + fan_out))
|
||||
|
||||
def calculate_fan_in_and_fan_out(tensor):
|
||||
dimensions = tensor.ndimension()
|
||||
if dimensions < 2:
|
||||
logger.error("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions")
|
||||
raise ValueError("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions")
|
||||
|
||||
if dimensions == 2: # Linear
|
||||
fan_in = tensor.size(1)
|
||||
fan_out = tensor.size(0)
|
||||
else:
|
||||
num_input_fmaps = tensor.size(1)
|
||||
num_output_fmaps = tensor.size(0)
|
||||
receptive_field_size = 1
|
||||
if tensor.dim() > 2:
|
||||
receptive_field_size = tensor[0][0].numel()
|
||||
fan_in = num_input_fmaps * receptive_field_size
|
||||
fan_out = num_output_fmaps * receptive_field_size
|
||||
|
||||
return fan_in, fan_out
|
||||
|
||||
def pad_encoder_input(self, input_list):
|
||||
"""
|
||||
:param input_list: N [seq_len, hidden_state]
|
||||
:return: enc_sent_input_pad: list, N [max_len, hidden_state]
|
||||
"""
|
||||
max_len = self.sent_max_len
|
||||
enc_sent_input_pad = []
|
||||
_, hidden_size = input_list[0].size()
|
||||
for i in range(len(input_list)):
|
||||
article_words = input_list[i] # [seq_len, hidden_size]
|
||||
seq_len = article_words.size(0)
|
||||
if seq_len > max_len:
|
||||
pad_words = article_words[:max_len, :]
|
||||
else:
|
||||
pad_tensor = torch.zeros(max_len - seq_len, hidden_size).cuda() if self._cuda else torch.zeros(max_len - seq_len, hidden_size)
|
||||
pad_words = torch.cat([article_words, pad_tensor], dim=0)
|
||||
enc_sent_input_pad.append(pad_words)
|
||||
return enc_sent_input_pad
|
||||
|
||||
def forward(self, inputs, input_masks, enc_sent_len):
|
||||
"""
|
||||
|
||||
:param inputs: a batch of Example object [batch_size, doc_len=512, character_len=50]
|
||||
:param input_masks: 0 or 1, [batch, doc_len=512]
|
||||
:param enc_sent_len: sentence original length [batch, N]
|
||||
:return:
|
||||
sent_embedding: [batch, N, D]
|
||||
"""
|
||||
|
||||
# Use Bert to get word embedding
|
||||
batch_size, N = enc_sent_len.size()
|
||||
input_pad_list = []
|
||||
|
||||
elmo_output = self.elmo(inputs)['elmo_representations'][0] # [batch_size, 512, D]
|
||||
elmo_output = elmo_output * input_masks.unsqueeze(-1).float()
|
||||
# print("END elmo")
|
||||
|
||||
for i in range(batch_size):
|
||||
sent_len = enc_sent_len[i] # [1, N]
|
||||
out = elmo_output[i]
|
||||
|
||||
_, hidden_size = out.size()
|
||||
|
||||
# restore the sentence
|
||||
last_end = 0
|
||||
enc_sent_input = []
|
||||
for length in sent_len:
|
||||
if length != 0 and last_end < 512:
|
||||
enc_sent_input.append(out[last_end : min(512, last_end + length), :])
|
||||
last_end += length
|
||||
else:
|
||||
pad_tensor = torch.zeros(self.sent_max_len, hidden_size).cuda() if self._hps.cuda else torch.zeros(self.sent_max_len, hidden_size)
|
||||
enc_sent_input.append(pad_tensor)
|
||||
|
||||
# pad the sentence
|
||||
enc_sent_input_pad = self.pad_encoder_input(enc_sent_input) # [N, seq_len, hidden_state=4096]
|
||||
input_pad_list.append(torch.stack(enc_sent_input_pad)) # batch * [N, max_len, hidden_state]
|
||||
|
||||
input_pad = torch.stack(input_pad_list)
|
||||
|
||||
input_pad = input_pad.view(batch_size * N, self.sent_max_len, -1)
|
||||
enc_sent_len = enc_sent_len.view(-1) # [batch_size*N]
|
||||
enc_embed_input = self.embed_proj(input_pad) # [batch_size * N, L, D]
|
||||
|
||||
# input_pos = torch.Tensor([np.hstack((np.arange(1, sentlen + 1), np.zeros(self.sent_max_len - sentlen))) for sentlen in input_sent_len])
|
||||
|
||||
sent_pos_list = []
|
||||
for sentlen in enc_sent_len:
|
||||
sent_pos = list(range(1, min(self.sent_max_len, sentlen) + 1))
|
||||
for k in range(self.sent_max_len - sentlen):
|
||||
sent_pos.append(0)
|
||||
sent_pos_list.append(sent_pos)
|
||||
input_pos = torch.Tensor(sent_pos_list).long()
|
||||
|
||||
if self._hps.cuda:
|
||||
input_pos = input_pos.cuda()
|
||||
enc_pos_embed_input = self.position_embedding(input_pos.long()) # [batch_size*N, D]
|
||||
enc_conv_input = enc_embed_input + enc_pos_embed_input
|
||||
enc_conv_input = enc_conv_input.unsqueeze(1) # (batch * N,Ci,L,D)
|
||||
enc_conv_output = [F.relu(conv(enc_conv_input)).squeeze(3) for conv in self.convs] # kernel_sizes * (batch*N, Co, W)
|
||||
enc_maxpool_output = [F.max_pool1d(x, x.size(2)).squeeze(2) for x in enc_conv_output] # kernel_sizes * (batch*N, Co)
|
||||
sent_embedding = torch.cat(enc_maxpool_output, 1) # (batch*N, Co * kernel_sizes)
|
||||
sent_embedding = sent_embedding.view(batch_size, N, -1)
|
||||
return sent_embedding
|
@ -21,6 +21,7 @@
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import shutil
|
||||
import argparse
|
||||
import datetime
|
||||
|
||||
@ -32,20 +33,25 @@ os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches'
|
||||
sys.path.append('/remote-home/dqwang/FastNLP/fastNLP_brxx/')
|
||||
|
||||
|
||||
from fastNLP.core._logger import logger
|
||||
# from fastNLP.core._logger import _init_logger
|
||||
from fastNLP.core.const import Const
|
||||
from fastNLP.core.trainer import Trainer, Tester
|
||||
from fastNLP.io.pipe.summarization import ExtCNNDMPipe
|
||||
from fastNLP.io.model_io import ModelLoader, ModelSaver
|
||||
from fastNLP.io.embed_loader import EmbedLoader
|
||||
|
||||
from tools.logger import *
|
||||
from data.dataloader import SummarizationLoader
|
||||
# from tools.logger import *
|
||||
# from model.TransformerModel import TransformerModel
|
||||
from model.TForiginal import TransformerModel
|
||||
from model.Metric import LabelFMetric, FastRougeMetric, PyRougeMetric
|
||||
from model.LSTMModel import SummarizationModel
|
||||
from model.Metric import LossMetric, LabelFMetric, FastRougeMetric, PyRougeMetric
|
||||
from model.Loss import MyCrossEntropyLoss
|
||||
from tools.Callback import TrainCallback
|
||||
|
||||
|
||||
|
||||
|
||||
def setup_training(model, train_loader, valid_loader, hps):
|
||||
"""Does setup before starting training (run_training)"""
|
||||
|
||||
@ -60,32 +66,23 @@ def setup_training(model, train_loader, valid_loader, hps):
|
||||
else:
|
||||
logger.info("[INFO] Create new model for training...")
|
||||
|
||||
try:
|
||||
run_training(model, train_loader, valid_loader, hps) # this is an infinite loop until interrupted
|
||||
except KeyboardInterrupt:
|
||||
logger.error("[Error] Caught keyboard interrupt on worker. Stopping supervisor...")
|
||||
save_file = os.path.join(train_dir, "earlystop.pkl")
|
||||
saver = ModelSaver(save_file)
|
||||
saver.save_pytorch(model)
|
||||
logger.info('[INFO] Saving early stop model to %s', save_file)
|
||||
run_training(model, train_loader, valid_loader, hps) # this is an infinite loop until interrupted
|
||||
|
||||
def run_training(model, train_loader, valid_loader, hps):
|
||||
"""Repeatedly runs training iterations, logging loss to screen and writing summaries"""
|
||||
logger.info("[INFO] Starting run_training")
|
||||
|
||||
train_dir = os.path.join(hps.save_root, "train")
|
||||
if not os.path.exists(train_dir): os.makedirs(train_dir)
|
||||
if os.path.exists(train_dir): shutil.rmtree(train_dir)
|
||||
os.makedirs(train_dir)
|
||||
eval_dir = os.path.join(hps.save_root, "eval") # make a subdir of the root dir for eval data
|
||||
if not os.path.exists(eval_dir): os.makedirs(eval_dir)
|
||||
|
||||
lr = hps.lr
|
||||
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)
|
||||
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=hps.lr)
|
||||
criterion = MyCrossEntropyLoss(pred = "p_sent", target=Const.TARGET, mask=Const.INPUT_LEN, reduce='none')
|
||||
# criterion = torch.nn.CrossEntropyLoss(reduce="none")
|
||||
|
||||
trainer = Trainer(model=model, train_data=train_loader, optimizer=optimizer, loss=criterion,
|
||||
n_epochs=hps.n_epochs, print_every=100, dev_data=valid_loader, metrics=[LabelFMetric(pred="prediction"), FastRougeMetric(hps, pred="prediction")],
|
||||
metric_key="f", validate_every=-1, save_path=eval_dir,
|
||||
n_epochs=hps.n_epochs, print_every=100, dev_data=valid_loader, metrics=[LossMetric(pred = "p_sent", target=Const.TARGET, mask=Const.INPUT_LEN, reduce='none'), LabelFMetric(pred="prediction"), FastRougeMetric(hps, pred="prediction")],
|
||||
metric_key="loss", validate_every=-1, save_path=eval_dir,
|
||||
callbacks=[TrainCallback(hps, patience=5)], use_tqdm=False)
|
||||
|
||||
train_info = trainer.train(load_best_model=True)
|
||||
@ -98,8 +95,8 @@ def run_training(model, train_loader, valid_loader, hps):
|
||||
saver.save_pytorch(model)
|
||||
logger.info('[INFO] Saving eval best model to %s', bestmodel_save_path)
|
||||
|
||||
def run_test(model, loader, hps, limited=False):
|
||||
"""Repeatedly runs eval iterations, logging to screen and writing summaries. Saves the model with the best loss seen so far."""
|
||||
|
||||
def run_test(model, loader, hps):
|
||||
test_dir = os.path.join(hps.save_root, "test") # make a subdir of the root dir for eval data
|
||||
eval_dir = os.path.join(hps.save_root, "eval")
|
||||
if not os.path.exists(test_dir) : os.makedirs(test_dir)
|
||||
@ -113,8 +110,8 @@ def run_test(model, loader, hps, limited=False):
|
||||
train_dir = os.path.join(hps.save_root, "train")
|
||||
bestmodel_load_path = os.path.join(train_dir, 'earlystop.pkl')
|
||||
else:
|
||||
logger.error("None of such model! Must be one of evalbestmodel/trainbestmodel/earlystop")
|
||||
raise ValueError("None of such model! Must be one of evalbestmodel/trainbestmodel/earlystop")
|
||||
logger.error("None of such model! Must be one of evalbestmodel/earlystop")
|
||||
raise ValueError("None of such model! Must be one of evalbestmodel/earlystop")
|
||||
logger.info("[INFO] Restoring %s for testing...The path is %s", hps.test_model, bestmodel_load_path)
|
||||
|
||||
modelloader = ModelLoader()
|
||||
@ -174,13 +171,11 @@ def main():
|
||||
# Training
|
||||
parser.add_argument('--lr', type=float, default=0.0001, help='learning rate')
|
||||
parser.add_argument('--lr_descent', action='store_true', default=False, help='learning rate descent')
|
||||
parser.add_argument('--warmup_steps', type=int, default=4000, help='warmup_steps')
|
||||
parser.add_argument('--grad_clip', action='store_true', default=False, help='for gradient clipping')
|
||||
parser.add_argument('--max_grad_norm', type=float, default=10, help='for gradient clipping max gradient normalization')
|
||||
|
||||
# test
|
||||
parser.add_argument('-m', type=int, default=3, help='decode summary length')
|
||||
parser.add_argument('--limited', action='store_true', default=False, help='limited decode summary length')
|
||||
parser.add_argument('--test_model', type=str, default='evalbestmodel', help='choose different model to test [evalbestmodel/evalbestFmodel/trainbestmodel/trainbestFmodel/earlystop]')
|
||||
parser.add_argument('--use_pyrouge', action='store_true', default=False, help='use_pyrouge')
|
||||
|
||||
@ -195,36 +190,42 @@ def main():
|
||||
VOCAL_FILE = args.vocab_path
|
||||
LOG_PATH = args.log_root
|
||||
|
||||
# train_log setting
|
||||
# # train_log setting
|
||||
if not os.path.exists(LOG_PATH):
|
||||
if args.mode == "train":
|
||||
os.makedirs(LOG_PATH)
|
||||
else:
|
||||
logger.exception("[Error] Logdir %s doesn't exist. Run in train mode to create it.", LOG_PATH)
|
||||
raise Exception("[Error] Logdir %s doesn't exist. Run in train mode to create it." % (LOG_PATH))
|
||||
nowTime=datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
log_path = os.path.join(LOG_PATH, args.mode + "_" + nowTime)
|
||||
file_handler = logging.FileHandler(log_path)
|
||||
file_handler.setFormatter(formatter)
|
||||
logger.addHandler(file_handler)
|
||||
# logger = _init_logger(path=log_path)
|
||||
# file_handler = logging.FileHandler(log_path)
|
||||
# file_handler.setFormatter(formatter)
|
||||
# logger.addHandler(file_handler)
|
||||
|
||||
logger.info("Pytorch %s", torch.__version__)
|
||||
sum_loader = SummarizationLoader()
|
||||
|
||||
# dataset
|
||||
hps = args
|
||||
dbPipe = ExtCNNDMPipe(vocab_size=hps.vocab_size,
|
||||
vocab_path=VOCAL_FILE,
|
||||
sent_max_len=hps.sent_max_len,
|
||||
doc_max_timesteps=hps.doc_max_timesteps)
|
||||
if hps.mode == 'test':
|
||||
paths = {"test": DATA_FILE}
|
||||
hps.recurrent_dropout_prob = 0.0
|
||||
hps.atten_dropout_prob = 0.0
|
||||
hps.ffn_dropout_prob = 0.0
|
||||
logger.info(hps)
|
||||
paths = {"test": DATA_FILE}
|
||||
db = dbPipe.process_from_file(paths)
|
||||
else:
|
||||
paths = {"train": DATA_FILE, "valid": VALID_FILE}
|
||||
db = dbPipe.process_from_file(paths)
|
||||
|
||||
dataInfo = sum_loader.process(paths=paths, vocab_size=hps.vocab_size, vocab_path=VOCAL_FILE, sent_max_len=hps.sent_max_len, doc_max_timesteps=hps.doc_max_timesteps, load_vocab=os.path.exists(VOCAL_FILE))
|
||||
|
||||
# embedding
|
||||
if args.embedding == "glove":
|
||||
vocab = dataInfo.vocabs["vocab"]
|
||||
vocab = db.get_vocab("vocab")
|
||||
embed = torch.nn.Embedding(len(vocab), hps.word_emb_dim)
|
||||
if hps.word_embedding:
|
||||
embed_loader = EmbedLoader()
|
||||
@ -235,26 +236,31 @@ def main():
|
||||
logger.error("[ERROR] embedding To Be Continued!")
|
||||
sys.exit(1)
|
||||
|
||||
# model
|
||||
if args.sentence_encoder == "transformer" and args.sentence_decoder == "SeqLab":
|
||||
model_param = json.load(open("config/transformer.config", "rb"))
|
||||
hps.__dict__.update(model_param)
|
||||
model = TransformerModel(hps, embed)
|
||||
elif args.sentence_encoder == "deeplstm" and args.sentence_decoder == "SeqLab":
|
||||
model_param = json.load(open("config/deeplstm.config", "rb"))
|
||||
hps.__dict__.update(model_param)
|
||||
model = SummarizationModel(hps, embed)
|
||||
else:
|
||||
logger.error("[ERROR] Model To Be Continued!")
|
||||
sys.exit(1)
|
||||
|
||||
logger.info(hps)
|
||||
|
||||
if hps.cuda:
|
||||
model = model.cuda()
|
||||
logger.info("[INFO] Use cuda")
|
||||
|
||||
logger.info(hps)
|
||||
|
||||
if hps.mode == 'train':
|
||||
dataInfo.datasets["valid"].set_target("text", "summary")
|
||||
setup_training(model, dataInfo.datasets["train"], dataInfo.datasets["valid"], hps)
|
||||
db.get_dataset("valid").set_target("text", "summary")
|
||||
setup_training(model, db.get_dataset("train"), db.get_dataset("valid"), hps)
|
||||
elif hps.mode == 'test':
|
||||
logger.info("[INFO] Decoding...")
|
||||
dataInfo.datasets["test"].set_target("text", "summary")
|
||||
run_test(model, dataInfo.datasets["test"], hps, limited=hps.limited)
|
||||
db.get_dataset("test").set_target("text", "summary")
|
||||
run_test(model, db.get_dataset("test"), hps, limited=hps.limited)
|
||||
else:
|
||||
logger.error("The 'mode' flag must be one of train/eval/test")
|
||||
raise ValueError("The 'mode' flag must be one of train/eval/test")
|
||||
|
@ -18,7 +18,7 @@ FastNLP中实现的模型包括:
|
||||
|
||||
这里提供的摘要任务数据集包括:
|
||||
|
||||
- CNN/DailyMail
|
||||
- CNN/DailyMail ([Get To The Point: Summarization with Pointer-Generator Networks](http://arxiv.org/abs/1704.04368))
|
||||
- Newsroom
|
||||
- The New York Times Annotated Corpus
|
||||
- NYT
|
||||
@ -110,11 +110,11 @@ $ python -m pyrouge.test
|
||||
|
||||
LSTM + Sequence Labeling
|
||||
|
||||
python train.py --cuda --gpu <gpuid> --sentence_encoder deeplstm --sentence_decoder seqlab --save_root <savedir> --log_root <logdir> --lr_descent --grad_clip --max_grad_norm 10
|
||||
python train.py --cuda --gpu <gpuid> --sentence_encoder deeplstm --sentence_decoder SeqLab --save_root <savedir> --log_root <logdir> --lr_descent --grad_clip --max_grad_norm 10
|
||||
|
||||
Transformer + Sequence Labeling
|
||||
|
||||
python train.py --cuda --gpu <gpuid> --sentence_encoder transformer --sentence_decoder seqlab --save_root <savedir> --log_root <logdir> --lr_descent --grad_clip --max_grad_norm 10
|
||||
python train.py --cuda --gpu <gpuid> --sentence_encoder transformer --sentence_decoder SeqLab --save_root <savedir> --log_root <logdir> --lr_descent --grad_clip --max_grad_norm 10
|
||||
|
||||
|
||||
|
||||
|
10
test/data_for_tests/cnndm.jsonl
Normal file
10
test/data_for_tests/cnndm.jsonl
Normal file
File diff suppressed because one or more lines are too long
100
test/data_for_tests/cnndm.vocab
Normal file
100
test/data_for_tests/cnndm.vocab
Normal file
@ -0,0 +1,100 @@
|
||||
. 12172211
|
||||
the 11896296
|
||||
, 9609022
|
||||
to 5751102
|
||||
a 5100569
|
||||
and 4892246
|
||||
of 4867879
|
||||
in 4431149
|
||||
's 2202754
|
||||
was 2086001
|
||||
for 1995054
|
||||
that 1944328
|
||||
' 1880335
|
||||
on 1858606
|
||||
` 1821696
|
||||
is 1797908
|
||||
he 1678396
|
||||
it 1603145
|
||||
with 1497568
|
||||
said 1348297
|
||||
: 1344327
|
||||
his 1302056
|
||||
at 1260578
|
||||
as 1230256
|
||||
i 1089458
|
||||
by 1064355
|
||||
have 1016505
|
||||
from 1015625
|
||||
has 969042
|
||||
her 935151
|
||||
be 932950
|
||||
'' 904149
|
||||
`` 898933
|
||||
but 884494
|
||||
are 865728
|
||||
she 850971
|
||||
they 816011
|
||||
an 766001
|
||||
not 738121
|
||||
had 725375
|
||||
who 722127
|
||||
this 721027
|
||||
after 669231
|
||||
were 655187
|
||||
been 647432
|
||||
their 645014
|
||||
we 625684
|
||||
will 577581
|
||||
when 506811
|
||||
-rrb- 501827
|
||||
n't 499765
|
||||
-lrb- 497508
|
||||
one 490666
|
||||
which 465040
|
||||
you 461359
|
||||
-- 460450
|
||||
up 437177
|
||||
more 433177
|
||||
out 432343
|
||||
about 428037
|
||||
would 400420
|
||||
- 399113
|
||||
or 399001
|
||||
there 389590
|
||||
people 386121
|
||||
new 380970
|
||||
also 380041
|
||||
all 350670
|
||||
two 343787
|
||||
can 341110
|
||||
him 338345
|
||||
do 330166
|
||||
into 319067
|
||||
last 315857
|
||||
so 308507
|
||||
than 306701
|
||||
just 305759
|
||||
time 302071
|
||||
police 301341
|
||||
could 298919
|
||||
told 298384
|
||||
over 297568
|
||||
if 297292
|
||||
what 293759
|
||||
years 288999
|
||||
first 283683
|
||||
no 274488
|
||||
my 273829
|
||||
year 272392
|
||||
them 270715
|
||||
its 269566
|
||||
now 262011
|
||||
before 260991
|
||||
mr 250970
|
||||
other 247663
|
||||
some 245191
|
||||
being 243458
|
||||
home 229570
|
||||
like 229425
|
||||
did 227833
|
59
test/io/pipe/test_extcnndm.py
Normal file
59
test/io/pipe/test_extcnndm.py
Normal file
@ -0,0 +1,59 @@
|
||||
#!/usr/bin/python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# __author__="Danqing Wang"
|
||||
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import unittest
|
||||
import os
|
||||
# import sys
|
||||
#
|
||||
# sys.path.append("../../../")
|
||||
|
||||
from fastNLP.io import DataBundle
|
||||
from fastNLP.io.pipe.summarization import ExtCNNDMPipe
|
||||
|
||||
class TestRunExtCNNDMPipe(unittest.TestCase):
|
||||
|
||||
def test_load(self):
|
||||
data_set_dict = {
|
||||
'CNNDM': {"train": 'test/data_for_tests/cnndm.jsonl'},
|
||||
}
|
||||
vocab_size = 100000
|
||||
VOCAL_FILE = 'test/data_for_tests/cnndm.vocab'
|
||||
sent_max_len = 100
|
||||
doc_max_timesteps = 50
|
||||
dbPipe = ExtCNNDMPipe(vocab_size=vocab_size,
|
||||
vocab_path=VOCAL_FILE,
|
||||
sent_max_len=sent_max_len,
|
||||
doc_max_timesteps=doc_max_timesteps)
|
||||
dbPipe2 = ExtCNNDMPipe(vocab_size=vocab_size,
|
||||
vocab_path=VOCAL_FILE,
|
||||
sent_max_len=sent_max_len,
|
||||
doc_max_timesteps=doc_max_timesteps,
|
||||
domain=True)
|
||||
for k, v in data_set_dict.items():
|
||||
db = dbPipe.process_from_file(v)
|
||||
db2 = dbPipe2.process_from_file(v)
|
||||
|
||||
# print(db2.get_dataset("train"))
|
||||
|
||||
self.assertTrue(isinstance(db, DataBundle))
|
||||
self.assertTrue(isinstance(db2, DataBundle))
|
||||
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user