mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-03 20:57:37 +08:00
update crf
This commit is contained in:
parent
9b25de3ff3
commit
8fa50d1749
@ -9,7 +9,7 @@ from fastNLP.core.vocabulary import Vocabulary
|
||||
|
||||
_READERS = {}
|
||||
|
||||
class DataSet(list):
|
||||
class DataSet(object):
|
||||
"""A DataSet object is a list of Instance objects.
|
||||
|
||||
"""
|
||||
|
@ -31,7 +31,7 @@ class ConditionalRandomField(nn.Module):
|
||||
self.tag_size = tag_size
|
||||
|
||||
# the meaning of entry in this matrix is (from_tag_id, to_tag_id) score
|
||||
self.transition_m = nn.Parameter(torch.randn(tag_size, tag_size))
|
||||
self.trans_m = nn.Parameter(torch.randn(tag_size, tag_size))
|
||||
if self.include_start_end_trans:
|
||||
self.start_scores = nn.Parameter(torch.randn(tag_size))
|
||||
self.end_scores = nn.Parameter(torch.randn(tag_size))
|
||||
@ -39,137 +39,121 @@ class ConditionalRandomField(nn.Module):
|
||||
# self.reset_parameter()
|
||||
initial_parameter(self, initial_method)
|
||||
def reset_parameter(self):
|
||||
nn.init.xavier_normal_(self.transition_m)
|
||||
nn.init.xavier_normal_(self.trans_m)
|
||||
if self.include_start_end_trans:
|
||||
nn.init.normal_(self.start_scores)
|
||||
nn.init.normal_(self.end_scores)
|
||||
|
||||
def _normalizer_likelihood(self, feats, masks):
|
||||
def _normalizer_likelihood(self, logits, mask):
|
||||
"""
|
||||
Computes the (batch_size,) denominator term for the log-likelihood, which is the
|
||||
sum of the likelihoods across all possible state sequences.
|
||||
:param feats:FloatTensor, batch_size x max_len x tag_size
|
||||
:param masks:ByteTensor, batch_size x max_len
|
||||
:param logits:FloatTensor, max_len x batch_size x tag_size
|
||||
:param mask:ByteTensor, max_len x batch_size
|
||||
:return:FloatTensor, batch_size
|
||||
"""
|
||||
batch_size, max_len, _ = feats.size()
|
||||
|
||||
# alpha, batch_size x tag_size
|
||||
seq_len, batch_size, n_tags = logits.size()
|
||||
alpha = logits[0]
|
||||
if self.include_start_end_trans:
|
||||
alpha = self.start_scores.view(1, -1) + feats[:, 0]
|
||||
else:
|
||||
alpha = feats[:, 0]
|
||||
alpha += self.start_scores.view(1, -1)
|
||||
|
||||
# broadcast_trans_m, the meaning of entry in this matrix is [batch_idx, to_tag_id, from_tag_id]
|
||||
broadcast_trans_m = self.transition_m.permute(
|
||||
1, 0).unsqueeze(0).repeat(batch_size, 1, 1)
|
||||
# loop
|
||||
for i in range(1, max_len):
|
||||
emit_score = feats[:, i].unsqueeze(2)
|
||||
new_alpha = broadcast_trans_m + alpha.unsqueeze(1) + emit_score
|
||||
|
||||
new_alpha = log_sum_exp(new_alpha, dim=2)
|
||||
|
||||
alpha = new_alpha * \
|
||||
masks[:, i:i + 1].float() + alpha * \
|
||||
(1 - masks[:, i:i + 1].float())
|
||||
|
||||
if self.include_start_end_trans:
|
||||
alpha = alpha + self.end_scores.view(1, -1)
|
||||
|
||||
return log_sum_exp(alpha)
|
||||
|
||||
def _glod_score(self, feats, tags, masks):
|
||||
"""
|
||||
Compute the score for the gold path.
|
||||
:param feats: FloatTensor, batch_size x max_len x tag_size
|
||||
:param tags: LongTensor, batch_size x max_len
|
||||
:param masks: ByteTensor, batch_size x max_len
|
||||
:return:FloatTensor, batch_size
|
||||
"""
|
||||
batch_size, max_len, _ = feats.size()
|
||||
|
||||
# alpha, B x 1
|
||||
if self.include_start_end_trans:
|
||||
alpha = self.start_scores.view(1, -1).repeat(batch_size, 1).gather(dim=1, index=tags[:, :1]) + \
|
||||
feats[:, 0].gather(dim=1, index=tags[:, :1])
|
||||
else:
|
||||
alpha = feats[:, 0].gather(dim=1, index=tags[:, :1])
|
||||
|
||||
for i in range(1, max_len):
|
||||
trans_score = self.transition_m[(
|
||||
tags[:, i - 1], tags[:, i])].unsqueeze(1)
|
||||
emit_score = feats[:, i].gather(dim=1, index=tags[:, i:i + 1])
|
||||
new_alpha = alpha + trans_score + emit_score
|
||||
|
||||
alpha = new_alpha * \
|
||||
masks[:, i:i + 1].float() + alpha * \
|
||||
(1 - masks[:, i:i + 1].float())
|
||||
|
||||
if self.include_start_end_trans:
|
||||
last_tag_index = masks.cumsum(dim=1, dtype=torch.long)[:, -1:] - 1
|
||||
last_from_tag_id = tags.gather(dim=1, index=last_tag_index)
|
||||
trans_score = self.end_scores.view(
|
||||
1, -1).repeat(batch_size, 1).gather(dim=1, index=last_from_tag_id)
|
||||
alpha = alpha + trans_score
|
||||
|
||||
return alpha.squeeze(1)
|
||||
|
||||
def forward(self, feats, tags, masks):
|
||||
"""
|
||||
Calculate the neg log likelihood
|
||||
:param feats:FloatTensor, batch_size x max_len x tag_size
|
||||
:param tags:LongTensor, batch_size x max_len
|
||||
:param masks:ByteTensor batch_size x max_len
|
||||
:return:FloatTensor, batch_size
|
||||
"""
|
||||
all_path_score = self._normalizer_likelihood(feats, masks)
|
||||
gold_path_score = self._glod_score(feats, tags, masks)
|
||||
|
||||
return all_path_score - gold_path_score
|
||||
|
||||
def viterbi_decode(self, feats, masks, get_score=False):
|
||||
"""
|
||||
Given a feats matrix, return best decode path and best score.
|
||||
:param feats:
|
||||
:param masks:
|
||||
:param get_score: bool, whether to output the decode score.
|
||||
:return:List[Tuple(List, float)],
|
||||
"""
|
||||
batch_size, max_len, tag_size = feats.size()
|
||||
|
||||
paths = torch.zeros(batch_size, max_len - 1, self.tag_size)
|
||||
if self.include_start_end_trans:
|
||||
alpha = self.start_scores.repeat(batch_size, 1) + feats[:, 0]
|
||||
else:
|
||||
alpha = feats[:, 0]
|
||||
for i in range(1, max_len):
|
||||
new_alpha = alpha.clone()
|
||||
for t in range(self.tag_size):
|
||||
pre_scores = self.transition_m[:, t].view(
|
||||
1, self.tag_size) + alpha
|
||||
max_score, indices = pre_scores.max(dim=1)
|
||||
new_alpha[:, t] = max_score + feats[:, i, t]
|
||||
paths[:, i - 1, t] = indices
|
||||
alpha = new_alpha * masks[:, i:i + 1].float() + alpha * (1 - masks[:, i:i + 1].float())
|
||||
for i in range(1, seq_len):
|
||||
emit_score = logits[i].view(batch_size, 1, n_tags)
|
||||
trans_score = self.trans_m.view(1, n_tags, n_tags)
|
||||
tmp = alpha.view(batch_size, n_tags, 1) + emit_score + trans_score
|
||||
alpha = log_sum_exp(tmp, 1) * mask[i].view(batch_size, 1) + alpha * (1 - mask[i]).view(batch_size, 1)
|
||||
|
||||
if self.include_start_end_trans:
|
||||
alpha += self.end_scores.view(1, -1)
|
||||
|
||||
max_scores, indices = alpha.max(dim=1)
|
||||
indices = indices.cpu().numpy()
|
||||
final_paths = []
|
||||
paths = paths.cpu().numpy().astype(int)
|
||||
return log_sum_exp(alpha, 1)
|
||||
|
||||
seq_lens = masks.cumsum(dim=1, dtype=torch.long)[:, -1]
|
||||
def _glod_score(self, logits, tags, mask):
|
||||
"""
|
||||
Compute the score for the gold path.
|
||||
:param logits: FloatTensor, max_len x batch_size x tag_size
|
||||
:param tags: LongTensor, max_len x batch_size
|
||||
:param mask: ByteTensor, max_len x batch_size
|
||||
:return:FloatTensor, batch_size
|
||||
"""
|
||||
seq_len, batch_size, _ = logits.size()
|
||||
batch_idx = torch.arange(batch_size, dtype=torch.long, device=logits.device)
|
||||
seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device)
|
||||
|
||||
# trans_socre [L-1, B]
|
||||
trans_score = self.trans_m[tags[:seq_len-1], tags[1:]] * mask[1:, :]
|
||||
# emit_score [L, B]
|
||||
emit_score = logits[seq_idx.view(-1,1), batch_idx.view(1,-1), tags] * mask
|
||||
# score [L-1, B]
|
||||
score = trans_score + emit_score[:seq_len-1, :]
|
||||
score = score.sum(0) + emit_score[-1]
|
||||
if self.include_start_end_trans:
|
||||
st_scores = self.start_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[0]]
|
||||
last_idx = masks.long().sum(0)
|
||||
ed_scores = self.end_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[last_idx, batch_idx]]
|
||||
score += st_scores + ed_scores
|
||||
# return [B,]
|
||||
return score
|
||||
|
||||
def forward(self, feats, tags, mask):
|
||||
"""
|
||||
Calculate the neg log likelihood
|
||||
:param feats:FloatTensor, batch_size x max_len x tag_size
|
||||
:param tags:LongTensor, batch_size x max_len
|
||||
:param mask:ByteTensor batch_size x max_len
|
||||
:return:FloatTensor, batch_size
|
||||
"""
|
||||
feats = feats.transpose(0, 1)
|
||||
tags = tags.transpose(0, 1)
|
||||
mask = mask.transpose(0, 1)
|
||||
all_path_score = self._normalizer_likelihood(feats, mask)
|
||||
gold_path_score = self._glod_score(feats, tags, mask)
|
||||
|
||||
return all_path_score - gold_path_score
|
||||
|
||||
def viterbi_decode(self, data, mask, get_score=False):
|
||||
"""
|
||||
Given a feats matrix, return best decode path and best score.
|
||||
:param data:FloatTensor, batch_size x max_len x tag_size
|
||||
:param mask:ByteTensor batch_size x max_len
|
||||
:param get_score: bool, whether to output the decode score.
|
||||
:return: scores, paths
|
||||
"""
|
||||
batch_size, seq_len, n_tags = data.size()
|
||||
data = data.transpose(0, 1).data # L, B, H
|
||||
mask = mask.transpose(0, 1).data.float() # L, B
|
||||
|
||||
# dp
|
||||
vpath = data.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long)
|
||||
vscore = data[0]
|
||||
if self.include_start_end_trans:
|
||||
vscore += self.start_scores.view(1. -1)
|
||||
for i in range(1, seq_len):
|
||||
prev_score = vscore.view(batch_size, n_tags, 1)
|
||||
cur_score = data[i].view(batch_size, 1, n_tags)
|
||||
trans_score = self.trans_m.view(1, n_tags, n_tags).data
|
||||
score = prev_score + trans_score + cur_score
|
||||
best_score, best_dst = score.max(1)
|
||||
vpath[i] = best_dst
|
||||
vscore = best_score * mask[i].view(batch_size, 1) + vscore * (1 - mask[i]).view(batch_size, 1)
|
||||
|
||||
if self.include_start_end_trans:
|
||||
vscore += self.end_scores.view(1, -1)
|
||||
|
||||
# backtrace
|
||||
batch_idx = torch.arange(batch_size, dtype=torch.long, device=data.device)
|
||||
seq_idx = torch.arange(seq_len, dtype=torch.long, device=data.device)
|
||||
lens = (mask.long().sum(0) - 1)
|
||||
# idxes [L, B], batched idx from seq_len-1 to 0
|
||||
idxes = (lens.view(1,-1) - seq_idx.view(-1,1)) % seq_len
|
||||
|
||||
ans = data.new_empty((seq_len, batch_size), dtype=torch.long)
|
||||
ans_score, last_tags = vscore.max(1)
|
||||
ans[idxes[0], batch_idx] = last_tags
|
||||
for i in range(seq_len - 1):
|
||||
last_tags = vpath[idxes[i], batch_idx, last_tags]
|
||||
ans[idxes[i+1], batch_idx] = last_tags
|
||||
|
||||
for b in range(batch_size):
|
||||
path = [indices[b]]
|
||||
for i in range(seq_lens[b] - 2, -1, -1):
|
||||
index = paths[b, i, path[-1]]
|
||||
path.append(index)
|
||||
final_paths.append(path[::-1])
|
||||
if get_score:
|
||||
return list(zip(final_paths, max_scores.detach().cpu().numpy()))
|
||||
else:
|
||||
return final_paths
|
||||
return ans_score, ans.transpose(0, 1)
|
||||
return ans.transpose(0, 1)
|
Loading…
Reference in New Issue
Block a user