add BertSum

This commit is contained in:
maszhongming 2019-07-08 16:10:01 +09:00
parent ca94a00f41
commit 248eefe9eb
6 changed files with 686 additions and 0 deletions

View File

@ -0,0 +1,129 @@
import os
import torch
import sys
from torch import nn
from fastNLP.core.callback import Callback
from fastNLP.core.utils import _get_model_device
class MyCallback(Callback):
def __init__(self, args):
super(MyCallback, self).__init__()
self.args = args
self.real_step = 0
def on_step_end(self):
if self.step % self.update_every == 0 and self.step > 0:
self.real_step += 1
cur_lr = self.args.max_lr * 100 * min(self.real_step ** (-0.5), self.real_step * self.args.warmup_steps**(-1.5))
for param_group in self.optimizer.param_groups:
param_group['lr'] = cur_lr
if self.real_step % 1000 == 0:
self.pbar.write('Current learning rate is {:.8f}, real_step: {}'.format(cur_lr, self.real_step))
def on_epoch_end(self):
self.pbar.write('Epoch {} is done !!!'.format(self.epoch))
def _save_model(model, model_name, save_dir, only_param=False):
""" 存储不含有显卡信息的 state_dict 或 model
:param model:
:param model_name:
:param save_dir: 保存的 directory
:param only_param:
:return:
"""
model_path = os.path.join(save_dir, model_name)
if not os.path.isdir(save_dir):
os.makedirs(save_dir, exist_ok=True)
if isinstance(model, nn.DataParallel):
model = model.module
if only_param:
state_dict = model.state_dict()
for key in state_dict:
state_dict[key] = state_dict[key].cpu()
torch.save(state_dict, model_path)
else:
_model_device = _get_model_device(model)
model.cpu()
torch.save(model, model_path)
model.to(_model_device)
class SaveModelCallback(Callback):
"""
由于Trainer在训练过程中只会保存最佳的模型 callback 可实现多种方式的结果存储
会根据训练开始的时间戳在 save_dir 下建立文件夹在再文件夹下存放多个模型
-save_dir
-2019-07-03-15-06-36
-epoch0step20{metric_key}{evaluate_performance}.pt # metric是给定的metric_key, evaluate_perfomance是性能
-epoch1step40
-2019-07-03-15-10-00
-epoch:0step:20{metric_key}:{evaluate_performance}.pt # metric是给定的metric_key, evaluate_perfomance是性能
:param str save_dir: 将模型存放在哪个目录下会在该目录下创建以时间戳命名的目录并存放模型
:param int top: 保存dev表现top多少模型-1为保存所有模型
:param bool only_param: 是否只保存模型权重
:param save_on_exception: 发生exception时是否保存一份当时的模型
"""
def __init__(self, save_dir, top=5, only_param=False, save_on_exception=False):
super().__init__()
if not os.path.isdir(save_dir):
raise IsADirectoryError("{} is not a directory.".format(save_dir))
self.save_dir = save_dir
if top < 0:
self.top = sys.maxsize
else:
self.top = top
self._ordered_save_models = [] # List[Tuple], Tuple[0]是metric Tuple[1]是path。metric是依次变好的所以从头删
self.only_param = only_param
self.save_on_exception = save_on_exception
def on_train_begin(self):
self.save_dir = os.path.join(self.save_dir, self.trainer.start_time)
def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval):
metric_value = list(eval_result.values())[0][metric_key]
self._save_this_model(metric_value)
def _insert_into_ordered_save_models(self, pair):
# pair:(metric_value, model_name)
# 返回save的模型pair与删除的模型pair. pair中第一个元素是metric的值第二个元素是模型的名称
index = -1
for _pair in self._ordered_save_models:
if _pair[0]>=pair[0] and self.trainer.increase_better:
break
if not self.trainer.increase_better and _pair[0]<=pair[0]:
break
index += 1
save_pair = None
if len(self._ordered_save_models)<self.top or (len(self._ordered_save_models)>=self.top and index!=-1):
save_pair = pair
self._ordered_save_models.insert(index+1, pair)
delete_pair = None
if len(self._ordered_save_models)>self.top:
delete_pair = self._ordered_save_models.pop(0)
return save_pair, delete_pair
def _save_this_model(self, metric_value):
name = "epoch:{}_step:{}_{}:{:.6f}.pt".format(self.epoch, self.step, self.trainer.metric_key, metric_value)
save_pair, delete_pair = self._insert_into_ordered_save_models((metric_value, name))
if save_pair:
try:
_save_model(self.model, model_name=name, save_dir=self.save_dir, only_param=self.only_param)
except Exception as e:
print(f"The following exception:{e} happens when saves model to {self.save_dir}.")
if delete_pair:
try:
delete_model_path = os.path.join(self.save_dir, delete_pair[1])
if os.path.exists(delete_model_path):
os.remove(delete_model_path)
except Exception as e:
print(f"Fail to delete model {name} at {self.save_dir} caused by exception:{e}.")
def on_exception(self, exception):
if self.save_on_exception:
name = "epoch:{}_step:{}_Exception:{}.pt".format(self.epoch, self.step, exception.__class__.__name__)
_save_model(self.model, model_name=name, save_dir=self.save_dir, only_param=self.only_param)

