mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-01 11:48:09 +08:00
add BertSum
This commit is contained in:
parent
ca94a00f41
commit
248eefe9eb
129
reproduction/Summmarization/BertSum/callback.py
Normal file
129
reproduction/Summmarization/BertSum/callback.py
Normal file
@ -0,0 +1,129 @@
|
||||
import os
|
||||
import torch
|
||||
import sys
|
||||
from torch import nn
|
||||
|
||||
from fastNLP.core.callback import Callback
|
||||
from fastNLP.core.utils import _get_model_device
|
||||
|
||||
class MyCallback(Callback):
|
||||
def __init__(self, args):
|
||||
super(MyCallback, self).__init__()
|
||||
self.args = args
|
||||
self.real_step = 0
|
||||
|
||||
def on_step_end(self):
|
||||
if self.step % self.update_every == 0 and self.step > 0:
|
||||
self.real_step += 1
|
||||
cur_lr = self.args.max_lr * 100 * min(self.real_step ** (-0.5), self.real_step * self.args.warmup_steps**(-1.5))
|
||||
for param_group in self.optimizer.param_groups:
|
||||
param_group['lr'] = cur_lr
|
||||
|
||||
if self.real_step % 1000 == 0:
|
||||
self.pbar.write('Current learning rate is {:.8f}, real_step: {}'.format(cur_lr, self.real_step))
|
||||
|
||||
def on_epoch_end(self):
|
||||
self.pbar.write('Epoch {} is done !!!'.format(self.epoch))
|
||||
|
||||
def _save_model(model, model_name, save_dir, only_param=False):
|
||||
""" 存储不含有显卡信息的 state_dict 或 model
|
||||
:param model:
|
||||
:param model_name:
|
||||
:param save_dir: 保存的 directory
|
||||
:param only_param:
|
||||
:return:
|
||||
"""
|
||||
model_path = os.path.join(save_dir, model_name)
|
||||
if not os.path.isdir(save_dir):
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
if isinstance(model, nn.DataParallel):
|
||||
model = model.module
|
||||
if only_param:
|
||||
state_dict = model.state_dict()
|
||||
for key in state_dict:
|
||||
state_dict[key] = state_dict[key].cpu()
|
||||
torch.save(state_dict, model_path)
|
||||
else:
|
||||
_model_device = _get_model_device(model)
|
||||
model.cpu()
|
||||
torch.save(model, model_path)
|
||||
model.to(_model_device)
|
||||
|
||||
class SaveModelCallback(Callback):
|
||||
"""
|
||||
由于Trainer在训练过程中只会保存最佳的模型, 该 callback 可实现多种方式的结果存储。
|
||||
会根据训练开始的时间戳在 save_dir 下建立文件夹,在再文件夹下存放多个模型
|
||||
-save_dir
|
||||
-2019-07-03-15-06-36
|
||||
-epoch0step20{metric_key}{evaluate_performance}.pt # metric是给定的metric_key, evaluate_perfomance是性能
|
||||
-epoch1step40
|
||||
-2019-07-03-15-10-00
|
||||
-epoch:0step:20{metric_key}:{evaluate_performance}.pt # metric是给定的metric_key, evaluate_perfomance是性能
|
||||
:param str save_dir: 将模型存放在哪个目录下,会在该目录下创建以时间戳命名的目录,并存放模型
|
||||
:param int top: 保存dev表现top多少模型。-1为保存所有模型
|
||||
:param bool only_param: 是否只保存模型权重
|
||||
:param save_on_exception: 发生exception时,是否保存一份当时的模型
|
||||
"""
|
||||
def __init__(self, save_dir, top=5, only_param=False, save_on_exception=False):
|
||||
super().__init__()
|
||||
|
||||
if not os.path.isdir(save_dir):
|
||||
raise IsADirectoryError("{} is not a directory.".format(save_dir))
|
||||
self.save_dir = save_dir
|
||||
if top < 0:
|
||||
self.top = sys.maxsize
|
||||
else:
|
||||
self.top = top
|
||||
self._ordered_save_models = [] # List[Tuple], Tuple[0]是metric, Tuple[1]是path。metric是依次变好的,所以从头删
|
||||
|
||||
self.only_param = only_param
|
||||
self.save_on_exception = save_on_exception
|
||||
|
||||
def on_train_begin(self):
|
||||
self.save_dir = os.path.join(self.save_dir, self.trainer.start_time)
|
||||
|
||||
def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval):
|
||||
metric_value = list(eval_result.values())[0][metric_key]
|
||||
self._save_this_model(metric_value)
|
||||
|
||||
def _insert_into_ordered_save_models(self, pair):
|
||||
# pair:(metric_value, model_name)
|
||||
# 返回save的模型pair与删除的模型pair. pair中第一个元素是metric的值,第二个元素是模型的名称
|
||||
index = -1
|
||||
for _pair in self._ordered_save_models:
|
||||
if _pair[0]>=pair[0] and self.trainer.increase_better:
|
||||
break
|
||||
if not self.trainer.increase_better and _pair[0]<=pair[0]:
|
||||
break
|
||||
index += 1
|
||||
save_pair = None
|
||||
if len(self._ordered_save_models)<self.top or (len(self._ordered_save_models)>=self.top and index!=-1):
|
||||
save_pair = pair
|
||||
self._ordered_save_models.insert(index+1, pair)
|
||||
delete_pair = None
|
||||
if len(self._ordered_save_models)>self.top:
|
||||
delete_pair = self._ordered_save_models.pop(0)
|
||||
return save_pair, delete_pair
|
||||
|
||||
def _save_this_model(self, metric_value):
|
||||
name = "epoch:{}_step:{}_{}:{:.6f}.pt".format(self.epoch, self.step, self.trainer.metric_key, metric_value)
|
||||
save_pair, delete_pair = self._insert_into_ordered_save_models((metric_value, name))
|
||||
if save_pair:
|
||||
try:
|
||||
_save_model(self.model, model_name=name, save_dir=self.save_dir, only_param=self.only_param)
|
||||
except Exception as e:
|
||||
print(f"The following exception:{e} happens when saves model to {self.save_dir}.")
|
||||
if delete_pair:
|
||||
try:
|
||||
delete_model_path = os.path.join(self.save_dir, delete_pair[1])
|
||||
if os.path.exists(delete_model_path):
|
||||
os.remove(delete_model_path)
|
||||
except Exception as e:
|
||||
print(f"Fail to delete model {name} at {self.save_dir} caused by exception:{e}.")
|
||||
|
||||
def on_exception(self, exception):
|
||||
if self.save_on_exception:
|
||||
name = "epoch:{}_step:{}_Exception:{}.pt".format(self.epoch, self.step, exception.__class__.__name__)
|
||||
_save_model(self.model, model_name=name, save_dir=self.save_dir, only_param=self.only_param)
|
||||
|
||||
|
157
reproduction/Summmarization/BertSum/dataloader.py
Normal file
157
reproduction/Summmarization/BertSum/dataloader.py
Normal file
@ -0,0 +1,157 @@
|
||||
from time import time
|
||||
from datetime import timedelta
|
||||
|
||||
from fastNLP.io.dataset_loader import JsonLoader
|
||||
from fastNLP.modules.encoder._bert import BertTokenizer
|
||||
from fastNLP.io.base_loader import DataInfo
|
||||
from fastNLP.core.const import Const
|
||||
|
||||
class BertData(JsonLoader):
|
||||
|
||||
def __init__(self, max_nsents=60, max_ntokens=100, max_len=512):
|
||||
|
||||
fields = {'article': 'article',
|
||||
'label': 'label'}
|
||||
super(BertData, self).__init__(fields=fields)
|
||||
|
||||
self.max_nsents = max_nsents
|
||||
self.max_ntokens = max_ntokens
|
||||
self.max_len = max_len
|
||||
|
||||
self.tokenizer = BertTokenizer.from_pretrained('/path/to/uncased_L-12_H-768_A-12')
|
||||
self.cls_id = self.tokenizer.vocab['[CLS]']
|
||||
self.sep_id = self.tokenizer.vocab['[SEP]']
|
||||
self.pad_id = self.tokenizer.vocab['[PAD]']
|
||||
|
||||
def _load(self, paths):
|
||||
dataset = super(BertData, self)._load(paths)
|
||||
return dataset
|
||||
|
||||
def process(self, paths):
|
||||
|
||||
def truncate_articles(instance, max_nsents=self.max_nsents, max_ntokens=self.max_ntokens):
|
||||
article = [' '.join(sent.lower().split()[:max_ntokens]) for sent in instance['article']]
|
||||
return article[:max_nsents]
|
||||
|
||||
def truncate_labels(instance):
|
||||
label = list(filter(lambda x: x < len(instance['article']), instance['label']))
|
||||
return label
|
||||
|
||||
def bert_tokenize(instance, tokenizer, max_len, pad_value):
|
||||
article = instance['article']
|
||||
article = ' [SEP] [CLS] '.join(article)
|
||||
word_pieces = tokenizer.tokenize(article)[:(max_len - 2)]
|
||||
word_pieces = ['[CLS]'] + word_pieces + ['[SEP]']
|
||||
token_ids = tokenizer.convert_tokens_to_ids(word_pieces)
|
||||
while len(token_ids) < max_len:
|
||||
token_ids.append(pad_value)
|
||||
assert len(token_ids) == max_len
|
||||
return token_ids
|
||||
|
||||
def get_seg_id(instance, max_len, sep_id):
|
||||
_segs = [-1] + [i for i, idx in enumerate(instance['article']) if idx == sep_id]
|
||||
segs = [_segs[i] - _segs[i - 1] for i in range(1, len(_segs))]
|
||||
segment_id = []
|
||||
for i, length in enumerate(segs):
|
||||
if i % 2 == 0:
|
||||
segment_id += length * [0]
|
||||
else:
|
||||
segment_id += length * [1]
|
||||
while len(segment_id) < max_len:
|
||||
segment_id.append(0)
|
||||
return segment_id
|
||||
|
||||
def get_cls_id(instance, cls_id):
|
||||
classification_id = [i for i, idx in enumerate(instance['article']) if idx == cls_id]
|
||||
return classification_id
|
||||
|
||||
def get_labels(instance):
|
||||
labels = [0] * len(instance['cls_id'])
|
||||
label_idx = list(filter(lambda x: x < len(instance['cls_id']), instance['label']))
|
||||
for idx in label_idx:
|
||||
labels[idx] = 1
|
||||
return labels
|
||||
|
||||
datasets = {}
|
||||
for name in paths:
|
||||
datasets[name] = self._load(paths[name])
|
||||
|
||||
# remove empty samples
|
||||
datasets[name].drop(lambda ins: len(ins['article']) == 0 or len(ins['label']) == 0)
|
||||
|
||||
# truncate articles
|
||||
datasets[name].apply(lambda ins: truncate_articles(ins, self.max_nsents, self.max_ntokens), new_field_name='article')
|
||||
|
||||
# truncate labels
|
||||
datasets[name].apply(truncate_labels, new_field_name='label')
|
||||
|
||||
# tokenize and convert tokens to id
|
||||
datasets[name].apply(lambda ins: bert_tokenize(ins, self.tokenizer, self.max_len, self.pad_id), new_field_name='article')
|
||||
|
||||
# get segment id
|
||||
datasets[name].apply(lambda ins: get_seg_id(ins, self.max_len, self.sep_id), new_field_name='segment_id')
|
||||
|
||||
# get classification id
|
||||
datasets[name].apply(lambda ins: get_cls_id(ins, self.cls_id), new_field_name='cls_id')
|
||||
|
||||
# get label
|
||||
datasets[name].apply(get_labels, new_field_name='label')
|
||||
|
||||
# rename filed
|
||||
datasets[name].rename_field('article', Const.INPUTS(0))
|
||||
datasets[name].rename_field('segment_id', Const.INPUTS(1))
|
||||
datasets[name].rename_field('cls_id', Const.INPUTS(2))
|
||||
datasets[name].rename_field('lbael', Const.TARGET)
|
||||
|
||||
# set input and target
|
||||
datasets[name].set_input(Const.INPUTS(0), Const.INPUTS(1), Const.INPUTS(2))
|
||||
datasets[name].set_target(Const.TARGET)
|
||||
|
||||
# set paddding value
|
||||
datasets[name].set_pad_val('article', 0)
|
||||
|
||||
return DataInfo(datasets=datasets)
|
||||
|
||||
|
||||
class BertSumLoader(JsonLoader):
|
||||
|
||||
def __init__(self):
|
||||
fields = {'article': 'article',
|
||||
'segment_id': 'segment_id',
|
||||
'cls_id': 'cls_id',
|
||||
'label': Const.TARGET
|
||||
}
|
||||
super(BertSumLoader, self).__init__(fields=fields)
|
||||
|
||||
def _load(self, paths):
|
||||
dataset = super(BertSumLoader, self)._load(paths)
|
||||
return dataset
|
||||
|
||||
def process(self, paths):
|
||||
|
||||
def get_seq_len(instance):
|
||||
return len(instance['article'])
|
||||
|
||||
print('Start loading datasets !!!')
|
||||
start = time()
|
||||
|
||||
# load datasets
|
||||
datasets = {}
|
||||
for name in paths:
|
||||
datasets[name] = self._load(paths[name])
|
||||
|
||||
datasets[name].apply(get_seq_len, new_field_name='seq_len')
|
||||
|
||||
# set input and target
|
||||
datasets[name].set_input('article', 'segment_id', 'cls_id')
|
||||
datasets[name].set_target(Const.TARGET)
|
||||
|
||||
# set padding value
|
||||
datasets[name].set_pad_val('article', 0)
|
||||
datasets[name].set_pad_val('segment_id', 0)
|
||||
datasets[name].set_pad_val('cls_id', -1)
|
||||
datasets[name].set_pad_val(Const.TARGET, 0)
|
||||
|
||||
print('Finished in {}'.format(timedelta(seconds=time()-start)))
|
||||
|
||||
return DataInfo(datasets=datasets)
|
178
reproduction/Summmarization/BertSum/metrics.py
Normal file
178
reproduction/Summmarization/BertSum/metrics.py
Normal file
@ -0,0 +1,178 @@
|
||||
import numpy as np
|
||||
import json
|
||||
from os.path import join
|
||||
import torch
|
||||
import logging
|
||||
import tempfile
|
||||
import subprocess as sp
|
||||
from datetime import timedelta
|
||||
from time import time
|
||||
|
||||
from pyrouge import Rouge155
|
||||
from pyrouge.utils import log
|
||||
|
||||
from fastNLP.core.losses import LossBase
|
||||
from fastNLP.core.metrics import MetricBase
|
||||
|
||||
_ROUGE_PATH = '/path/to/RELEASE-1.5.5'
|
||||
|
||||
class MyBCELoss(LossBase):
|
||||
|
||||
def __init__(self, pred=None, target=None, mask=None):
|
||||
super(MyBCELoss, self).__init__()
|
||||
self._init_param_map(pred=pred, target=target, mask=mask)
|
||||
self.loss_func = torch.nn.BCELoss(reduction='none')
|
||||
|
||||
def get_loss(self, pred, target, mask):
|
||||
loss = self.loss_func(pred, target.float())
|
||||
loss = (loss * mask.float()).sum()
|
||||
return loss
|
||||
|
||||
class LossMetric(MetricBase):
|
||||
def __init__(self, pred=None, target=None, mask=None):
|
||||
super(LossMetric, self).__init__()
|
||||
self._init_param_map(pred=pred, target=target, mask=mask)
|
||||
self.loss_func = torch.nn.BCELoss(reduction='none')
|
||||
self.avg_loss = 0.0
|
||||
self.nsamples = 0
|
||||
|
||||
def evaluate(self, pred, target, mask):
|
||||
batch_size = pred.size(0)
|
||||
loss = self.loss_func(pred, target.float())
|
||||
loss = (loss * mask.float()).sum()
|
||||
self.avg_loss += loss
|
||||
self.nsamples += batch_size
|
||||
|
||||
def get_metric(self, reset=True):
|
||||
self.avg_loss = self.avg_loss / self.nsamples
|
||||
eval_result = {'loss': self.avg_loss}
|
||||
if reset:
|
||||
self.avg_loss = 0
|
||||
self.nsamples = 0
|
||||
return eval_result
|
||||
|
||||
class RougeMetric(MetricBase):
|
||||
def __init__(self, data_path, dec_path, ref_path, n_total, n_ext=3, ngram_block=3, pred=None, target=None, mask=None):
|
||||
super(RougeMetric, self).__init__()
|
||||
self._init_param_map(pred=pred, target=target, mask=mask)
|
||||
self.data_path = data_path
|
||||
self.dec_path = dec_path
|
||||
self.ref_path = ref_path
|
||||
self.n_total = n_total
|
||||
self.n_ext = n_ext
|
||||
self.ngram_block = ngram_block
|
||||
|
||||
self.cur_idx = 0
|
||||
self.ext = []
|
||||
self.start = time()
|
||||
|
||||
@staticmethod
|
||||
def eval_rouge(dec_dir, ref_dir):
|
||||
assert _ROUGE_PATH is not None
|
||||
log.get_global_console_logger().setLevel(logging.WARNING)
|
||||
dec_pattern = '(\d+).dec'
|
||||
ref_pattern = '#ID#.ref'
|
||||
cmd = '-c 95 -r 1000 -n 2 -m'
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
Rouge155.convert_summaries_to_rouge_format(
|
||||
dec_dir, join(tmp_dir, 'dec'))
|
||||
Rouge155.convert_summaries_to_rouge_format(
|
||||
ref_dir, join(tmp_dir, 'ref'))
|
||||
Rouge155.write_config_static(
|
||||
join(tmp_dir, 'dec'), dec_pattern,
|
||||
join(tmp_dir, 'ref'), ref_pattern,
|
||||
join(tmp_dir, 'settings.xml'), system_id=1
|
||||
)
|
||||
cmd = (join(_ROUGE_PATH, 'ROUGE-1.5.5.pl')
|
||||
+ ' -e {} '.format(join(_ROUGE_PATH, 'data'))
|
||||
+ cmd
|
||||
+ ' -a {}'.format(join(tmp_dir, 'settings.xml')))
|
||||
output = sp.check_output(cmd.split(' '), universal_newlines=True)
|
||||
R_1 = float(output.split('\n')[3].split(' ')[3])
|
||||
R_2 = float(output.split('\n')[7].split(' ')[3])
|
||||
R_L = float(output.split('\n')[11].split(' ')[3])
|
||||
print(output)
|
||||
return R_1, R_2, R_L
|
||||
|
||||
def evaluate(self, pred, target, mask):
|
||||
pred = pred + mask.float()
|
||||
pred = pred.cpu().data.numpy()
|
||||
ext_ids = np.argsort(-pred, 1)
|
||||
for sent_id in ext_ids:
|
||||
self.ext.append(sent_id)
|
||||
self.cur_idx += 1
|
||||
print('{}/{} ({:.2f}%) decoded in {} seconds\r'.format(
|
||||
self.cur_idx, self.n_total, self.cur_idx/self.n_total*100, timedelta(seconds=int(time()-self.start))
|
||||
), end='')
|
||||
|
||||
def get_metric(self, use_ngram_block=True, reset=True):
|
||||
|
||||
def check_n_gram(sentence, n, dic):
|
||||
tokens = sentence.split(' ')
|
||||
s_len = len(tokens)
|
||||
for i in range(s_len):
|
||||
if i + n > s_len:
|
||||
break
|
||||
if ' '.join(tokens[i: i + n]) in dic:
|
||||
return False
|
||||
return True # no n_gram overlap
|
||||
|
||||
# load original data
|
||||
data = []
|
||||
with open(self.data_path) as f:
|
||||
for line in f:
|
||||
cur_data = json.loads(line)
|
||||
if 'text' in cur_data:
|
||||
new_data = {}
|
||||
new_data['article'] = cur_data['text']
|
||||
new_data['abstract'] = cur_data['summary']
|
||||
data.append(new_data)
|
||||
else:
|
||||
data.append(cur_data)
|
||||
|
||||
# write decode sentences and references
|
||||
if use_ngram_block == True:
|
||||
print('\nStart {}-gram blocking !!!'.format(self.ngram_block))
|
||||
for i, ext_ids in enumerate(self.ext):
|
||||
dec, ref = [], []
|
||||
if use_ngram_block == False:
|
||||
n_sent = min(len(data[i]['article']), self.n_ext)
|
||||
for j in range(n_sent):
|
||||
idx = ext_ids[j]
|
||||
dec.append(data[i]['article'][idx])
|
||||
else:
|
||||
n_sent = len(ext_ids)
|
||||
dic = {}
|
||||
for j in range(n_sent):
|
||||
sent = data[i]['article'][ext_ids[j]]
|
||||
if check_n_gram(sent, self.ngram_block, dic) == True:
|
||||
dec.append(sent)
|
||||
# update dic
|
||||
tokens = sent.split(' ')
|
||||
s_len = len(tokens)
|
||||
for k in range(s_len):
|
||||
if k + self.ngram_block > s_len:
|
||||
break
|
||||
dic[' '.join(tokens[k: k + self.ngram_block])] = 1
|
||||
if len(dec) >= self.n_ext:
|
||||
break
|
||||
|
||||
for sent in data[i]['abstract']:
|
||||
ref.append(sent)
|
||||
|
||||
with open(join(self.dec_path, '{}.dec'.format(i)), 'w') as f:
|
||||
for sent in dec:
|
||||
print(sent, file=f)
|
||||
with open(join(self.ref_path, '{}.ref'.format(i)), 'w') as f:
|
||||
for sent in ref:
|
||||
print(sent, file=f)
|
||||
|
||||
print('\nStart evaluating ROUGE score !!!')
|
||||
R_1, R_2, R_L = RougeMetric.eval_rouge(self.dec_path, self.ref_path)
|
||||
eval_result = {'ROUGE-1': R_1, 'ROUGE-2': R_2, 'ROUGE-L':R_L}
|
||||
|
||||
if reset == True:
|
||||
self.cur_idx = 0
|
||||
self.ext = []
|
||||
self.start = time()
|
||||
return eval_result
|
51
reproduction/Summmarization/BertSum/model.py
Normal file
51
reproduction/Summmarization/BertSum/model.py
Normal file
@ -0,0 +1,51 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import init
|
||||
|
||||
from fastNLP.modules.encoder._bert import BertModel
|
||||
|
||||
|
||||
class Classifier(nn.Module):
|
||||
def __init__(self, hidden_size):
|
||||
super(Classifier, self).__init__()
|
||||
self.linear = nn.Linear(hidden_size, 1)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def forward(self, inputs, mask_cls):
|
||||
h = self.linear(inputs).squeeze(-1) # [batch_size, seq_len]
|
||||
sent_scores = self.sigmoid(h) * mask_cls.float()
|
||||
return sent_scores
|
||||
|
||||
|
||||
class BertSum(nn.Module):
|
||||
|
||||
def __init__(self, hidden_size=768):
|
||||
super(BertSum, self).__init__()
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
self.encoder = BertModel.from_pretrained('/path/to/uncased_L-12_H-768_A-12')
|
||||
self.decoder = Classifier(self.hidden_size)
|
||||
|
||||
def forward(self, article, segment_id, cls_id):
|
||||
|
||||
# print(article.device)
|
||||
# print(segment_id.device)
|
||||
# print(cls_id.device)
|
||||
|
||||
input_mask = 1 - (article == 0)
|
||||
mask_cls = 1 - (cls_id == -1)
|
||||
assert input_mask.size() == article.size()
|
||||
assert mask_cls.size() == cls_id.size()
|
||||
|
||||
bert_out = self.encoder(article, token_type_ids=segment_id, attention_mask=input_mask)
|
||||
bert_out = bert_out[0][-1] # last layer
|
||||
|
||||
sent_emb = bert_out[torch.arange(bert_out.size(0)).unsqueeze(1), cls_id]
|
||||
sent_emb = sent_emb * mask_cls.unsqueeze(-1).float()
|
||||
assert sent_emb.size() == (article.size(0), cls_id.size(1), self.hidden_size) # [batch_size, seq_len, hidden_size]
|
||||
|
||||
sent_scores = self.decoder(sent_emb, mask_cls) # [batch_size, seq_len]
|
||||
assert sent_scores.size() == (article.size(0), cls_id.size(1))
|
||||
|
||||
return {'pred': sent_scores, 'mask': mask_cls}
|
147
reproduction/Summmarization/BertSum/train_BertSum.py
Normal file
147
reproduction/Summmarization/BertSum/train_BertSum.py
Normal file
@ -0,0 +1,147 @@
|
||||
import sys
|
||||
import argparse
|
||||
import os
|
||||
import json
|
||||
import torch
|
||||
from time import time
|
||||
from datetime import timedelta
|
||||
from os.path import join, exists
|
||||
from torch.optim import Adam
|
||||
|
||||
from utils import get_data_path, get_rouge_path
|
||||
|
||||
from dataloader import BertSumLoader
|
||||
from model import BertSum
|
||||
from fastNLP.core.optimizer import AdamW
|
||||
from metrics import MyBCELoss, LossMetric, RougeMetric
|
||||
from fastNLP.core.sampler import BucketSampler
|
||||
from callback import MyCallback, SaveModelCallback
|
||||
from fastNLP.core.trainer import Trainer
|
||||
from fastNLP.core.tester import Tester
|
||||
|
||||
|
||||
def configure_training(args):
|
||||
devices = [int(gpu) for gpu in args.gpus.split(',')]
|
||||
params = {}
|
||||
params['label_type'] = args.label_type
|
||||
params['batch_size'] = args.batch_size
|
||||
params['accum_count'] = args.accum_count
|
||||
params['max_lr'] = args.max_lr
|
||||
params['warmup_steps'] = args.warmup_steps
|
||||
params['n_epochs'] = args.n_epochs
|
||||
params['valid_steps'] = args.valid_steps
|
||||
return devices, params
|
||||
|
||||
def train_model(args):
|
||||
|
||||
# check if the data_path and save_path exists
|
||||
data_paths = get_data_path(args.mode, args.label_type)
|
||||
for name in data_paths:
|
||||
assert exists(data_paths[name])
|
||||
if not exists(args.save_path):
|
||||
os.makedirs(args.save_path)
|
||||
|
||||
# load summarization datasets
|
||||
datasets = BertSumLoader().process(data_paths)
|
||||
print('Information of dataset is:')
|
||||
print(datasets)
|
||||
train_set = datasets.datasets['train']
|
||||
valid_set = datasets.datasets['val']
|
||||
|
||||
# configure training
|
||||
devices, train_params = configure_training(args)
|
||||
with open(join(args.save_path, 'params.json'), 'w') as f:
|
||||
json.dump(train_params, f, indent=4)
|
||||
print('Devices is:')
|
||||
print(devices)
|
||||
|
||||
# configure model
|
||||
model = BertSum()
|
||||
optimizer = Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0)
|
||||
callbacks = [MyCallback(args), SaveModelCallback(args.save_path)]
|
||||
criterion = MyBCELoss()
|
||||
val_metric = [LossMetric()]
|
||||
# sampler = BucketSampler(num_buckets=32, batch_size=args.batch_size)
|
||||
trainer = Trainer(train_data=train_set, model=model, optimizer=optimizer,
|
||||
loss=criterion, batch_size=args.batch_size, # sampler=sampler,
|
||||
update_every=args.accum_count, n_epochs=args.n_epochs,
|
||||
print_every=100, dev_data=valid_set, metrics=val_metric,
|
||||
metric_key='-loss', validate_every=args.valid_steps,
|
||||
save_path=args.save_path, device=devices, callbacks=callbacks)
|
||||
|
||||
print('Start training with the following hyper-parameters:')
|
||||
print(train_params)
|
||||
trainer.train()
|
||||
|
||||
def test_model(args):
|
||||
|
||||
models = os.listdir(args.save_path)
|
||||
|
||||
# load dataset
|
||||
data_paths = get_data_path(args.mode, args.label_type)
|
||||
datasets = BertSumLoader().process(data_paths)
|
||||
print('Information of dataset is:')
|
||||
print(datasets)
|
||||
test_set = datasets.datasets['test']
|
||||
|
||||
# only need 1 gpu for testing
|
||||
device = int(args.gpus)
|
||||
|
||||
args.batch_size = 1
|
||||
|
||||
for cur_model in models:
|
||||
|
||||
print('Current model is {}'.format(cur_model))
|
||||
|
||||
# load model
|
||||
model = torch.load(join(args.save_path, cur_model))
|
||||
|
||||
# configure testing
|
||||
original_path, dec_path, ref_path = get_rouge_path(args.label_type)
|
||||
test_metric = RougeMetric(data_path=original_path, dec_path=dec_path,
|
||||
ref_path=ref_path, n_total = len(test_set))
|
||||
tester = Tester(data=test_set, model=model, metrics=[test_metric],
|
||||
batch_size=args.batch_size, device=device)
|
||||
tester.test()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(
|
||||
description='training/testing of BertSum(liu et al. 2019)'
|
||||
)
|
||||
parser.add_argument('--mode', required=True,
|
||||
help='training or testing of BertSum', type=str)
|
||||
|
||||
parser.add_argument('--label_type', default='greedy',
|
||||
help='greedy/limit', type=str)
|
||||
parser.add_argument('--save_path', required=True,
|
||||
help='root of the model', type=str)
|
||||
# example for gpus input: '0,1,2,3'
|
||||
parser.add_argument('--gpus', required=True,
|
||||
help='available gpus for training(separated by commas)', type=str)
|
||||
|
||||
parser.add_argument('--batch_size', default=18,
|
||||
help='the training batch size', type=int)
|
||||
parser.add_argument('--accum_count', default=2,
|
||||
help='number of updates steps to accumulate before performing a backward/update pass.', type=int)
|
||||
parser.add_argument('--max_lr', default=2e-5,
|
||||
help='max learning rate for warm up', type=float)
|
||||
parser.add_argument('--warmup_steps', default=10000,
|
||||
help='warm up steps for training', type=int)
|
||||
parser.add_argument('--n_epochs', default=10,
|
||||
help='total number of training epochs', type=int)
|
||||
parser.add_argument('--valid_steps', default=1000,
|
||||
help='number of update steps for checkpoint and validation', type=int)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.mode == 'train':
|
||||
print('Training process of BertSum !!!')
|
||||
train_model(args)
|
||||
else:
|
||||
print('Testing process of BertSum !!!')
|
||||
test_model(args)
|
||||
|
||||
|
||||
|
||||
|
24
reproduction/Summmarization/BertSum/utils.py
Normal file
24
reproduction/Summmarization/BertSum/utils.py
Normal file
@ -0,0 +1,24 @@
|
||||
import os
|
||||
from os.path import exists
|
||||
|
||||
def get_data_path(mode, label_type):
|
||||
paths = {}
|
||||
if mode == 'train':
|
||||
paths['train'] = 'data/' + label_type + '/bert.train.jsonl'
|
||||
paths['val'] = 'data/' + label_type + '/bert.val.jsonl'
|
||||
else:
|
||||
paths['test'] = 'data/' + label_type + '/bert.test.jsonl'
|
||||
return paths
|
||||
|
||||
def get_rouge_path(label_type):
|
||||
if label_type == 'others':
|
||||
data_path = 'data/' + label_type + '/bert.test.jsonl'
|
||||
else:
|
||||
data_path = 'data/' + label_type + '/test.jsonl'
|
||||
dec_path = 'dec'
|
||||
ref_path = 'ref'
|
||||
if not exists(ref_path):
|
||||
os.makedirs(ref_path)
|
||||
if not exists(dec_path):
|
||||
os.makedirs(dec_path)
|
||||
return data_path, dec_path, ref_path
|
Loading…
Reference in New Issue
Block a user