Merge remote-tracking branch 'origin/dataset' into dataset

This commit is contained in:
FengZiYjun 2018-11-09 19:53:08 +08:00
commit 12e9a93b52
8 changed files with 518 additions and 1 deletions

11
fastNLP/api/api.py Normal file
View File

@ -0,0 +1,11 @@
class API:
def __init__(self):
pass
def predict(self):
pass
def load(self):
pass

View File

@ -8,7 +8,6 @@ class Pipeline:
def add_processor(self, processor):
assert isinstance(processor, Processor), "Must be a Processor, not {}.".format(type(processor))
processor_name = type(processor)
self.pipeline.append(processor)
def process(self, dataset):

View File

@ -0,0 +1,135 @@
from torch import nn
import torch
import torch.nn.functional as F
from fastNLP.modules.decoder.MLP import MLP
from fastNLP.models.base_model import BaseModel
from reproduction.chinese_word_segment.utils import seq_lens_to_mask
class CWSBiLSTMEncoder(BaseModel):
def __init__(self, vocab_num, embed_dim=100, bigram_vocab_num=None, bigram_embed_dim=100, num_bigram_per_char=None,
hidden_size=200, bidirectional=True, embed_drop_p=None, num_layers=1):
super().__init__()
self.input_size = 0
self.num_bigram_per_char = num_bigram_per_char
self.bidirectional = bidirectional
self.num_layers = num_layers
self.embed_drop_p = embed_drop_p
if self.bidirectional:
self.hidden_size = hidden_size//2
self.num_directions = 2
else:
self.hidden_size = hidden_size
self.num_directions = 1
if not bigram_vocab_num is None:
assert not bigram_vocab_num is None, "Specify num_bigram_per_char."
if vocab_num is not None:
self.char_embedding = nn.Embedding(num_embeddings=vocab_num, embedding_dim=embed_dim)
self.input_size += embed_dim
if bigram_vocab_num is not None:
self.bigram_embedding = nn.Embedding(num_embeddings=bigram_vocab_num, embedding_dim=bigram_embed_dim)
self.input_size += self.num_bigram_per_char*bigram_embed_dim
if self.num_criterion!=None:
if bidirectional:
self.backward_criterion_embedding = nn.Embedding(num_embeddings=self.num_criterion,
embedding_dim=self.hidden_size)
self.forward_criterion_embedding = nn.Embedding(num_embeddings=self.num_criterion,
embedding_dim=self.hidden_size)
if not self.embed_drop_p is None:
self.embedding_drop = nn.Dropout(p=self.embed_drop_p)
self.lstm = nn.LSTM(input_size=self.input_size, hidden_size=self.hidden_size, bidirectional=self.bidirectional,
batch_first=True, num_layers=self.num_layers)
self.reset_parameters()
def reset_parameters(self):
for name, param in self.named_parameters():
if 'bias_hh' in name:
nn.init.constant_(param, 0)
elif 'bias_ih' in name:
nn.init.constant_(param, 1)
else:
nn.init.xavier_uniform_(param)
def init_embedding(self, embedding, embed_name):
if embed_name == 'bigram':
self.bigram_embedding.weight.data = torch.from_numpy(embedding)
elif embed_name == 'char':
self.char_embedding.weight.data = torch.from_numpy(embedding)
def forward(self, chars, bigrams=None, seq_lens=None):
batch_size, max_len = chars.size()
x_tensor = self.char_embedding(chars)
if not bigrams is None:
bigram_tensor = self.bigram_embedding(bigrams).view(batch_size, max_len, -1)
x_tensor = torch.cat([x_tensor, bigram_tensor], dim=2)
sorted_lens, sorted_indices = torch.sort(seq_lens, descending=True)
packed_x = nn.utils.rnn.pack_padded_sequence(x_tensor[sorted_indices], sorted_lens, batch_first=True)
outputs, _ = self.lstm(packed_x)
outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True)
_, desorted_indices = torch.sort(sorted_indices, descending=False)
outputs = outputs[desorted_indices]
return outputs
class CWSBiLSTMSegApp(BaseModel):
def __init__(self, vocab_num, embed_dim=100, bigram_vocab_num=None, bigram_embed_dim=100, num_bigram_per_char=None,
hidden_size=200, bidirectional=True, embed_drop_p=None, num_layers=1, tag_size=2):
super(CWSBiLSTMSegApp, self).__init__()
self.tag_size = tag_size
self.encoder_model = CWSBiLSTMEncoder(vocab_num, embed_dim, bigram_vocab_num, bigram_embed_dim, num_bigram_per_char,
hidden_size, bidirectional, embed_drop_p, num_layers)
size_layer = [hidden_size, 100, tag_size]
self.decoder_model = MLP(size_layer)
def forward(self, **kwargs):
chars = kwargs['chars']
if 'bigram' in kwargs:
bigrams = kwargs['bigrams']
else:
bigrams = None
seq_lens = kwargs['seq_lens']
feats = self.encoder_model(chars, bigrams, seq_lens)
probs = self.decoder_model(feats)
pred_dict = {}
pred_dict['seq_lens'] = seq_lens
pred_dict['pred_prob'] = probs
return pred_dict
def loss_fn(self, pred_dict, true_dict):
seq_lens = pred_dict['seq_lens']
masks = seq_lens_to_mask(seq_lens).float()
pred_prob = pred_dict['pred_prob']
true_y = true_dict['tags']
# TODO 当前把loss写死了
loss = F.cross_entropy(pred_prob.view(-1, self.tag_size),
true_y.view(-1), reduction='none')*masks.view(-1)/torch.sum(masks)
return loss