View File

@ -0,0 +1,157 @@
from time import time
from datetime import timedelta
from fastNLP.io.dataset_loader import JsonLoader
from fastNLP.modules.encoder._bert import BertTokenizer
from fastNLP.io.base_loader import DataInfo
from fastNLP.core.const import Const
class BertData(JsonLoader):
def __init__(self, max_nsents=60, max_ntokens=100, max_len=512):
fields = {'article': 'article',
'label': 'label'}
super(BertData, self).__init__(fields=fields)
self.max_nsents = max_nsents
self.max_ntokens = max_ntokens
self.max_len = max_len
self.tokenizer = BertTokenizer.from_pretrained('/path/to/uncased_L-12_H-768_A-12')
self.cls_id = self.tokenizer.vocab['[CLS]']
self.sep_id = self.tokenizer.vocab['[SEP]']
self.pad_id = self.tokenizer.vocab['[PAD]']
def _load(self, paths):
dataset = super(BertData, self)._load(paths)
return dataset
def process(self, paths):
def truncate_articles(instance, max_nsents=self.max_nsents, max_ntokens=self.max_ntokens):
article = [' '.join(sent.lower().split()[:max_ntokens]) for sent in instance['article']]
return article[:max_nsents]
def truncate_labels(instance):
label = list(filter(lambda x: x < len(instance['article']), instance['label']))
return label
def bert_tokenize(instance, tokenizer, max_len, pad_value):
article = instance['article']
article = ' [SEP] [CLS] '.join(article)
word_pieces = tokenizer.tokenize(article)[:(max_len - 2)]
word_pieces = ['[CLS]'] + word_pieces + ['[SEP]']
token_ids = tokenizer.convert_tokens_to_ids(word_pieces)
while len(token_ids) < max_len:
token_ids.append(pad_value)
assert len(token_ids) == max_len
return token_ids
def get_seg_id(instance, max_len, sep_id):
_segs = [-1] + [i for i, idx in enumerate(instance['article']) if idx == sep_id]
segs = [_segs[i] - _segs[i - 1] for i in range(1, len(_segs))]
segment_id = []
for i, length in enumerate(segs):
if i % 2 == 0:
segment_id += length * [0]
else:
segment_id += length * [1]
while len(segment_id) < max_len:
segment_id.append(0)
return segment_id
def get_cls_id(instance, cls_id):
classification_id = [i for i, idx in enumerate(instance['article']) if idx == cls_id]
return classification_id
def get_labels(instance):
labels = [0] * len(instance['cls_id'])
label_idx = list(filter(lambda x: x < len(instance['cls_id']), instance['label']))
for idx in label_idx:
labels[idx] = 1
return labels
datasets = {}
for name in paths:
datasets[name] = self._load(paths[name])
# remove empty samples
datasets[name].drop(lambda ins: len(ins['article']) == 0 or len(ins['label']) == 0)
# truncate articles
datasets[name].apply(lambda ins: truncate_articles(ins, self.max_nsents, self.max_ntokens), new_field_name='article')
# truncate labels
datasets[name].apply(truncate_labels, new_field_name='label')
# tokenize and convert tokens to id
datasets[name].apply(lambda ins: bert_tokenize(ins, self.tokenizer, self.max_len, self.pad_id), new_field_name='article')
# get segment id
datasets[name].apply(lambda ins: get_seg_id(ins, self.max_len, self.sep_id), new_field_name='segment_id')
# get classification id
datasets[name].apply(lambda ins: get_cls_id(ins, self.cls_id), new_field_name='cls_id')
# get label
datasets[name].apply(get_labels, new_field_name='label')
# rename filed
datasets[name].rename_field('article', Const.INPUTS(0))
datasets[name].rename_field('segment_id', Const.INPUTS(1))
datasets[name].rename_field('cls_id', Const.INPUTS(2))
datasets[name].rename_field('lbael', Const.TARGET)
# set input and target
datasets[name].set_input(Const.INPUTS(0), Const.INPUTS(1), Const.INPUTS(2))
datasets[name].set_target(Const.TARGET)
# set paddding value
datasets[name].set_pad_val('article', 0)
return DataInfo(datasets=datasets)
class BertSumLoader(JsonLoader):
def __init__(self):
fields = {'article': 'article',
'segment_id': 'segment_id',
'cls_id': 'cls_id',
'label': Const.TARGET
}
super(BertSumLoader, self).__init__(fields=fields)
def _load(self, paths):
dataset = super(BertSumLoader, self)._load(paths)
return dataset
def process(self, paths):
def get_seq_len(instance):
return len(instance['article'])
print('Start loading datasets !!!')
start = time()
# load datasets
datasets = {}
for name in paths:
datasets[name] = self._load(paths[name])
datasets[name].apply(get_seq_len, new_field_name='seq_len')
# set input and target
datasets[name].set_input('article', 'segment_id', 'cls_id')
datasets[name].set_target(Const.TARGET)
# set padding value
datasets[name].set_pad_val('article', 0)
datasets[name].set_pad_val('segment_id', 0)
datasets[name].set_pad_val('cls_id', -1)
datasets[name].set_pad_val(Const.TARGET, 0)
print('Finished in {}'.format(timedelta(seconds=time()-start)))
return DataInfo(datasets=datasets)

