update crf

This commit is contained in:
yunfan 2018-11-09 14:07:17 +08:00
parent 9b25de3ff3
commit 8fa50d1749
2 changed files with 102 additions and 118 deletions

View File

@ -9,7 +9,7 @@ from fastNLP.core.vocabulary import Vocabulary
_READERS = {}
class DataSet(list):
class DataSet(object):
"""A DataSet object is a list of Instance objects.
"""

View File

@ -31,7 +31,7 @@ class ConditionalRandomField(nn.Module):
self.tag_size = tag_size
# the meaning of entry in this matrix is (from_tag_id, to_tag_id) score
self.transition_m = nn.Parameter(torch.randn(tag_size, tag_size))
self.trans_m = nn.Parameter(torch.randn(tag_size, tag_size))
if self.include_start_end_trans:
self.start_scores = nn.Parameter(torch.randn(tag_size))
self.end_scores = nn.Parameter(torch.randn(tag_size))
@ -39,137 +39,121 @@ class ConditionalRandomField(nn.Module):
# self.reset_parameter()
initial_parameter(self, initial_method)
def reset_parameter(self):
nn.init.xavier_normal_(self.transition_m)
nn.init.xavier_normal_(self.trans_m)
if self.include_start_end_trans:
nn.init.normal_(self.start_scores)
nn.init.normal_(self.end_scores)
def _normalizer_likelihood(self, feats, masks):
def _normalizer_likelihood(self, logits, mask):
"""
Computes the (batch_size,) denominator term for the log-likelihood, which is the
sum of the likelihoods across all possible state sequences.
:param feats:FloatTensor, batch_size x max_len x tag_size
:param masks:ByteTensor, batch_size x max_len
:param logits:FloatTensor, max_len x batch_size x tag_size
:param mask:ByteTensor, max_len x batch_size
:return:FloatTensor, batch_size
"""
batch_size, max_len, _ = feats.size()
# alpha, batch_size x tag_size
seq_len, batch_size, n_tags = logits.size()
alpha = logits[0]
if self.include_start_end_trans:
alpha = self.start_scores.view(1, -1) + feats[:, 0]
else:
alpha = feats[:, 0]
alpha += self.start_scores.view(1, -1)
# broadcast_trans_m, the meaning of entry in this matrix is [batch_idx, to_tag_id, from_tag_id]
broadcast_trans_m = self.transition_m.permute(
1, 0).unsqueeze(0).repeat(batch_size, 1, 1)
# loop
for i in range(1, max_len):
emit_score = feats[:, i].unsqueeze(2)
new_alpha = broadcast_trans_m + alpha.unsqueeze(1) + emit_score
new_alpha = log_sum_exp(new_alpha, dim=2)
alpha = new_alpha * \
masks[:, i:i + 1].float() + alpha * \
(1 - masks[:, i:i + 1].float())
if self.include_start_end_trans:
alpha = alpha + self.end_scores.view(1, -1)
return log_sum_exp(alpha)
def _glod_score(self, feats, tags, masks):
"""
Compute the score for the gold path.
:param feats: FloatTensor, batch_size x max_len x tag_size
:param tags: LongTensor, batch_size x max_len
:param masks: ByteTensor, batch_size x max_len
:return:FloatTensor, batch_size
"""
batch_size, max_len, _ = feats.size()
# alpha, B x 1
if self.include_start_end_trans:
alpha = self.start_scores.view(1, -1).repeat(batch_size, 1).gather(dim=1, index=tags[:, :1]) + \
feats[:, 0].gather(dim=1, index=tags[:, :1])
else:
alpha = feats[:, 0].gather(dim=1, index=tags[:, :1])
for i in range(1, max_len):
trans_score = self.transition_m[(
tags[:, i - 1], tags[:, i])].unsqueeze(1)
emit_score = feats[:, i].gather(dim=1, index=tags[:, i:i + 1])
new_alpha = alpha + trans_score + emit_score
alpha = new_alpha * \
masks[:, i:i + 1].float() + alpha * \
(1 - masks[:, i:i + 1].float())
if self.include_start_end_trans:
last_tag_index = masks.cumsum(dim=1, dtype=torch.long)[:, -1:] - 1
last_from_tag_id = tags.gather(dim=1, index=last_tag_index)
trans_score = self.end_scores.view(
1, -1).repeat(batch_size, 1).gather(dim=1, index=last_from_tag_id)
alpha = alpha + trans_score
return alpha.squeeze(1)
def forward(self, feats, tags, masks):
"""
Calculate the neg log likelihood
:param feats:FloatTensor, batch_size x max_len x tag_size
:param tags:LongTensor, batch_size x max_len
:param masks:ByteTensor batch_size x max_len
:return:FloatTensor, batch_size
"""
all_path_score = self._normalizer_likelihood(feats, masks)
gold_path_score = self._glod_score(feats, tags, masks)
return all_path_score - gold_path_score
def viterbi_decode(self, feats, masks, get_score=False):
"""
Given a feats matrix, return best decode path and best score.
:param feats:
:param masks:
:param get_score: bool, whether to output the decode score.
:return:List[Tuple(List, float)],
"""
batch_size, max_len, tag_size = feats.size()
paths = torch.zeros(batch_size, max_len - 1, self.tag_size)
if self.include_start_end_trans:
alpha = self.start_scores.repeat(batch_size, 1) + feats[:, 0]
else:
alpha = feats[:, 0]
for i in range(1, max_len):
new_alpha = alpha.clone()
for t in range(self.tag_size):
pre_scores = self.transition_m[:, t].view(
1, self.tag_size) + alpha
max_score, indices = pre_scores.max(dim=1)
new_alpha[:, t] = max_score + feats[:, i, t]
paths[:, i - 1, t] = indices
alpha = new_alpha * masks[:, i:i + 1].float() + alpha * (1 - masks[:, i:i + 1].float())
for i in range(1, seq_len):
emit_score = logits[i].view(batch_size, 1, n_tags)
trans_score = self.trans_m.view(1, n_tags, n_tags)
tmp = alpha.view(batch_size, n_tags, 1) + emit_score + trans_score
alpha = log_sum_exp(tmp, 1) * mask[i].view(batch_size, 1) + alpha * (1 - mask[i]).view(batch_size, 1)
if self.include_start_end_trans:
alpha += self.end_scores.view(1, -1)
max_scores, indices = alpha.max(dim=1)
indices = indices.cpu().numpy()
final_paths = []
paths = paths.cpu().numpy().astype(int)
return log_sum_exp(alpha, 1)
seq_lens = masks.cumsum(dim=1, dtype=torch.long)[:, -1]
def _glod_score(self, logits, tags, mask):
"""
Compute the score for the gold path.
:param logits: FloatTensor, max_len x batch_size x tag_size
:param tags: LongTensor, max_len x batch_size
:param mask: ByteTensor, max_len x batch_size
:return:FloatTensor, batch_size
"""
seq_len, batch_size, _ = logits.size()
batch_idx = torch.arange(batch_size, dtype=torch.long, device=logits.device)
seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device)
# trans_socre [L-1, B]
trans_score = self.trans_m[tags[:seq_len-1], tags[1:]] * mask[1:, :]
# emit_score [L, B]
emit_score = logits[seq_idx.view(-1,1), batch_idx.view(1,-1), tags] * mask
# score [L-1, B]
score = trans_score + emit_score[:seq_len-1, :]
score = score.sum(0) + emit_score[-1]
if self.include_start_end_trans:
st_scores = self.start_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[0]]
last_idx = masks.long().sum(0)
ed_scores = self.end_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[last_idx, batch_idx]]
score += st_scores + ed_scores
# return [B,]
return score
def forward(self, feats, tags, mask):
"""
Calculate the neg log likelihood
:param feats:FloatTensor, batch_size x max_len x tag_size
:param tags:LongTensor, batch_size x max_len
:param mask:ByteTensor batch_size x max_len
:return:FloatTensor, batch_size
"""
feats = feats.transpose(0, 1)
tags = tags.transpose(0, 1)
mask = mask.transpose(0, 1)
all_path_score = self._normalizer_likelihood(feats, mask)
gold_path_score = self._glod_score(feats, tags, mask)
return all_path_score - gold_path_score
def viterbi_decode(self, data, mask, get_score=False):
"""
Given a feats matrix, return best decode path and best score.
:param data:FloatTensor, batch_size x max_len x tag_size
:param mask:ByteTensor batch_size x max_len
:param get_score: bool, whether to output the decode score.
:return: scores, paths
"""
batch_size, seq_len, n_tags = data.size()
data = data.transpose(0, 1).data # L, B, H
mask = mask.transpose(0, 1).data.float() # L, B
# dp
vpath = data.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long)
vscore = data[0]
if self.include_start_end_trans:
vscore += self.start_scores.view(1. -1)
for i in range(1, seq_len):
prev_score = vscore.view(batch_size, n_tags, 1)
cur_score = data[i].view(batch_size, 1, n_tags)
trans_score = self.trans_m.view(1, n_tags, n_tags).data
score = prev_score + trans_score + cur_score
best_score, best_dst = score.max(1)
vpath[i] = best_dst
vscore = best_score * mask[i].view(batch_size, 1) + vscore * (1 - mask[i]).view(batch_size, 1)
if self.include_start_end_trans:
vscore += self.end_scores.view(1, -1)
# backtrace
batch_idx = torch.arange(batch_size, dtype=torch.long, device=data.device)
seq_idx = torch.arange(seq_len, dtype=torch.long, device=data.device)
lens = (mask.long().sum(0) - 1)
# idxes [L, B], batched idx from seq_len-1 to 0
idxes = (lens.view(1,-1) - seq_idx.view(-1,1)) % seq_len
ans = data.new_empty((seq_len, batch_size), dtype=torch.long)
ans_score, last_tags = vscore.max(1)
ans[idxes[0], batch_idx] = last_tags
for i in range(seq_len - 1):
last_tags = vpath[idxes[i], batch_idx, last_tags]
ans[idxes[i+1], batch_idx] = last_tags
for b in range(batch_size):
path = [indices[b]]
for i in range(seq_lens[b] - 2, -1, -1):
index = paths[b, i, path[-1]]
path.append(index)
final_paths.append(path[::-1])
if get_score:
return list(zip(final_paths, max_scores.detach().cpu().numpy()))
else:
return final_paths
return ans_score, ans.transpose(0, 1)
return ans.transpose(0, 1)