mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-05 05:38:31 +08:00
commit
eb01a5e833
@ -1,6 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
from typing import Union, Dict
|
from typing import Union, Dict , List
|
||||||
|
|
||||||
from ...core.const import Const
|
from ...core.const import Const
|
||||||
from ...core.vocabulary import Vocabulary
|
from ...core.vocabulary import Vocabulary
|
||||||
@ -33,7 +33,8 @@ class MatchingLoader(DataSetLoader):
|
|||||||
to_lower=False, seq_len_type: str=None, bert_tokenizer: str=None,
|
to_lower=False, seq_len_type: str=None, bert_tokenizer: str=None,
|
||||||
cut_text: int = None, get_index=True, auto_pad_length: int=None,
|
cut_text: int = None, get_index=True, auto_pad_length: int=None,
|
||||||
auto_pad_token: str='<pad>', set_input: Union[list, str, bool]=True,
|
auto_pad_token: str='<pad>', set_input: Union[list, str, bool]=True,
|
||||||
set_target: Union[list, str, bool] = True, concat: Union[str, list, bool]=None, ) -> DataInfo:
|
set_target: Union[list, str, bool] = True, concat: Union[str, list, bool]=None,
|
||||||
|
extra_split: List[str]=['-'], ) -> DataInfo:
|
||||||
"""
|
"""
|
||||||
:param paths: str或者Dict[str, str]。如果是str,则为数据集所在的文件夹或者是全路径文件名:如果是文件夹,
|
:param paths: str或者Dict[str, str]。如果是str,则为数据集所在的文件夹或者是全路径文件名:如果是文件夹,
|
||||||
则会从self.paths里面找对应的数据集名称与文件名。如果是Dict,则为数据集名称(如train、dev、test)和
|
则会从self.paths里面找对应的数据集名称与文件名。如果是Dict,则为数据集名称(如train、dev、test)和
|
||||||
@ -56,6 +57,7 @@ class MatchingLoader(DataSetLoader):
|
|||||||
:param concat: 是否需要将两个句子拼接起来。如果为False则不会拼接。如果为True则会在两个句子之间插入一个<sep>。
|
:param concat: 是否需要将两个句子拼接起来。如果为False则不会拼接。如果为True则会在两个句子之间插入一个<sep>。
|
||||||
如果传入一个长度为4的list,则分别表示插在第一句开始前、第一句结束后、第二句开始前、第二句结束后的标识符。如果
|
如果传入一个长度为4的list,则分别表示插在第一句开始前、第一句结束后、第二句开始前、第二句结束后的标识符。如果
|
||||||
传入字符串 ``bert`` ,则会采用bert的拼接方式,等价于['[CLS]', '[SEP]', '', '[SEP]'].
|
传入字符串 ``bert`` ,则会采用bert的拼接方式,等价于['[CLS]', '[SEP]', '', '[SEP]'].
|
||||||
|
:param extra_split: 额外的分隔符,即除了空格之外的用于分词的字符。
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
if isinstance(set_input, str):
|
if isinstance(set_input, str):
|
||||||
@ -89,6 +91,24 @@ class MatchingLoader(DataSetLoader):
|
|||||||
if Const.TARGET in data_set.get_field_names():
|
if Const.TARGET in data_set.get_field_names():
|
||||||
data_set.set_target(Const.TARGET)
|
data_set.set_target(Const.TARGET)
|
||||||
|
|
||||||
|
if extra_split:
|
||||||
|
for data_name, data_set in data_info.datasets.items():
|
||||||
|
data_set.apply(lambda x: ' '.join(x[Const.INPUTS(0)]), new_field_name=Const.INPUTS(0))
|
||||||
|
data_set.apply(lambda x: ' '.join(x[Const.INPUTS(1)]), new_field_name=Const.INPUTS(1))
|
||||||
|
|
||||||
|
for s in extra_split:
|
||||||
|
data_set.apply(lambda x: x[Const.INPUTS(0)].replace(s , ' ' + s + ' '),
|
||||||
|
new_field_name=Const.INPUTS(0))
|
||||||
|
data_set.apply(lambda x: x[Const.INPUTS(0)].replace(s , ' ' + s + ' '),
|
||||||
|
new_field_name=Const.INPUTS(0))
|
||||||
|
|
||||||
|
_filt = lambda x : x
|
||||||
|
data_set.apply(lambda x: list(filter(_filt , x[Const.INPUTS(0)].split(' '))),
|
||||||
|
new_field_name=Const.INPUTS(0), is_input=auto_set_input)
|
||||||
|
data_set.apply(lambda x: list(filter(_filt , x[Const.INPUTS(1)].split(' '))),
|
||||||
|
new_field_name=Const.INPUTS(1), is_input=auto_set_input)
|
||||||
|
_filt = None
|
||||||
|
|
||||||
if to_lower:
|
if to_lower:
|
||||||
for data_name, data_set in data_info.datasets.items():
|
for data_name, data_set in data_info.datasets.items():
|
||||||
data_set.apply(lambda x: [w.lower() for w in x[Const.INPUTS(0)]], new_field_name=Const.INPUTS(0),
|
data_set.apply(lambda x: [w.lower() for w in x[Const.INPUTS(0)]], new_field_name=Const.INPUTS(0),
|
||||||
|
145
reproduction/matching/matching_mwan.py
Normal file
145
reproduction/matching/matching_mwan.py
Normal file
@ -0,0 +1,145 @@
|
|||||||
|
import sys
|
||||||
|
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch.optim import Adadelta, SGD
|
||||||
|
from torch.optim.lr_scheduler import StepLR
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from fastNLP import CrossEntropyLoss
|
||||||
|
from fastNLP import cache_results
|
||||||
|
from fastNLP.core import Trainer, Tester, Adam, AccuracyMetric, Const
|
||||||
|
from fastNLP.core.predictor import Predictor
|
||||||
|
from fastNLP.core.callback import GradientClipCallback, LRScheduler, FitlogCallback
|
||||||
|
from fastNLP.modules.encoder.embedding import ElmoEmbedding, StaticEmbedding
|
||||||
|
|
||||||
|
from fastNLP.io.data_loader import MNLILoader, QNLILoader, QuoraLoader, SNLILoader, RTELoader
|
||||||
|
from model.mwan import MwanModel
|
||||||
|
|
||||||
|
import fitlog
|
||||||
|
fitlog.debug()
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
|
||||||
|
argument = argparse.ArgumentParser()
|
||||||
|
argument.add_argument('--task' , choices = ['snli', 'rte', 'qnli', 'mnli'],default = 'snli')
|
||||||
|
argument.add_argument('--batch-size' , type = int , default = 128)
|
||||||
|
argument.add_argument('--n-epochs' , type = int , default = 50)
|
||||||
|
argument.add_argument('--lr' , type = float , default = 1)
|
||||||
|
argument.add_argument('--testset-name' , type = str , default = 'test')
|
||||||
|
argument.add_argument('--devset-name' , type = str , default = 'dev')
|
||||||
|
argument.add_argument('--seed' , type = int , default = 42)
|
||||||
|
argument.add_argument('--hidden-size' , type = int , default = 150)
|
||||||
|
argument.add_argument('--dropout' , type = float , default = 0.3)
|
||||||
|
arg = argument.parse_args()
|
||||||
|
|
||||||
|
random.seed(arg.seed)
|
||||||
|
np.random.seed(arg.seed)
|
||||||
|
torch.manual_seed(arg.seed)
|
||||||
|
|
||||||
|
n_gpu = torch.cuda.device_count()
|
||||||
|
if n_gpu > 0:
|
||||||
|
torch.cuda.manual_seed_all(arg.seed)
|
||||||
|
print (n_gpu)
|
||||||
|
|
||||||
|
for k in arg.__dict__:
|
||||||
|
print(k, arg.__dict__[k], type(arg.__dict__[k]))
|
||||||
|
|
||||||
|
# load data set
|
||||||
|
if arg.task == 'snli':
|
||||||
|
@cache_results(f'snli_mwan.pkl')
|
||||||
|
def read_snli():
|
||||||
|
data_info = SNLILoader().process(
|
||||||
|
paths='path/to/snli/data', to_lower=True, seq_len_type=None, bert_tokenizer=None,
|
||||||
|
get_index=True, concat=False, extra_split=['/','%','-'],
|
||||||
|
)
|
||||||
|
return data_info
|
||||||
|
data_info = read_snli()
|
||||||
|
elif arg.task == 'rte':
|
||||||
|
@cache_results(f'rte_mwan.pkl')
|
||||||
|
def read_rte():
|
||||||
|
data_info = RTELoader().process(
|
||||||
|
paths='path/to/rte/data', to_lower=True, seq_len_type=None, bert_tokenizer=None,
|
||||||
|
get_index=True, concat=False, extra_split=['/','%','-'],
|
||||||
|
)
|
||||||
|
return data_info
|
||||||
|
data_info = read_rte()
|
||||||
|
elif arg.task == 'qnli':
|
||||||
|
data_info = QNLILoader().process(
|
||||||
|
paths='path/to/qnli/data', to_lower=True, seq_len_type=None, bert_tokenizer=None,
|
||||||
|
get_index=True, concat=False , cut_text=512, extra_split=['/','%','-'],
|
||||||
|
)
|
||||||
|
elif arg.task == 'mnli':
|
||||||
|
@cache_results(f'mnli_v0.9_mwan.pkl')
|
||||||
|
def read_mnli():
|
||||||
|
data_info = MNLILoader().process(
|
||||||
|
paths='path/to/mnli/data', to_lower=True, seq_len_type=None, bert_tokenizer=None,
|
||||||
|
get_index=True, concat=False, extra_split=['/','%','-'],
|
||||||
|
)
|
||||||
|
return data_info
|
||||||
|
data_info = read_mnli()
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f'NOT support {arg.task} task yet!')
|
||||||
|
|
||||||
|
print(data_info)
|
||||||
|
print(len(data_info.vocabs['words']))
|
||||||
|
|
||||||
|
|
||||||
|
model = MwanModel(
|
||||||
|
num_class = len(data_info.vocabs[Const.TARGET]),
|
||||||
|
EmbLayer = StaticEmbedding(data_info.vocabs[Const.INPUT], requires_grad=False, normalize=False),
|
||||||
|
ElmoLayer = None,
|
||||||
|
args_of_imm = {
|
||||||
|
"input_size" : 300 ,
|
||||||
|
"hidden_size" : arg.hidden_size ,
|
||||||
|
"dropout" : arg.dropout ,
|
||||||
|
"use_allennlp" : False ,
|
||||||
|
} ,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
optimizer = Adadelta(lr=arg.lr, params=model.parameters())
|
||||||
|
scheduler = StepLR(optimizer, step_size=10, gamma=0.5)
|
||||||
|
|
||||||
|
callbacks = [
|
||||||
|
LRScheduler(scheduler),
|
||||||
|
]
|
||||||
|
|
||||||
|
if arg.task in ['snli']:
|
||||||
|
callbacks.append(FitlogCallback(data_info.datasets[arg.testset_name], verbose=1))
|
||||||
|
elif arg.task == 'mnli':
|
||||||
|
callbacks.append(FitlogCallback({'dev_matched': data_info.datasets['dev_matched'],
|
||||||
|
'dev_mismatched': data_info.datasets['dev_mismatched']},
|
||||||
|
verbose=1))
|
||||||
|
|
||||||
|
trainer = Trainer(
|
||||||
|
train_data = data_info.datasets['train'],
|
||||||
|
model = model,
|
||||||
|
optimizer = optimizer,
|
||||||
|
num_workers = 0,
|
||||||
|
batch_size = arg.batch_size,
|
||||||
|
n_epochs = arg.n_epochs,
|
||||||
|
print_every = -1,
|
||||||
|
dev_data = data_info.datasets[arg.devset_name],
|
||||||
|
metrics = AccuracyMetric(pred = "pred" , target = "target"),
|
||||||
|
metric_key = 'acc',
|
||||||
|
device = [i for i in range(torch.cuda.device_count())],
|
||||||
|
check_code_level = -1,
|
||||||
|
callbacks = callbacks,
|
||||||
|
loss = CrossEntropyLoss(pred = "pred" , target = "target")
|
||||||
|
)
|
||||||
|
trainer.train(load_best_model=True)
|
||||||
|
|
||||||
|
tester = Tester(
|
||||||
|
data=data_info.datasets[arg.testset_name],
|
||||||
|
model=model,
|
||||||
|
metrics=AccuracyMetric(),
|
||||||
|
batch_size=arg.batch_size,
|
||||||
|
device=[i for i in range(torch.cuda.device_count())],
|
||||||
|
)
|
||||||
|
tester.test()
|
455
reproduction/matching/model/mwan.py
Normal file
455
reproduction/matching/model/mwan.py
Normal file
@ -0,0 +1,455 @@
|
|||||||
|
import torch as tc
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import math
|
||||||
|
from fastNLP.core.const import Const
|
||||||
|
|
||||||
|
class RNNModel(nn.Module):
|
||||||
|
def __init__(self, input_size, hidden_size, num_layers, bidrect, dropout):
|
||||||
|
super(RNNModel, self).__init__()
|
||||||
|
|
||||||
|
if num_layers <= 1:
|
||||||
|
dropout = 0.0
|
||||||
|
|
||||||
|
self.rnn = nn.GRU(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers,
|
||||||
|
batch_first=True, dropout=dropout, bidirectional=bidrect)
|
||||||
|
|
||||||
|
self.number = (2 if bidrect else 1) * num_layers
|
||||||
|
|
||||||
|
def forward(self, x, mask):
|
||||||
|
'''
|
||||||
|
mask: (batch_size, seq_len)
|
||||||
|
x: (batch_size, seq_len, input_size)
|
||||||
|
'''
|
||||||
|
lens = (mask).long().sum(dim=1)
|
||||||
|
lens, idx_sort = tc.sort(lens, descending=True)
|
||||||
|
_, idx_unsort = tc.sort(idx_sort)
|
||||||
|
|
||||||
|
x = x[idx_sort]
|
||||||
|
|
||||||
|
x = nn.utils.rnn.pack_padded_sequence(x, lens, batch_first=True)
|
||||||
|
self.rnn.flatten_parameters()
|
||||||
|
y, h = self.rnn(x)
|
||||||
|
y, lens = nn.utils.rnn.pad_packed_sequence(y, batch_first=True)
|
||||||
|
|
||||||
|
h = h.transpose(0,1).contiguous() #make batch size first
|
||||||
|
|
||||||
|
y = y[idx_unsort] #(batch_size, seq_len, bid * hid_size)
|
||||||
|
h = h[idx_unsort] #(batch_size, number, hid_size)
|
||||||
|
|
||||||
|
return y, h
|
||||||
|
|
||||||
|
class Contexualizer(nn.Module):
|
||||||
|
def __init__(self, input_size, hidden_size, num_layers=1, dropout=0.3):
|
||||||
|
super(Contexualizer, self).__init__()
|
||||||
|
|
||||||
|
self.rnn = RNNModel(input_size, hidden_size, num_layers, True, dropout)
|
||||||
|
self.output_size = hidden_size * 2
|
||||||
|
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
weights = self.rnn.rnn.all_weights
|
||||||
|
for w1 in weights:
|
||||||
|
for w2 in w1:
|
||||||
|
if len(list(w2.size())) <= 1:
|
||||||
|
w2.data.fill_(0)
|
||||||
|
else: nn.init.xavier_normal_(w2.data, gain=1.414)
|
||||||
|
|
||||||
|
def forward(self, s, mask):
|
||||||
|
y = self.rnn(s, mask)[0] # (batch_size, seq_len, 2 * hidden_size)
|
||||||
|
|
||||||
|
return y
|
||||||
|
|
||||||
|
class ConcatAttention_Param(nn.Module):
|
||||||
|
def __init__(self, input_size, hidden_size, dropout=0.2):
|
||||||
|
super(ConcatAttention_Param, self).__init__()
|
||||||
|
self.ln = nn.Linear(input_size + hidden_size, hidden_size)
|
||||||
|
self.v = nn.Linear(hidden_size, 1, bias=False)
|
||||||
|
self.vq = nn.Parameter(tc.rand(hidden_size))
|
||||||
|
self.drop = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
self.output_size = input_size
|
||||||
|
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
|
||||||
|
nn.init.xavier_uniform_(self.v.weight.data)
|
||||||
|
nn.init.xavier_uniform_(self.ln.weight.data)
|
||||||
|
self.ln.bias.data.fill_(0)
|
||||||
|
|
||||||
|
def forward(self, h, mask):
|
||||||
|
'''
|
||||||
|
h: (batch_size, len, input_size)
|
||||||
|
mask: (batch_size, len)
|
||||||
|
'''
|
||||||
|
|
||||||
|
vq = self.vq.view(1,1,-1).expand(h.size(0), h.size(1), self.vq.size(0))
|
||||||
|
|
||||||
|
s = self.v(tc.tanh(self.ln(tc.cat([h,vq],-1)))).squeeze(-1) # (batch_size, len)
|
||||||
|
|
||||||
|
s = s - ((mask == 0).float() * 10000)
|
||||||
|
a = tc.softmax(s, dim=1)
|
||||||
|
|
||||||
|
r = a.unsqueeze(-1) * h # (batch_size, len, input_size)
|
||||||
|
r = tc.sum(r, dim=1) # (batch_size, input_size)
|
||||||
|
|
||||||
|
return self.drop(r)
|
||||||
|
|
||||||
|
|
||||||
|
def get_2dmask(mask_hq, mask_hp, siz=None):
|
||||||
|
|
||||||
|
if siz is None:
|
||||||
|
siz = (mask_hq.size(0), mask_hq.size(1), mask_hp.size(1))
|
||||||
|
|
||||||
|
mask_mat = 1
|
||||||
|
if mask_hq is not None:
|
||||||
|
mask_mat = mask_mat * mask_hq.unsqueeze(2).expand(siz)
|
||||||
|
if mask_hp is not None:
|
||||||
|
mask_mat = mask_mat * mask_hp.unsqueeze(1).expand(siz)
|
||||||
|
return mask_mat
|
||||||
|
|
||||||
|
def Attention(hq, hp, mask_hq, mask_hp, my_method):
|
||||||
|
standard_size = (hq.size(0), hq.size(1), hp.size(1), hq.size(-1))
|
||||||
|
mask_mat = get_2dmask(mask_hq, mask_hp, standard_size[:-1])
|
||||||
|
|
||||||
|
hq_mat = hq.unsqueeze(2).expand(standard_size)
|
||||||
|
hp_mat = hp.unsqueeze(1).expand(standard_size)
|
||||||
|
|
||||||
|
s = my_method(hq_mat, hp_mat) # (batch_size, len_q, len_p)
|
||||||
|
|
||||||
|
s = s - ((mask_mat == 0).float() * 10000)
|
||||||
|
a = tc.softmax(s, dim=1)
|
||||||
|
|
||||||
|
q = a.unsqueeze(-1) * hq_mat #(batch_size, len_q, len_p, input_size)
|
||||||
|
q = tc.sum(q, dim=1) #(batch_size, len_p, input_size)
|
||||||
|
|
||||||
|
return q
|
||||||
|
|
||||||
|
class ConcatAttention(nn.Module):
|
||||||
|
def __init__(self, input_size, hidden_size, dropout=0.2, input_size_2=-1):
|
||||||
|
super(ConcatAttention, self).__init__()
|
||||||
|
|
||||||
|
if input_size_2 < 0:
|
||||||
|
input_size_2 = input_size
|
||||||
|
self.ln = nn.Linear(input_size + input_size_2, hidden_size)
|
||||||
|
self.v = nn.Linear(hidden_size, 1, bias=False)
|
||||||
|
self.drop = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
self.output_size = input_size
|
||||||
|
|
||||||
|
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
|
||||||
|
nn.init.xavier_uniform_(self.v.weight.data)
|
||||||
|
nn.init.xavier_uniform_(self.ln.weight.data)
|
||||||
|
self.ln.bias.data.fill_(0)
|
||||||
|
|
||||||
|
def my_method(self, hq_mat, hp_mat):
|
||||||
|
s = tc.cat([hq_mat, hp_mat], dim=-1)
|
||||||
|
s = self.v(tc.tanh(self.ln(s))).squeeze(-1) #(batch_size, len_q, len_p)
|
||||||
|
return s
|
||||||
|
|
||||||
|
def forward(self, hq, hp, mask_hq=None, mask_hp=None):
|
||||||
|
'''
|
||||||
|
hq: (batch_size, len_q, input_size)
|
||||||
|
mask_hq: (batch_size, len_q)
|
||||||
|
'''
|
||||||
|
return self.drop(Attention(hq, hp, mask_hq, mask_hp, self.my_method))
|
||||||
|
|
||||||
|
class MinusAttention(nn.Module):
|
||||||
|
def __init__(self, input_size, hidden_size, dropout=0.2):
|
||||||
|
super(MinusAttention, self).__init__()
|
||||||
|
self.ln = nn.Linear(input_size, hidden_size)
|
||||||
|
self.v = nn.Linear(hidden_size, 1, bias=False)
|
||||||
|
|
||||||
|
self.drop = nn.Dropout(dropout)
|
||||||
|
self.output_size = input_size
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
|
||||||
|
nn.init.xavier_uniform_(self.v.weight.data)
|
||||||
|
nn.init.xavier_uniform_(self.ln.weight.data)
|
||||||
|
self.ln.bias.data.fill_(0)
|
||||||
|
|
||||||
|
def my_method(self, hq_mat, hp_mat):
|
||||||
|
s = hq_mat - hp_mat
|
||||||
|
s = self.v(tc.tanh(self.ln(s))).squeeze(-1) #(batch_size, len_q, len_p) s[j,t]
|
||||||
|
return s
|
||||||
|
|
||||||
|
def forward(self, hq, hp, mask_hq=None, mask_hp=None):
|
||||||
|
return self.drop(Attention(hq, hp, mask_hq, mask_hp, self.my_method))
|
||||||
|
|
||||||
|
class DotProductAttention(nn.Module):
|
||||||
|
def __init__(self, input_size, hidden_size, dropout=0.2):
|
||||||
|
super(DotProductAttention, self).__init__()
|
||||||
|
self.ln = nn.Linear(input_size, hidden_size)
|
||||||
|
self.v = nn.Linear(hidden_size, 1, bias=False)
|
||||||
|
|
||||||
|
self.drop = nn.Dropout(dropout)
|
||||||
|
self.output_size = input_size
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
|
||||||
|
nn.init.xavier_uniform_(self.v.weight.data)
|
||||||
|
nn.init.xavier_uniform_(self.ln.weight.data)
|
||||||
|
self.ln.bias.data.fill_(0)
|
||||||
|
|
||||||
|
def my_method(self, hq_mat, hp_mat):
|
||||||
|
s = hq_mat * hp_mat
|
||||||
|
s = self.v(tc.tanh(self.ln(s))).squeeze(-1) #(batch_size, len_q, len_p) s[j,t]
|
||||||
|
return s
|
||||||
|
|
||||||
|
def forward(self, hq, hp, mask_hq=None, mask_hp=None):
|
||||||
|
return self.drop(Attention(hq, hp, mask_hq, mask_hp, self.my_method))
|
||||||
|
|
||||||
|
class BiLinearAttention(nn.Module):
|
||||||
|
def __init__(self, input_size, hidden_size, dropout=0.2, input_size_2=-1):
|
||||||
|
super(BiLinearAttention, self).__init__()
|
||||||
|
|
||||||
|
input_size_2 = input_size if input_size_2 < 0 else input_size_2
|
||||||
|
|
||||||
|
self.ln = nn.Linear(input_size_2, input_size)
|
||||||
|
self.drop = nn.Dropout(dropout)
|
||||||
|
self.output_size = input_size
|
||||||
|
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
|
||||||
|
nn.init.xavier_uniform_(self.ln.weight.data)
|
||||||
|
self.ln.bias.data.fill_(0)
|
||||||
|
|
||||||
|
def my_method(self, hq, hp, mask_p):
|
||||||
|
# (bs, len, input_size)
|
||||||
|
|
||||||
|
hp = self.ln(hp)
|
||||||
|
hp = hp * mask_p.unsqueeze(-1)
|
||||||
|
s = tc.matmul(hq, hp.transpose(-1,-2))
|
||||||
|
|
||||||
|
return s
|
||||||
|
|
||||||
|
def forward(self, hq, hp, mask_hq=None, mask_hp=None):
|
||||||
|
standard_size = (hq.size(0), hq.size(1), hp.size(1), hq.size(-1))
|
||||||
|
mask_mat = get_2dmask(mask_hq, mask_hp, standard_size[:-1])
|
||||||
|
|
||||||
|
s = self.my_method(hq, hp, mask_hp) # (batch_size, len_q, len_p)
|
||||||
|
|
||||||
|
s = s - ((mask_mat == 0).float() * 10000)
|
||||||
|
a = tc.softmax(s, dim=1)
|
||||||
|
|
||||||
|
hq_mat = hq.unsqueeze(2).expand(standard_size)
|
||||||
|
q = a.unsqueeze(-1) * hq_mat #(batch_size, len_q, len_p, input_size)
|
||||||
|
q = tc.sum(q, dim=1) #(batch_size, len_p, input_size)
|
||||||
|
|
||||||
|
return self.drop(q)
|
||||||
|
|
||||||
|
|
||||||
|
class AggAttention(nn.Module):
|
||||||
|
def __init__(self, input_size, hidden_size, dropout=0.2):
|
||||||
|
super(AggAttention, self).__init__()
|
||||||
|
self.ln = nn.Linear(input_size + hidden_size, hidden_size)
|
||||||
|
self.v = nn.Linear(hidden_size, 1, bias=False)
|
||||||
|
self.vq = nn.Parameter(tc.rand(hidden_size, 1))
|
||||||
|
self.drop = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
self.output_size = input_size
|
||||||
|
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
|
||||||
|
nn.init.xavier_uniform_(self.vq.data)
|
||||||
|
nn.init.xavier_uniform_(self.v.weight.data)
|
||||||
|
nn.init.xavier_uniform_(self.ln.weight.data)
|
||||||
|
self.ln.bias.data.fill_(0)
|
||||||
|
self.vq.data = self.vq.data[:,0]
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, hs, mask):
|
||||||
|
'''
|
||||||
|
hs: [(batch_size, len_q, input_size), ...]
|
||||||
|
mask: (batch_size, len_q)
|
||||||
|
'''
|
||||||
|
|
||||||
|
hs = tc.cat([h.unsqueeze(0) for h in hs], dim=0)# (4, batch_size, len_q, input_size)
|
||||||
|
|
||||||
|
vq = self.vq.view(1,1,1,-1).expand(hs.size(0), hs.size(1), hs.size(2), self.vq.size(0))
|
||||||
|
|
||||||
|
s = self.v(tc.tanh(self.ln(tc.cat([hs,vq],-1)))).squeeze(-1)# (4, batch_size, len_q)
|
||||||
|
|
||||||
|
s = s - ((mask.unsqueeze(0) == 0).float() * 10000)
|
||||||
|
a = tc.softmax(s, dim=0)
|
||||||
|
|
||||||
|
x = a.unsqueeze(-1) * hs
|
||||||
|
x = tc.sum(x, dim=0)#(batch_size, len_q, input_size)
|
||||||
|
|
||||||
|
return self.drop(x)
|
||||||
|
|
||||||
|
class Aggragator(nn.Module):
|
||||||
|
def __init__(self, input_size, hidden_size, dropout=0.3):
|
||||||
|
super(Aggragator, self).__init__()
|
||||||
|
|
||||||
|
now_size = input_size
|
||||||
|
self.ln = nn.Linear(2 * input_size, 2 * input_size)
|
||||||
|
|
||||||
|
now_size = 2 * input_size
|
||||||
|
self.rnn = Contexualizer(now_size, hidden_size, 2, dropout)
|
||||||
|
|
||||||
|
now_size = self.rnn.output_size
|
||||||
|
self.agg_att = AggAttention(now_size, now_size, dropout)
|
||||||
|
|
||||||
|
now_size = self.agg_att.output_size
|
||||||
|
self.agg_rnn = Contexualizer(now_size, hidden_size, 2, dropout)
|
||||||
|
|
||||||
|
self.drop = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
self.output_size = self.agg_rnn.output_size
|
||||||
|
|
||||||
|
def forward(self, qs, hp, mask):
|
||||||
|
'''
|
||||||
|
qs: [ (batch_size, len_p, input_size), ...]
|
||||||
|
hp: (batch_size, len_p, input_size)
|
||||||
|
mask if the same of hp's mask
|
||||||
|
'''
|
||||||
|
|
||||||
|
hs = [0 for _ in range(len(qs))]
|
||||||
|
|
||||||
|
for i in range(len(qs)):
|
||||||
|
q = qs[i]
|
||||||
|
x = tc.cat([q, hp], dim=-1)
|
||||||
|
g = tc.sigmoid(self.ln(x))
|
||||||
|
x_star = x * g
|
||||||
|
h = self.rnn(x_star, mask)
|
||||||
|
|
||||||
|
hs[i] = h
|
||||||
|
|
||||||
|
x = self.agg_att(hs, mask) #(batch_size, len_p, output_size)
|
||||||
|
h = self.agg_rnn(x, mask) #(batch_size, len_p, output_size)
|
||||||
|
return self.drop(h)
|
||||||
|
|
||||||
|
|
||||||
|
class Mwan_Imm(nn.Module):
|
||||||
|
def __init__(self, input_size, hidden_size, num_class=3, dropout=0.2, use_allennlp=False):
|
||||||
|
super(Mwan_Imm, self).__init__()
|
||||||
|
|
||||||
|
now_size = input_size
|
||||||
|
self.enc_s1 = Contexualizer(now_size, hidden_size, 2, dropout)
|
||||||
|
self.enc_s2 = Contexualizer(now_size, hidden_size, 2, dropout)
|
||||||
|
|
||||||
|
now_size = self.enc_s1.output_size
|
||||||
|
self.att_c = ConcatAttention(now_size, hidden_size, dropout)
|
||||||
|
self.att_b = BiLinearAttention(now_size, hidden_size, dropout)
|
||||||
|
self.att_d = DotProductAttention(now_size, hidden_size, dropout)
|
||||||
|
self.att_m = MinusAttention(now_size, hidden_size, dropout)
|
||||||
|
|
||||||
|
now_size = self.att_c.output_size
|
||||||
|
self.agg = Aggragator(now_size, hidden_size, dropout)
|
||||||
|
|
||||||
|
now_size = self.enc_s1.output_size
|
||||||
|
self.pred_1 = ConcatAttention_Param(now_size, hidden_size, dropout)
|
||||||
|
now_size = self.agg.output_size
|
||||||
|
self.pred_2 = ConcatAttention(now_size, hidden_size, dropout,
|
||||||
|
input_size_2=self.pred_1.output_size)
|
||||||
|
|
||||||
|
now_size = self.pred_2.output_size
|
||||||
|
self.ln1 = nn.Linear(now_size, hidden_size)
|
||||||
|
self.ln2 = nn.Linear(hidden_size, num_class)
|
||||||
|
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
nn.init.xavier_uniform_(self.ln1.weight.data)
|
||||||
|
nn.init.xavier_uniform_(self.ln2.weight.data)
|
||||||
|
self.ln1.bias.data.fill_(0)
|
||||||
|
self.ln2.bias.data.fill_(0)
|
||||||
|
|
||||||
|
def forward(self, s1, s2, mas_s1, mas_s2):
|
||||||
|
hq = self.enc_s1(s1, mas_s1) #(batch_size, len_q, output_size)
|
||||||
|
hp = self.enc_s1(s2, mas_s2)
|
||||||
|
|
||||||
|
mas_s1 = mas_s1[:,:hq.size(1)]
|
||||||
|
mas_s2 = mas_s2[:,:hp.size(1)]
|
||||||
|
mas_q, mas_p = mas_s1, mas_s2
|
||||||
|
|
||||||
|
qc = self.att_c(hq, hp, mas_s1, mas_s2) #(batch_size, len_p, output_size)
|
||||||
|
qb = self.att_b(hq, hp, mas_s1, mas_s2)
|
||||||
|
qd = self.att_d(hq, hp, mas_s1, mas_s2)
|
||||||
|
qm = self.att_m(hq, hp, mas_s1, mas_s2)
|
||||||
|
|
||||||
|
ho = self.agg([qc,qb,qd,qm], hp, mas_s2) #(batch_size, len_p, output_size)
|
||||||
|
|
||||||
|
rq = self.pred_1(hq, mas_q) #(batch_size, output_size)
|
||||||
|
rp = self.pred_2(ho, rq.unsqueeze(1), mas_p)#(batch_size, 1, output_size)
|
||||||
|
rp = rp.squeeze(1) #(batch_size, output_size)
|
||||||
|
|
||||||
|
rp = F.relu(self.ln1(rp))
|
||||||
|
rp = self.ln2(rp)
|
||||||
|
|
||||||
|
return rp
|
||||||
|
|
||||||
|
class MwanModel(nn.Module):
|
||||||
|
def __init__(self, num_class, EmbLayer, args_of_imm={}, ElmoLayer=None):
|
||||||
|
super(MwanModel, self).__init__()
|
||||||
|
|
||||||
|
self.emb = EmbLayer
|
||||||
|
|
||||||
|
if ElmoLayer is not None:
|
||||||
|
self.elmo = ElmoLayer
|
||||||
|
self.elmo_preln = nn.Linear(3 * self.elmo.emb_size, self.elmo.emb_size)
|
||||||
|
self.elmo_ln = nn.Linear(args_of_imm["input_size"] +
|
||||||
|
self.elmo.emb_size, args_of_imm["input_size"])
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.elmo = None
|
||||||
|
|
||||||
|
|
||||||
|
self.imm = Mwan_Imm(num_class=num_class, **args_of_imm)
|
||||||
|
self.drop = nn.Dropout(args_of_imm["dropout"])
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, words1, words2, str_s1=None, str_s2=None, *pargs, **kwargs):
|
||||||
|
'''
|
||||||
|
str_s is for elmo use , however we don't use elmo
|
||||||
|
str_s: (batch_size, seq_len, word_len)
|
||||||
|
'''
|
||||||
|
|
||||||
|
s1, s2 = words1, words2
|
||||||
|
|
||||||
|
mas_s1 = (s1 != 0).float() # mas: (batch_size, seq_len)
|
||||||
|
mas_s2 = (s2 != 0).float() # mas: (batch_size, seq_len)
|
||||||
|
|
||||||
|
mas_s1.requires_grad = False
|
||||||
|
mas_s2.requires_grad = False
|
||||||
|
|
||||||
|
s1_emb = self.emb(s1)
|
||||||
|
s2_emb = self.emb(s2)
|
||||||
|
|
||||||
|
if self.elmo is not None:
|
||||||
|
s1_elmo = self.elmo(str_s1)
|
||||||
|
s2_elmo = self.elmo(str_s2)
|
||||||
|
|
||||||
|
s1_elmo = tc.tanh(self.elmo_preln(tc.cat(s1_elmo, dim=-1)))
|
||||||
|
s2_elmo = tc.tanh(self.elmo_preln(tc.cat(s2_elmo, dim=-1)))
|
||||||
|
|
||||||
|
s1_emb = tc.cat([s1_emb, s1_elmo], dim=-1)
|
||||||
|
s2_emb = tc.cat([s2_emb, s2_elmo], dim=-1)
|
||||||
|
|
||||||
|
s1_emb = tc.tanh(self.elmo_ln(s1_emb))
|
||||||
|
s2_emb = tc.tanh(self.elmo_ln(s2_emb))
|
||||||
|
|
||||||
|
s1_emb = self.drop(s1_emb)
|
||||||
|
s2_emb = self.drop(s2_emb)
|
||||||
|
|
||||||
|
y = self.imm(s1_emb, s2_emb, mas_s1, mas_s2)
|
||||||
|
|
||||||
|
return {
|
||||||
|
Const.OUTPUT: y,
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user