View File

@ -0,0 +1,178 @@
import numpy as np
import json
from os.path import join
import torch
import logging
import tempfile
import subprocess as sp
from datetime import timedelta
from time import time
from pyrouge import Rouge155
from pyrouge.utils import log
from fastNLP.core.losses import LossBase
from fastNLP.core.metrics import MetricBase
_ROUGE_PATH = '/path/to/RELEASE-1.5.5'
class MyBCELoss(LossBase):
def __init__(self, pred=None, target=None, mask=None):
super(MyBCELoss, self).__init__()
self._init_param_map(pred=pred, target=target, mask=mask)
self.loss_func = torch.nn.BCELoss(reduction='none')
def get_loss(self, pred, target, mask):
loss = self.loss_func(pred, target.float())
loss = (loss * mask.float()).sum()
return loss
class LossMetric(MetricBase):
def __init__(self, pred=None, target=None, mask=None):
super(LossMetric, self).__init__()
self._init_param_map(pred=pred, target=target, mask=mask)
self.loss_func = torch.nn.BCELoss(reduction='none')
self.avg_loss = 0.0
self.nsamples = 0
def evaluate(self, pred, target, mask):
batch_size = pred.size(0)
loss = self.loss_func(pred, target.float())
loss = (loss * mask.float()).sum()
self.avg_loss += loss
self.nsamples += batch_size
def get_metric(self, reset=True):
self.avg_loss = self.avg_loss / self.nsamples
eval_result = {'loss': self.avg_loss}
if reset:
self.avg_loss = 0
self.nsamples = 0
return eval_result
class RougeMetric(MetricBase):
def __init__(self, data_path, dec_path, ref_path, n_total, n_ext=3, ngram_block=3, pred=None, target=None, mask=None):
super(RougeMetric, self).__init__()
self._init_param_map(pred=pred, target=target, mask=mask)
self.data_path = data_path
self.dec_path = dec_path
self.ref_path = ref_path
self.n_total = n_total
self.n_ext = n_ext
self.ngram_block = ngram_block
self.cur_idx = 0
self.ext = []
self.start = time()
@staticmethod
def eval_rouge(dec_dir, ref_dir):
assert _ROUGE_PATH is not None
log.get_global_console_logger().setLevel(logging.WARNING)
dec_pattern = '(\d+).dec'
ref_pattern = '#ID#.ref'
cmd = '-c 95 -r 1000 -n 2 -m'
with tempfile.TemporaryDirectory() as tmp_dir:
Rouge155.convert_summaries_to_rouge_format(
dec_dir, join(tmp_dir, 'dec'))
Rouge155.convert_summaries_to_rouge_format(
ref_dir, join(tmp_dir, 'ref'))
Rouge155.write_config_static(
join(tmp_dir, 'dec'), dec_pattern,
join(tmp_dir, 'ref'), ref_pattern,
join(tmp_dir, 'settings.xml'), system_id=1
)
cmd = (join(_ROUGE_PATH, 'ROUGE-1.5.5.pl')
+ ' -e {} '.format(join(_ROUGE_PATH, 'data'))
+ cmd
+ ' -a {}'.format(join(tmp_dir, 'settings.xml')))
output = sp.check_output(cmd.split(' '), universal_newlines=True)
R_1 = float(output.split('\n')[3].split(' ')[3])
R_2 = float(output.split('\n')[7].split(' ')[3])
R_L = float(output.split('\n')[11].split(' ')[3])
print(output)
return R_1, R_2, R_L
def evaluate(self, pred, target, mask):
pred = pred + mask.float()
pred = pred.cpu().data.numpy()
ext_ids = np.argsort(-pred, 1)
for sent_id in ext_ids:
self.ext.append(sent_id)
self.cur_idx += 1
print('{}/{} ({:.2f}%) decoded in {} seconds\r'.format(
self.cur_idx, self.n_total, self.cur_idx/self.n_total*100, timedelta(seconds=int(time()-self.start))
), end='')
def get_metric(self, use_ngram_block=True, reset=True):
def check_n_gram(sentence, n, dic):
tokens = sentence.split(' ')
s_len = len(tokens)
for i in range(s_len):
if i + n > s_len:
break
if ' '.join(tokens[i: i + n]) in dic:
return False
return True # no n_gram overlap
# load original data
data = []
with open(self.data_path) as f:
for line in f:
cur_data = json.loads(line)
if 'text' in cur_data:
new_data = {}
new_data['article'] = cur_data['text']
new_data['abstract'] = cur_data['summary']
data.append(new_data)
else:
data.append(cur_data)
# write decode sentences and references
if use_ngram_block == True:
print('\nStart {}-gram blocking !!!'.format(self.ngram_block))
for i, ext_ids in enumerate(self.ext):
dec, ref = [], []
if use_ngram_block == False:
n_sent = min(len(data[i]['article']), self.n_ext)
for j in range(n_sent):
idx = ext_ids[j]
dec.append(data[i]['article'][idx])
else:
n_sent = len(ext_ids)
dic = {}
for j in range(n_sent):
sent = data[i]['article'][ext_ids[j]]
if check_n_gram(sent, self.ngram_block, dic) == True:
dec.append(sent)
# update dic
tokens = sent.split(' ')
s_len = len(tokens)
for k in range(s_len):
if k + self.ngram_block > s_len:
break
dic[' '.join(tokens[k: k + self.ngram_block])] = 1
if len(dec) >= self.n_ext:
break
for sent in data[i]['abstract']:
ref.append(sent)
with open(join(self.dec_path, '{}.dec'.format(i)), 'w') as f:
for sent in dec:
print(sent, file=f)
with open(join(self.ref_path, '{}.ref'.format(i)), 'w') as f:
for sent in ref:
print(sent, file=f)
print('\nStart evaluating ROUGE score !!!')
R_1, R_2, R_L = RougeMetric.eval_rouge(self.dec_path, self.ref_path)
eval_result = {'ROUGE-1': R_1, 'ROUGE-2': R_2, 'ROUGE-L':R_L}
if reset == True:
self.cur_idx = 0
self.ext = []
self.start = time()
return eval_result