View File

@ -0,0 +1,283 @@
import re
from fastNLP.core.field import SeqLabelField
from fastNLP.core.vocabulary import Vocabulary
from fastNLP.core.dataset import DataSet
from fastNLP.api.processor import Processor
_SPECIAL_TAG_PATTERN = '<[a-zA-Z]+>'
class FullSpaceToHalfSpaceProcessor(Processor):
def __init__(self, field_name, change_alpha=True, change_digit=True, change_punctuation=True,
change_space=True):
super(FullSpaceToHalfSpaceProcessor, self).__init__(field_name, None)
self.change_alpha = change_alpha
self.change_digit = change_digit
self.change_punctuation = change_punctuation
self.change_space = change_space
FH_SPACE = [(u" ", u" ")]
FH_NUM = [
(u"", u"0"), (u"", u"1"), (u"", u"2"), (u"", u"3"), (u"", u"4"),
(u"", u"5"), (u"", u"6"), (u"", u"7"), (u"", u"8"), (u"", u"9")]
FH_ALPHA = [
(u"", u"a"), (u"", u"b"), (u"", u"c"), (u"", u"d"), (u"", u"e"),
(u"", u"f"), (u"", u"g"), (u"", u"h"), (u"", u"i"), (u"", u"j"),
(u"", u"k"), (u"", u"l"), (u"", u"m"), (u"", u"n"), (u"", u"o"),
(u"", u"p"), (u"", u"q"), (u"", u"r"), (u"", u"s"), (u"", u"t"),
(u"", u"u"), (u"", u"v"), (u"", u"w"), (u"", u"x"), (u"", u"y"),
(u"", u"z"),
(u"", u"A"), (u"", u"B"), (u"", u"C"), (u"", u"D"), (u"", u"E"),
(u"", u"F"), (u"", u"G"), (u"", u"H"), (u"", u"I"), (u"", u"J"),
(u"", u"K"), (u"", u"L"), (u"", u"M"), (u"", u"N"), (u"", u"O"),
(u"", u"P"), (u"", u"Q"), (u"", u"R"), (u"", u"S"), (u"", u"T"),
(u"", u"U"), (u"", u"V"), (u"", u"W"), (u"", u"X"), (u"", u"Y"),
(u"", u"Z")]
# 谨慎使用标点符号转换, 因为"512特大地震"转换后可能就成了"5.12特大地震"
FH_PUNCTUATION = [
(u'', u'%'), (u'', u'!'), (u'', u'\"'), (u'', u'\''), (u'', u'#'),
(u'', u'$'), (u'', u'&'), (u'', u'('), (u'', u')'), (u'', u'*'),
(u'', u'+'), (u'', u','), (u'', u'-'), (u'', u'.'), (u'', u'/'),
(u'', u':'), (u'', u';'), (u'', u'<'), (u'', u'='), (u'', u'>'),
(u'', u'?'), (u'', u'@'), (u'', u'['), (u'', u']'), (u'', u'\\'),
(u'', u'^'), (u'_', u'_'), (u'', u'`'), (u'', u'~'), (u'', u'{'),
(u'', u'}'), (u'', u'|')]
FHs = []
if self.change_alpha:
FHs = FH_ALPHA
if self.change_digit:
FHs += FH_NUM
if self.change_punctuation:
FHs += FH_PUNCTUATION
if self.change_space:
FHs += FH_SPACE
self.convert_map = {k: v for k, v in FHs}
def process(self, dataset):
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
for ins in dataset:
sentence = ins[self.field_name].text
new_sentence = [None]*len(sentence)
for idx, char in enumerate(sentence):
if char in self.convert_map:
char = self.convert_map[char]
new_sentence[idx] = char
ins[self.field_name].text = ''.join(new_sentence)
return dataset
class SpeicalSpanProcessor(Processor):
# 这个类会将句子中的special span转换为对应的内容。
def __init__(self, field_name, new_added_field_name=None):
super(SpeicalSpanProcessor, self).__init__(field_name, new_added_field_name)
self.span_converters = []
def process(self, dataset):
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
for ins in dataset:
sentence = ins[self.field_name].text
for span_converter in self.span_converters:
sentence = span_converter.find_certain_span_and_replace(sentence)
if self.new_added_field_name!=self.field_name:
new_text_field = TextField(sentence, is_target=False)
ins[self.new_added_field_name] = new_text_field
else:
ins[self.field_name].text = sentence
return dataset
def add_span_converter(self, converter):
assert isinstance(converter, SpanConverterBase), "Only SpanConverterBase is allowed, not {}."\
.format(type(converter))
self.span_converters.append(converter)
class CWSCharSegProcessor(Processor):
def __init__(self, field_name, new_added_field_name):
super(CWSCharSegProcessor, self).__init__(field_name, new_added_field_name)
def process(self, dataset):
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
for ins in dataset:
sentence = ins[self.field_name].text
chars = self._split_sent_into_chars(sentence)
new_token_field = TokenListFiled(chars, is_target=False)
ins[self.new_added_field_name] = new_token_field
return dataset
def _split_sent_into_chars(self, sentence):
sp_tag_match_iter = re.finditer(_SPECIAL_TAG_PATTERN, sentence)
sp_spans = [match_span.span() for match_span in sp_tag_match_iter]
sp_span_idx = 0
in_span_flag = False
chars = []
num_spans = len(sp_spans)
for idx, char in enumerate(sentence):
if sp_span_idx<num_spans and idx == sp_spans[sp_span_idx][0]:
in_span_flag = True
elif in_span_flag and sp_span_idx<num_spans and idx == sp_spans[sp_span_idx][1] - 1:
chars.append(sentence[sp_spans[sp_span_idx]
[0]:sp_spans[sp_span_idx][1]])
in_span_flag = False
sp_span_idx += 1
elif not in_span_flag:
# TODO 需要谨慎考虑如何处理空格的问题
if char != ' ':
chars.append(char)
else:
pass
return chars
class CWSTagProcessor(Processor):
def __init__(self, field_name, new_added_field_name=None):
super(CWSTagProcessor, self).__init__(field_name, new_added_field_name)
def _generate_tag(self, sentence):
sp_tag_match_iter = re.finditer(_SPECIAL_TAG_PATTERN, sentence)
sp_spans = [match_span.span() for match_span in sp_tag_match_iter]
sp_span_idx = 0
in_span_flag = False
tag_list = []
word_len = 0
num_spans = len(sp_spans)
for idx, char in enumerate(sentence):
if sp_span_idx<num_spans and idx == sp_spans[sp_span_idx][0]:
in_span_flag = True
elif in_span_flag and sp_span_idx<num_spans and idx == sp_spans[sp_span_idx][1] - 1:
word_len += 1
in_span_flag = False
sp_span_idx += 1
elif not in_span_flag:
if char == ' ':
if word_len!=0:
tag_list.extend(self._tags_from_word_len(word_len))
word_len = 0
else:
word_len += 1
else:
pass
if word_len!=0:
tag_list.extend(self._tags_from_word_len(word_len))
return tag_list
def process(self, dataset):
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
for ins in dataset:
sentence = ins[self.field_name].text
tag_list = self._generate_tag(sentence)
new_tag_field = SeqLabelField(tag_list)
ins[self.new_added_field_name] = new_tag_field
return dataset
def _tags_from_word_len(self, word_len):
raise NotImplementedError
class CWSSegAppTagProcessor(CWSTagProcessor):
def __init__(self, field_name, new_added_field_name=None):
super(CWSSegAppTagProcessor, self).__init__(field_name, new_added_field_name)
def _tags_from_word_len(self, word_len):
tag_list = []
for _ in range(word_len-1):
tag_list.append(0)
tag_list.append(1)
return tag_list
class BigramProcessor(Processor):
def __init__(self, field_name, new_added_fielf_name=None):
super(BigramProcessor, self).__init__(field_name, new_added_fielf_name)
def process(self, dataset):
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
for ins in dataset:
characters = ins[self.field_name].content
bigrams = self._generate_bigram(characters)
new_token_field = TokenListFiled(bigrams)
ins[self.new_added_field_name] = new_token_field
return dataset
def _generate_bigram(self, characters):
pass
class Pre2Post2BigramProcessor(BigramProcessor):
def __init__(self, field_name, new_added_fielf_name=None):
super(BigramProcessor, self).__init__(field_name, new_added_fielf_name)
def _generate_bigram(self, characters):
bigrams = []
characters = ['<SOS>', '<SOS>'] + characters + ['<EOS>', '<EOS>']
for idx in range(2, len(characters)-2):
cur_char = characters[idx]
pre_pre_char = characters[idx-2]
pre_char = characters[idx-1]
post_char = characters[idx+1]
post_post_char = characters[idx+2]
pre_pre_cur_bigram = pre_pre_char + cur_char
pre_cur_bigram = pre_char + cur_char
cur_post_bigram = cur_char + post_char
cur_post_post_bigram = cur_char + post_post_char
bigrams.extend([pre_pre_char, pre_char, post_char, post_post_char,
pre_pre_cur_bigram, pre_cur_bigram,
cur_post_bigram, cur_post_post_bigram])
return bigrams
# 这里需要建立vocabulary了但是遇到了以下的问题
# (1) 如果使用Processor的方式的话但是在这种情况返回的不是dataset。所以建立vocabulary的工作用另外的方式实现不借用
# Processor了
class IndexProcessor(Processor):
def __init__(self, vocab, field_name):
assert isinstance(vocab, Vocabulary), "Only Vocabulary class is allowed, not {}.".format(type(vocab))
super(IndexProcessor, self).__init__(field_name, None)
self.vocab = vocab
def set_vocab(self, vocab):
assert isinstance(vocab, Vocabulary), "Only Vocabulary class is allowed, not {}.".format(type(vocab))
self.vocab = vocab
def process(self, dataset):
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
for ins in dataset:
tokens = ins[self.field_name].content
index = [self.vocab.to_index(token) for token in tokens]
ins[self.field_name]._index = index
return dataset
class VocabProcessor(Processor):
def __init__(self, field_name):
super(VocabProcessor, self).__init__(field_name, None)
self.vocab = Vocabulary()
def process(self, dataset):
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
for ins in dataset:
tokens = ins[self.field_name].content
self.vocab.update(tokens)
def get_vocab(self):
self.vocab.build_vocab()
return self.vocab

