diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index e1964d99..c2a10210 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -9,7 +9,7 @@ from fastNLP.core.vocabulary import Vocabulary _READERS = {} -class DataSet(list): +class DataSet(object): """A DataSet object is a list of Instance objects. """ diff --git a/fastNLP/modules/decoder/CRF.py b/fastNLP/modules/decoder/CRF.py index 991927da..cd68d35d 100644 --- a/fastNLP/modules/decoder/CRF.py +++ b/fastNLP/modules/decoder/CRF.py @@ -31,7 +31,7 @@ class ConditionalRandomField(nn.Module): self.tag_size = tag_size # the meaning of entry in this matrix is (from_tag_id, to_tag_id) score - self.transition_m = nn.Parameter(torch.randn(tag_size, tag_size)) + self.trans_m = nn.Parameter(torch.randn(tag_size, tag_size)) if self.include_start_end_trans: self.start_scores = nn.Parameter(torch.randn(tag_size)) self.end_scores = nn.Parameter(torch.randn(tag_size)) @@ -39,137 +39,121 @@ class ConditionalRandomField(nn.Module): # self.reset_parameter() initial_parameter(self, initial_method) def reset_parameter(self): - nn.init.xavier_normal_(self.transition_m) + nn.init.xavier_normal_(self.trans_m) if self.include_start_end_trans: nn.init.normal_(self.start_scores) nn.init.normal_(self.end_scores) - def _normalizer_likelihood(self, feats, masks): + def _normalizer_likelihood(self, logits, mask): """ Computes the (batch_size,) denominator term for the log-likelihood, which is the sum of the likelihoods across all possible state sequences. - :param feats:FloatTensor, batch_size x max_len x tag_size - :param masks:ByteTensor, batch_size x max_len + :param logits:FloatTensor, max_len x batch_size x tag_size + :param mask:ByteTensor, max_len x batch_size :return:FloatTensor, batch_size """ - batch_size, max_len, _ = feats.size() - - # alpha, batch_size x tag_size + seq_len, batch_size, n_tags = logits.size() + alpha = logits[0] if self.include_start_end_trans: - alpha = self.start_scores.view(1, -1) + feats[:, 0] - else: - alpha = feats[:, 0] + alpha += self.start_scores.view(1, -1) - # broadcast_trans_m, the meaning of entry in this matrix is [batch_idx, to_tag_id, from_tag_id] - broadcast_trans_m = self.transition_m.permute( - 1, 0).unsqueeze(0).repeat(batch_size, 1, 1) - # loop - for i in range(1, max_len): - emit_score = feats[:, i].unsqueeze(2) - new_alpha = broadcast_trans_m + alpha.unsqueeze(1) + emit_score - - new_alpha = log_sum_exp(new_alpha, dim=2) - - alpha = new_alpha * \ - masks[:, i:i + 1].float() + alpha * \ - (1 - masks[:, i:i + 1].float()) - - if self.include_start_end_trans: - alpha = alpha + self.end_scores.view(1, -1) - - return log_sum_exp(alpha) - - def _glod_score(self, feats, tags, masks): - """ - Compute the score for the gold path. - :param feats: FloatTensor, batch_size x max_len x tag_size - :param tags: LongTensor, batch_size x max_len - :param masks: ByteTensor, batch_size x max_len - :return:FloatTensor, batch_size - """ - batch_size, max_len, _ = feats.size() - - # alpha, B x 1 - if self.include_start_end_trans: - alpha = self.start_scores.view(1, -1).repeat(batch_size, 1).gather(dim=1, index=tags[:, :1]) + \ - feats[:, 0].gather(dim=1, index=tags[:, :1]) - else: - alpha = feats[:, 0].gather(dim=1, index=tags[:, :1]) - - for i in range(1, max_len): - trans_score = self.transition_m[( - tags[:, i - 1], tags[:, i])].unsqueeze(1) - emit_score = feats[:, i].gather(dim=1, index=tags[:, i:i + 1]) - new_alpha = alpha + trans_score + emit_score - - alpha = new_alpha * \ - masks[:, i:i + 1].float() + alpha * \ - (1 - masks[:, i:i + 1].float()) - - if self.include_start_end_trans: - last_tag_index = masks.cumsum(dim=1, dtype=torch.long)[:, -1:] - 1 - last_from_tag_id = tags.gather(dim=1, index=last_tag_index) - trans_score = self.end_scores.view( - 1, -1).repeat(batch_size, 1).gather(dim=1, index=last_from_tag_id) - alpha = alpha + trans_score - - return alpha.squeeze(1) - - def forward(self, feats, tags, masks): - """ - Calculate the neg log likelihood - :param feats:FloatTensor, batch_size x max_len x tag_size - :param tags:LongTensor, batch_size x max_len - :param masks:ByteTensor batch_size x max_len - :return:FloatTensor, batch_size - """ - all_path_score = self._normalizer_likelihood(feats, masks) - gold_path_score = self._glod_score(feats, tags, masks) - - return all_path_score - gold_path_score - - def viterbi_decode(self, feats, masks, get_score=False): - """ - Given a feats matrix, return best decode path and best score. - :param feats: - :param masks: - :param get_score: bool, whether to output the decode score. - :return:List[Tuple(List, float)], - """ - batch_size, max_len, tag_size = feats.size() - - paths = torch.zeros(batch_size, max_len - 1, self.tag_size) - if self.include_start_end_trans: - alpha = self.start_scores.repeat(batch_size, 1) + feats[:, 0] - else: - alpha = feats[:, 0] - for i in range(1, max_len): - new_alpha = alpha.clone() - for t in range(self.tag_size): - pre_scores = self.transition_m[:, t].view( - 1, self.tag_size) + alpha - max_score, indices = pre_scores.max(dim=1) - new_alpha[:, t] = max_score + feats[:, i, t] - paths[:, i - 1, t] = indices - alpha = new_alpha * masks[:, i:i + 1].float() + alpha * (1 - masks[:, i:i + 1].float()) + for i in range(1, seq_len): + emit_score = logits[i].view(batch_size, 1, n_tags) + trans_score = self.trans_m.view(1, n_tags, n_tags) + tmp = alpha.view(batch_size, n_tags, 1) + emit_score + trans_score + alpha = log_sum_exp(tmp, 1) * mask[i].view(batch_size, 1) + alpha * (1 - mask[i]).view(batch_size, 1) if self.include_start_end_trans: alpha += self.end_scores.view(1, -1) - max_scores, indices = alpha.max(dim=1) - indices = indices.cpu().numpy() - final_paths = [] - paths = paths.cpu().numpy().astype(int) + return log_sum_exp(alpha, 1) - seq_lens = masks.cumsum(dim=1, dtype=torch.long)[:, -1] + def _glod_score(self, logits, tags, mask): + """ + Compute the score for the gold path. + :param logits: FloatTensor, max_len x batch_size x tag_size + :param tags: LongTensor, max_len x batch_size + :param mask: ByteTensor, max_len x batch_size + :return:FloatTensor, batch_size + """ + seq_len, batch_size, _ = logits.size() + batch_idx = torch.arange(batch_size, dtype=torch.long, device=logits.device) + seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device) + + # trans_socre [L-1, B] + trans_score = self.trans_m[tags[:seq_len-1], tags[1:]] * mask[1:, :] + # emit_score [L, B] + emit_score = logits[seq_idx.view(-1,1), batch_idx.view(1,-1), tags] * mask + # score [L-1, B] + score = trans_score + emit_score[:seq_len-1, :] + score = score.sum(0) + emit_score[-1] + if self.include_start_end_trans: + st_scores = self.start_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[0]] + last_idx = masks.long().sum(0) + ed_scores = self.end_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[last_idx, batch_idx]] + score += st_scores + ed_scores + # return [B,] + return score + + def forward(self, feats, tags, mask): + """ + Calculate the neg log likelihood + :param feats:FloatTensor, batch_size x max_len x tag_size + :param tags:LongTensor, batch_size x max_len + :param mask:ByteTensor batch_size x max_len + :return:FloatTensor, batch_size + """ + feats = feats.transpose(0, 1) + tags = tags.transpose(0, 1) + mask = mask.transpose(0, 1) + all_path_score = self._normalizer_likelihood(feats, mask) + gold_path_score = self._glod_score(feats, tags, mask) + + return all_path_score - gold_path_score + + def viterbi_decode(self, data, mask, get_score=False): + """ + Given a feats matrix, return best decode path and best score. + :param data:FloatTensor, batch_size x max_len x tag_size + :param mask:ByteTensor batch_size x max_len + :param get_score: bool, whether to output the decode score. + :return: scores, paths + """ + batch_size, seq_len, n_tags = data.size() + data = data.transpose(0, 1).data # L, B, H + mask = mask.transpose(0, 1).data.float() # L, B + + # dp + vpath = data.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) + vscore = data[0] + if self.include_start_end_trans: + vscore += self.start_scores.view(1. -1) + for i in range(1, seq_len): + prev_score = vscore.view(batch_size, n_tags, 1) + cur_score = data[i].view(batch_size, 1, n_tags) + trans_score = self.trans_m.view(1, n_tags, n_tags).data + score = prev_score + trans_score + cur_score + best_score, best_dst = score.max(1) + vpath[i] = best_dst + vscore = best_score * mask[i].view(batch_size, 1) + vscore * (1 - mask[i]).view(batch_size, 1) + + if self.include_start_end_trans: + vscore += self.end_scores.view(1, -1) + + # backtrace + batch_idx = torch.arange(batch_size, dtype=torch.long, device=data.device) + seq_idx = torch.arange(seq_len, dtype=torch.long, device=data.device) + lens = (mask.long().sum(0) - 1) + # idxes [L, B], batched idx from seq_len-1 to 0 + idxes = (lens.view(1,-1) - seq_idx.view(-1,1)) % seq_len + + ans = data.new_empty((seq_len, batch_size), dtype=torch.long) + ans_score, last_tags = vscore.max(1) + ans[idxes[0], batch_idx] = last_tags + for i in range(seq_len - 1): + last_tags = vpath[idxes[i], batch_idx, last_tags] + ans[idxes[i+1], batch_idx] = last_tags - for b in range(batch_size): - path = [indices[b]] - for i in range(seq_lens[b] - 2, -1, -1): - index = paths[b, i, path[-1]] - path.append(index) - final_paths.append(path[::-1]) if get_score: - return list(zip(final_paths, max_scores.detach().cpu().numpy())) - else: - return final_paths + return ans_score, ans.transpose(0, 1) + return ans.transpose(0, 1) \ No newline at end of file