View File

@ -0,0 +1,51 @@
import torch
from torch import nn
from torch.nn import init
from fastNLP.modules.encoder._bert import BertModel
class Classifier(nn.Module):
def __init__(self, hidden_size):
super(Classifier, self).__init__()
self.linear = nn.Linear(hidden_size, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, inputs, mask_cls):
h = self.linear(inputs).squeeze(-1) # [batch_size, seq_len]
sent_scores = self.sigmoid(h) * mask_cls.float()
return sent_scores
class BertSum(nn.Module):
def __init__(self, hidden_size=768):
super(BertSum, self).__init__()
self.hidden_size = hidden_size
self.encoder = BertModel.from_pretrained('/path/to/uncased_L-12_H-768_A-12')
self.decoder = Classifier(self.hidden_size)
def forward(self, article, segment_id, cls_id):
# print(article.device)
# print(segment_id.device)
# print(cls_id.device)
input_mask = 1 - (article == 0)
mask_cls = 1 - (cls_id == -1)
assert input_mask.size() == article.size()
assert mask_cls.size() == cls_id.size()
bert_out = self.encoder(article, token_type_ids=segment_id, attention_mask=input_mask)
bert_out = bert_out[0][-1] # last layer
sent_emb = bert_out[torch.arange(bert_out.size(0)).unsqueeze(1), cls_id]
sent_emb = sent_emb * mask_cls.unsqueeze(-1).float()
assert sent_emb.size() == (article.size(0), cls_id.size(1), self.hidden_size) # [batch_size, seq_len, hidden_size]
sent_scores = self.decoder(sent_emb, mask_cls) # [batch_size, seq_len]
assert sent_scores.size() == (article.size(0), cls_id.size(1))
return {'pred': sent_scores, 'mask': mask_cls}