View File

@ -0,0 +1,3 @@

View File

@ -0,0 +1,86 @@
import torch
def seq_lens_to_mask(seq_lens):
batch_size = seq_lens.size(0)
max_len = seq_lens.max()
indexes = torch.arange(max_len).view(1, -1).repeat(batch_size, 1).to(seq_lens.device)
masks = indexes.lt(seq_lens.unsqueeze(1))
return masks
def cut_long_training_sentences(sentences, max_sample_length=200):
cutted_sentence = []
for sent in sentences:
sent_no_space = sent.replace(' ', '')
if len(sent_no_space) > max_sample_length:
parts = sent.strip().split()
new_line = ''
length = 0
for part in parts:
length += len(part)
new_line += part + ' '
if length > max_sample_length:
new_line = new_line[:-1]
cutted_sentence.append(new_line)
length = 0
new_line = ''
if new_line != '':
cutted_sentence.append(new_line[:-1])
else:
cutted_sentence.append(sent)
return cutted_sentence
from torch import nn
import torch.nn.functional as F
class FocalLoss(nn.Module):
r"""
This criterion is a implemenation of Focal Loss, which is proposed in
Focal Loss for Dense Object Detection.
Loss(x, class) = - \alpha (1-softmax(x)[class])^gamma \log(softmax(x)[class])
The losses are averaged across observations for each minibatch.
Args:
alpha(1D Tensor, Variable) : the scalar factor for this criterion
gamma(float, double) : gamma > 0; reduces the relative loss for well-classified examples (p > .5),
putting more focus on hard, misclassified examples
size_average(bool): size_average(bool): By default, the losses are averaged over observations for each minibatch.
However, if the field size_average is set to False, the losses are
instead summed for each minibatch.
"""
def __init__(self, class_num, gamma=2, size_average=True, reduce=False):
super(FocalLoss, self).__init__()
self.gamma = gamma
self.class_num = class_num
self.size_average = size_average
self.reduce = reduce
def forward(self, inputs, targets):
N = inputs.size(0)
C = inputs.size(1)
P = F.softmax(inputs, dim=-1)
class_mask = inputs.data.new(N, C).fill_(0)
class_mask.requires_grad = True
ids = targets.view(-1, 1)
class_mask = class_mask.scatter(1, ids.data, 1.)
probs = (P * class_mask).sum(1).view(-1, 1)
log_p = probs.log()
batch_loss = - (torch.pow((1 - probs), self.gamma)) * log_p
if self.reduce:
if self.size_average:
loss = batch_loss.mean()
else:
loss = batch_loss.sum()
return loss
return batch_loss