View File

@ -0,0 +1,147 @@
import sys
import argparse
import os
import json
import torch
from time import time
from datetime import timedelta
from os.path import join, exists
from torch.optim import Adam
from utils import get_data_path, get_rouge_path
from dataloader import BertSumLoader
from model import BertSum
from fastNLP.core.optimizer import AdamW
from metrics import MyBCELoss, LossMetric, RougeMetric
from fastNLP.core.sampler import BucketSampler
from callback import MyCallback, SaveModelCallback
from fastNLP.core.trainer import Trainer
from fastNLP.core.tester import Tester
def configure_training(args):
devices = [int(gpu) for gpu in args.gpus.split(',')]
params = {}
params['label_type'] = args.label_type
params['batch_size'] = args.batch_size
params['accum_count'] = args.accum_count
params['max_lr'] = args.max_lr
params['warmup_steps'] = args.warmup_steps
params['n_epochs'] = args.n_epochs
params['valid_steps'] = args.valid_steps
return devices, params
def train_model(args):
# check if the data_path and save_path exists
data_paths = get_data_path(args.mode, args.label_type)
for name in data_paths:
assert exists(data_paths[name])
if not exists(args.save_path):
os.makedirs(args.save_path)
# load summarization datasets
datasets = BertSumLoader().process(data_paths)
print('Information of dataset is:')
print(datasets)
train_set = datasets.datasets['train']
valid_set = datasets.datasets['val']
# configure training
devices, train_params = configure_training(args)
with open(join(args.save_path, 'params.json'), 'w') as f:
json.dump(train_params, f, indent=4)
print('Devices is:')
print(devices)
# configure model
model = BertSum()
optimizer = Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0)
callbacks = [MyCallback(args), SaveModelCallback(args.save_path)]
criterion = MyBCELoss()
val_metric = [LossMetric()]
# sampler = BucketSampler(num_buckets=32, batch_size=args.batch_size)
trainer = Trainer(train_data=train_set, model=model, optimizer=optimizer,
loss=criterion, batch_size=args.batch_size, # sampler=sampler,
update_every=args.accum_count, n_epochs=args.n_epochs,
print_every=100, dev_data=valid_set, metrics=val_metric,
metric_key='-loss', validate_every=args.valid_steps,
save_path=args.save_path, device=devices, callbacks=callbacks)
print('Start training with the following hyper-parameters:')
print(train_params)
trainer.train()
def test_model(args):
models = os.listdir(args.save_path)
# load dataset
data_paths = get_data_path(args.mode, args.label_type)
datasets = BertSumLoader().process(data_paths)
print('Information of dataset is:')
print(datasets)
test_set = datasets.datasets['test']
# only need 1 gpu for testing
device = int(args.gpus)
args.batch_size = 1
for cur_model in models:
print('Current model is {}'.format(cur_model))
# load model
model = torch.load(join(args.save_path, cur_model))
# configure testing
original_path, dec_path, ref_path = get_rouge_path(args.label_type)
test_metric = RougeMetric(data_path=original_path, dec_path=dec_path,
ref_path=ref_path, n_total = len(test_set))
tester = Tester(data=test_set, model=model, metrics=[test_metric],
batch_size=args.batch_size, device=device)
tester.test()
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='training/testing of BertSum(liu et al. 2019)'
)
parser.add_argument('--mode', required=True,
help='training or testing of BertSum', type=str)
parser.add_argument('--label_type', default='greedy',
help='greedy/limit', type=str)
parser.add_argument('--save_path', required=True,
help='root of the model', type=str)
# example for gpus input: '0,1,2,3'
parser.add_argument('--gpus', required=True,
help='available gpus for training(separated by commas)', type=str)
parser.add_argument('--batch_size', default=18,
help='the training batch size', type=int)
parser.add_argument('--accum_count', default=2,
help='number of updates steps to accumulate before performing a backward/update pass.', type=int)
parser.add_argument('--max_lr', default=2e-5,
help='max learning rate for warm up', type=float)
parser.add_argument('--warmup_steps', default=10000,
help='warm up steps for training', type=int)
parser.add_argument('--n_epochs', default=10,
help='total number of training epochs', type=int)
parser.add_argument('--valid_steps', default=1000,
help='number of update steps for checkpoint and validation', type=int)
args = parser.parse_args()
if args.mode == 'train':
print('Training process of BertSum !!!')
train_model(args)
else:
print('Testing process of BertSum !!!')
test_model(args)

View File

@ -0,0 +1,24 @@
import os
from os.path import exists
def get_data_path(mode, label_type):
paths = {}
if mode == 'train':
paths['train'] = 'data/' + label_type + '/bert.train.jsonl'
paths['val'] = 'data/' + label_type + '/bert.val.jsonl'
else:
paths['test'] = 'data/' + label_type + '/bert.test.jsonl'
return paths
def get_rouge_path(label_type):
if label_type == 'others':
data_path = 'data/' + label_type + '/bert.test.jsonl'
else:
data_path = 'data/' + label_type + '/test.jsonl'
dec_path = 'dec'
ref_path = 'ref'
if not exists(ref_path):
os.makedirs(ref_path)
if not exists(dec_path):
os.makedirs(dec_path)
return data_path, dec_path, ref_path