mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-11 10:05:30 +08:00
Merge remote-tracking branch 'origin/dataset' into dataset
This commit is contained in:
commit
12e9a93b52
11
fastNLP/api/api.py
Normal file
11
fastNLP/api/api.py
Normal file
@ -0,0 +1,11 @@
|
||||
|
||||
|
||||
class API:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def predict(self):
|
||||
pass
|
||||
|
||||
def load(self):
|
||||
pass
|
@ -8,7 +8,6 @@ class Pipeline:
|
||||
|
||||
def add_processor(self, processor):
|
||||
assert isinstance(processor, Processor), "Must be a Processor, not {}.".format(type(processor))
|
||||
processor_name = type(processor)
|
||||
self.pipeline.append(processor)
|
||||
|
||||
def process(self, dataset):
|
||||
|
0
reproduction/chinese_word_segment/model/__init__.py
Normal file
0
reproduction/chinese_word_segment/model/__init__.py
Normal file
135
reproduction/chinese_word_segment/model/cws_model.py
Normal file
135
reproduction/chinese_word_segment/model/cws_model.py
Normal file
@ -0,0 +1,135 @@
|
||||
|
||||
from torch import nn
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from fastNLP.modules.decoder.MLP import MLP
|
||||
from fastNLP.models.base_model import BaseModel
|
||||
from reproduction.chinese_word_segment.utils import seq_lens_to_mask
|
||||
|
||||
class CWSBiLSTMEncoder(BaseModel):
|
||||
def __init__(self, vocab_num, embed_dim=100, bigram_vocab_num=None, bigram_embed_dim=100, num_bigram_per_char=None,
|
||||
hidden_size=200, bidirectional=True, embed_drop_p=None, num_layers=1):
|
||||
super().__init__()
|
||||
|
||||
self.input_size = 0
|
||||
self.num_bigram_per_char = num_bigram_per_char
|
||||
self.bidirectional = bidirectional
|
||||
self.num_layers = num_layers
|
||||
self.embed_drop_p = embed_drop_p
|
||||
if self.bidirectional:
|
||||
self.hidden_size = hidden_size//2
|
||||
self.num_directions = 2
|
||||
else:
|
||||
self.hidden_size = hidden_size
|
||||
self.num_directions = 1
|
||||
|
||||
if not bigram_vocab_num is None:
|
||||
assert not bigram_vocab_num is None, "Specify num_bigram_per_char."
|
||||
|
||||
if vocab_num is not None:
|
||||
self.char_embedding = nn.Embedding(num_embeddings=vocab_num, embedding_dim=embed_dim)
|
||||
self.input_size += embed_dim
|
||||
|
||||
if bigram_vocab_num is not None:
|
||||
self.bigram_embedding = nn.Embedding(num_embeddings=bigram_vocab_num, embedding_dim=bigram_embed_dim)
|
||||
self.input_size += self.num_bigram_per_char*bigram_embed_dim
|
||||
|
||||
if self.num_criterion!=None:
|
||||
if bidirectional:
|
||||
self.backward_criterion_embedding = nn.Embedding(num_embeddings=self.num_criterion,
|
||||
embedding_dim=self.hidden_size)
|
||||
self.forward_criterion_embedding = nn.Embedding(num_embeddings=self.num_criterion,
|
||||
embedding_dim=self.hidden_size)
|
||||
|
||||
if not self.embed_drop_p is None:
|
||||
self.embedding_drop = nn.Dropout(p=self.embed_drop_p)
|
||||
|
||||
self.lstm = nn.LSTM(input_size=self.input_size, hidden_size=self.hidden_size, bidirectional=self.bidirectional,
|
||||
batch_first=True, num_layers=self.num_layers)
|
||||
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
for name, param in self.named_parameters():
|
||||
if 'bias_hh' in name:
|
||||
nn.init.constant_(param, 0)
|
||||
elif 'bias_ih' in name:
|
||||
nn.init.constant_(param, 1)
|
||||
else:
|
||||
nn.init.xavier_uniform_(param)
|
||||
|
||||
def init_embedding(self, embedding, embed_name):
|
||||
if embed_name == 'bigram':
|
||||
self.bigram_embedding.weight.data = torch.from_numpy(embedding)
|
||||
elif embed_name == 'char':
|
||||
self.char_embedding.weight.data = torch.from_numpy(embedding)
|
||||
|
||||
|
||||
def forward(self, chars, bigrams=None, seq_lens=None):
|
||||
|
||||
batch_size, max_len = chars.size()
|
||||
|
||||
x_tensor = self.char_embedding(chars)
|
||||
|
||||
if not bigrams is None:
|
||||
bigram_tensor = self.bigram_embedding(bigrams).view(batch_size, max_len, -1)
|
||||
x_tensor = torch.cat([x_tensor, bigram_tensor], dim=2)
|
||||
|
||||
sorted_lens, sorted_indices = torch.sort(seq_lens, descending=True)
|
||||
packed_x = nn.utils.rnn.pack_padded_sequence(x_tensor[sorted_indices], sorted_lens, batch_first=True)
|
||||
|
||||
outputs, _ = self.lstm(packed_x)
|
||||
outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True)
|
||||
|
||||
_, desorted_indices = torch.sort(sorted_indices, descending=False)
|
||||
outputs = outputs[desorted_indices]
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class CWSBiLSTMSegApp(BaseModel):
|
||||
def __init__(self, vocab_num, embed_dim=100, bigram_vocab_num=None, bigram_embed_dim=100, num_bigram_per_char=None,
|
||||
hidden_size=200, bidirectional=True, embed_drop_p=None, num_layers=1, tag_size=2):
|
||||
super(CWSBiLSTMSegApp, self).__init__()
|
||||
|
||||
self.tag_size = tag_size
|
||||
|
||||
self.encoder_model = CWSBiLSTMEncoder(vocab_num, embed_dim, bigram_vocab_num, bigram_embed_dim, num_bigram_per_char,
|
||||
hidden_size, bidirectional, embed_drop_p, num_layers)
|
||||
|
||||
size_layer = [hidden_size, 100, tag_size]
|
||||
self.decoder_model = MLP(size_layer)
|
||||
|
||||
|
||||
def forward(self, **kwargs):
|
||||
chars = kwargs['chars']
|
||||
if 'bigram' in kwargs:
|
||||
bigrams = kwargs['bigrams']
|
||||
else:
|
||||
bigrams = None
|
||||
seq_lens = kwargs['seq_lens']
|
||||
|
||||
feats = self.encoder_model(chars, bigrams, seq_lens)
|
||||
probs = self.decoder_model(feats)
|
||||
|
||||
pred_dict = {}
|
||||
pred_dict['seq_lens'] = seq_lens
|
||||
pred_dict['pred_prob'] = probs
|
||||
|
||||
return pred_dict
|
||||
|
||||
def loss_fn(self, pred_dict, true_dict):
|
||||
seq_lens = pred_dict['seq_lens']
|
||||
masks = seq_lens_to_mask(seq_lens).float()
|
||||
|
||||
pred_prob = pred_dict['pred_prob']
|
||||
true_y = true_dict['tags']
|
||||
|
||||
# TODO 当前把loss写死了
|
||||
loss = F.cross_entropy(pred_prob.view(-1, self.tag_size),
|
||||
true_y.view(-1), reduction='none')*masks.view(-1)/torch.sum(masks)
|
||||
|
||||
|
||||
return loss
|
||||
|
283
reproduction/chinese_word_segment/process/cws_processor.py
Normal file
283
reproduction/chinese_word_segment/process/cws_processor.py
Normal file
@ -0,0 +1,283 @@
|
||||
|
||||
import re
|
||||
|
||||
|
||||
from fastNLP.core.field import SeqLabelField
|
||||
from fastNLP.core.vocabulary import Vocabulary
|
||||
from fastNLP.core.dataset import DataSet
|
||||
|
||||
from fastNLP.api.processor import Processor
|
||||
|
||||
|
||||
_SPECIAL_TAG_PATTERN = '<[a-zA-Z]+>'
|
||||
|
||||
class FullSpaceToHalfSpaceProcessor(Processor):
|
||||
def __init__(self, field_name, change_alpha=True, change_digit=True, change_punctuation=True,
|
||||
change_space=True):
|
||||
super(FullSpaceToHalfSpaceProcessor, self).__init__(field_name, None)
|
||||
|
||||
self.change_alpha = change_alpha
|
||||
self.change_digit = change_digit
|
||||
self.change_punctuation = change_punctuation
|
||||
self.change_space = change_space
|
||||
|
||||
FH_SPACE = [(u" ", u" ")]
|
||||
FH_NUM = [
|
||||
(u"0", u"0"), (u"1", u"1"), (u"2", u"2"), (u"3", u"3"), (u"4", u"4"),
|
||||
(u"5", u"5"), (u"6", u"6"), (u"7", u"7"), (u"8", u"8"), (u"9", u"9")]
|
||||
FH_ALPHA = [
|
||||
(u"a", u"a"), (u"b", u"b"), (u"c", u"c"), (u"d", u"d"), (u"e", u"e"),
|
||||
(u"f", u"f"), (u"g", u"g"), (u"h", u"h"), (u"i", u"i"), (u"j", u"j"),
|
||||
(u"k", u"k"), (u"l", u"l"), (u"m", u"m"), (u"n", u"n"), (u"o", u"o"),
|
||||
(u"p", u"p"), (u"q", u"q"), (u"r", u"r"), (u"s", u"s"), (u"t", u"t"),
|
||||
(u"u", u"u"), (u"v", u"v"), (u"w", u"w"), (u"x", u"x"), (u"y", u"y"),
|
||||
(u"z", u"z"),
|
||||
(u"A", u"A"), (u"B", u"B"), (u"C", u"C"), (u"D", u"D"), (u"E", u"E"),
|
||||
(u"F", u"F"), (u"G", u"G"), (u"H", u"H"), (u"I", u"I"), (u"J", u"J"),
|
||||
(u"K", u"K"), (u"L", u"L"), (u"M", u"M"), (u"N", u"N"), (u"O", u"O"),
|
||||
(u"P", u"P"), (u"Q", u"Q"), (u"R", u"R"), (u"S", u"S"), (u"T", u"T"),
|
||||
(u"U", u"U"), (u"V", u"V"), (u"W", u"W"), (u"X", u"X"), (u"Y", u"Y"),
|
||||
(u"Z", u"Z")]
|
||||
# 谨慎使用标点符号转换, 因为"5.12特大地震"转换后可能就成了"5.12特大地震"
|
||||
FH_PUNCTUATION = [
|
||||
(u'%', u'%'), (u'!', u'!'), (u'"', u'\"'), (u''', u'\''), (u'#', u'#'),
|
||||
(u'¥', u'$'), (u'&', u'&'), (u'(', u'('), (u')', u')'), (u'*', u'*'),
|
||||
(u'+', u'+'), (u',', u','), (u'-', u'-'), (u'.', u'.'), (u'/', u'/'),
|
||||
(u':', u':'), (u';', u';'), (u'<', u'<'), (u'=', u'='), (u'>', u'>'),
|
||||
(u'?', u'?'), (u'@', u'@'), (u'[', u'['), (u']', u']'), (u'\', u'\\'),
|
||||
(u'^', u'^'), (u'_', u'_'), (u'`', u'`'), (u'~', u'~'), (u'{', u'{'),
|
||||
(u'}', u'}'), (u'|', u'|')]
|
||||
FHs = []
|
||||
if self.change_alpha:
|
||||
FHs = FH_ALPHA
|
||||
if self.change_digit:
|
||||
FHs += FH_NUM
|
||||
if self.change_punctuation:
|
||||
FHs += FH_PUNCTUATION
|
||||
if self.change_space:
|
||||
FHs += FH_SPACE
|
||||
self.convert_map = {k: v for k, v in FHs}
|
||||
def process(self, dataset):
|
||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
|
||||
for ins in dataset:
|
||||
sentence = ins[self.field_name].text
|
||||
new_sentence = [None]*len(sentence)
|
||||
for idx, char in enumerate(sentence):
|
||||
if char in self.convert_map:
|
||||
char = self.convert_map[char]
|
||||
new_sentence[idx] = char
|
||||
ins[self.field_name].text = ''.join(new_sentence)
|
||||
return dataset
|
||||
|
||||
|
||||
class SpeicalSpanProcessor(Processor):
|
||||
# 这个类会将句子中的special span转换为对应的内容。
|
||||
def __init__(self, field_name, new_added_field_name=None):
|
||||
super(SpeicalSpanProcessor, self).__init__(field_name, new_added_field_name)
|
||||
|
||||
self.span_converters = []
|
||||
|
||||
|
||||
def process(self, dataset):
|
||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
|
||||
for ins in dataset:
|
||||
sentence = ins[self.field_name].text
|
||||
for span_converter in self.span_converters:
|
||||
sentence = span_converter.find_certain_span_and_replace(sentence)
|
||||
if self.new_added_field_name!=self.field_name:
|
||||
new_text_field = TextField(sentence, is_target=False)
|
||||
ins[self.new_added_field_name] = new_text_field
|
||||
else:
|
||||
ins[self.field_name].text = sentence
|
||||
|
||||
return dataset
|
||||
|
||||
def add_span_converter(self, converter):
|
||||
assert isinstance(converter, SpanConverterBase), "Only SpanConverterBase is allowed, not {}."\
|
||||
.format(type(converter))
|
||||
self.span_converters.append(converter)
|
||||
|
||||
|
||||
|
||||
class CWSCharSegProcessor(Processor):
|
||||
def __init__(self, field_name, new_added_field_name):
|
||||
super(CWSCharSegProcessor, self).__init__(field_name, new_added_field_name)
|
||||
|
||||
def process(self, dataset):
|
||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
|
||||
for ins in dataset:
|
||||
sentence = ins[self.field_name].text
|
||||
chars = self._split_sent_into_chars(sentence)
|
||||
new_token_field = TokenListFiled(chars, is_target=False)
|
||||
ins[self.new_added_field_name] = new_token_field
|
||||
|
||||
return dataset
|
||||
|
||||
def _split_sent_into_chars(self, sentence):
|
||||
sp_tag_match_iter = re.finditer(_SPECIAL_TAG_PATTERN, sentence)
|
||||
sp_spans = [match_span.span() for match_span in sp_tag_match_iter]
|
||||
sp_span_idx = 0
|
||||
in_span_flag = False
|
||||
chars = []
|
||||
num_spans = len(sp_spans)
|
||||
for idx, char in enumerate(sentence):
|
||||
if sp_span_idx<num_spans and idx == sp_spans[sp_span_idx][0]:
|
||||
in_span_flag = True
|
||||
elif in_span_flag and sp_span_idx<num_spans and idx == sp_spans[sp_span_idx][1] - 1:
|
||||
chars.append(sentence[sp_spans[sp_span_idx]
|
||||
[0]:sp_spans[sp_span_idx][1]])
|
||||
in_span_flag = False
|
||||
sp_span_idx += 1
|
||||
elif not in_span_flag:
|
||||
# TODO 需要谨慎考虑如何处理空格的问题
|
||||
if char != ' ':
|
||||
chars.append(char)
|
||||
else:
|
||||
pass
|
||||
return chars
|
||||
|
||||
|
||||
class CWSTagProcessor(Processor):
|
||||
def __init__(self, field_name, new_added_field_name=None):
|
||||
super(CWSTagProcessor, self).__init__(field_name, new_added_field_name)
|
||||
|
||||
def _generate_tag(self, sentence):
|
||||
sp_tag_match_iter = re.finditer(_SPECIAL_TAG_PATTERN, sentence)
|
||||
sp_spans = [match_span.span() for match_span in sp_tag_match_iter]
|
||||
sp_span_idx = 0
|
||||
in_span_flag = False
|
||||
tag_list = []
|
||||
word_len = 0
|
||||
num_spans = len(sp_spans)
|
||||
for idx, char in enumerate(sentence):
|
||||
if sp_span_idx<num_spans and idx == sp_spans[sp_span_idx][0]:
|
||||
in_span_flag = True
|
||||
elif in_span_flag and sp_span_idx<num_spans and idx == sp_spans[sp_span_idx][1] - 1:
|
||||
word_len += 1
|
||||
in_span_flag = False
|
||||
sp_span_idx += 1
|
||||
elif not in_span_flag:
|
||||
if char == ' ':
|
||||
if word_len!=0:
|
||||
tag_list.extend(self._tags_from_word_len(word_len))
|
||||
word_len = 0
|
||||
else:
|
||||
word_len += 1
|
||||
else:
|
||||
pass
|
||||
if word_len!=0:
|
||||
tag_list.extend(self._tags_from_word_len(word_len))
|
||||
|
||||
return tag_list
|
||||
|
||||
def process(self, dataset):
|
||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
|
||||
for ins in dataset:
|
||||
sentence = ins[self.field_name].text
|
||||
tag_list = self._generate_tag(sentence)
|
||||
new_tag_field = SeqLabelField(tag_list)
|
||||
ins[self.new_added_field_name] = new_tag_field
|
||||
return dataset
|
||||
|
||||
def _tags_from_word_len(self, word_len):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class CWSSegAppTagProcessor(CWSTagProcessor):
|
||||
def __init__(self, field_name, new_added_field_name=None):
|
||||
super(CWSSegAppTagProcessor, self).__init__(field_name, new_added_field_name)
|
||||
|
||||
def _tags_from_word_len(self, word_len):
|
||||
tag_list = []
|
||||
for _ in range(word_len-1):
|
||||
tag_list.append(0)
|
||||
tag_list.append(1)
|
||||
return tag_list
|
||||
|
||||
|
||||
class BigramProcessor(Processor):
|
||||
def __init__(self, field_name, new_added_fielf_name=None):
|
||||
|
||||
super(BigramProcessor, self).__init__(field_name, new_added_fielf_name)
|
||||
|
||||
def process(self, dataset):
|
||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
|
||||
|
||||
for ins in dataset:
|
||||
characters = ins[self.field_name].content
|
||||
bigrams = self._generate_bigram(characters)
|
||||
new_token_field = TokenListFiled(bigrams)
|
||||
ins[self.new_added_field_name] = new_token_field
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def _generate_bigram(self, characters):
|
||||
pass
|
||||
|
||||
|
||||
class Pre2Post2BigramProcessor(BigramProcessor):
|
||||
def __init__(self, field_name, new_added_fielf_name=None):
|
||||
|
||||
super(BigramProcessor, self).__init__(field_name, new_added_fielf_name)
|
||||
|
||||
def _generate_bigram(self, characters):
|
||||
bigrams = []
|
||||
characters = ['<SOS>', '<SOS>'] + characters + ['<EOS>', '<EOS>']
|
||||
for idx in range(2, len(characters)-2):
|
||||
cur_char = characters[idx]
|
||||
pre_pre_char = characters[idx-2]
|
||||
pre_char = characters[idx-1]
|
||||
post_char = characters[idx+1]
|
||||
post_post_char = characters[idx+2]
|
||||
pre_pre_cur_bigram = pre_pre_char + cur_char
|
||||
pre_cur_bigram = pre_char + cur_char
|
||||
cur_post_bigram = cur_char + post_char
|
||||
cur_post_post_bigram = cur_char + post_post_char
|
||||
bigrams.extend([pre_pre_char, pre_char, post_char, post_post_char,
|
||||
pre_pre_cur_bigram, pre_cur_bigram,
|
||||
cur_post_bigram, cur_post_post_bigram])
|
||||
return bigrams
|
||||
|
||||
|
||||
# 这里需要建立vocabulary了,但是遇到了以下的问题
|
||||
# (1) 如果使用Processor的方式的话,但是在这种情况返回的不是dataset。所以建立vocabulary的工作用另外的方式实现,不借用
|
||||
# Processor了
|
||||
class IndexProcessor(Processor):
|
||||
def __init__(self, vocab, field_name):
|
||||
|
||||
assert isinstance(vocab, Vocabulary), "Only Vocabulary class is allowed, not {}.".format(type(vocab))
|
||||
|
||||
super(IndexProcessor, self).__init__(field_name, None)
|
||||
self.vocab = vocab
|
||||
|
||||
def set_vocab(self, vocab):
|
||||
assert isinstance(vocab, Vocabulary), "Only Vocabulary class is allowed, not {}.".format(type(vocab))
|
||||
|
||||
self.vocab = vocab
|
||||
|
||||
def process(self, dataset):
|
||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
|
||||
for ins in dataset:
|
||||
tokens = ins[self.field_name].content
|
||||
index = [self.vocab.to_index(token) for token in tokens]
|
||||
ins[self.field_name]._index = index
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
class VocabProcessor(Processor):
|
||||
def __init__(self, field_name):
|
||||
|
||||
super(VocabProcessor, self).__init__(field_name, None)
|
||||
self.vocab = Vocabulary()
|
||||
|
||||
def process(self, dataset):
|
||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
|
||||
for ins in dataset:
|
||||
tokens = ins[self.field_name].content
|
||||
self.vocab.update(tokens)
|
||||
|
||||
def get_vocab(self):
|
||||
self.vocab.build_vocab()
|
||||
return self.vocab
|
3
reproduction/chinese_word_segment/train_context.py
Normal file
3
reproduction/chinese_word_segment/train_context.py
Normal file
@ -0,0 +1,3 @@
|
||||
|
||||
|
||||
|
86
reproduction/chinese_word_segment/utils.py
Normal file
86
reproduction/chinese_word_segment/utils.py
Normal file
@ -0,0 +1,86 @@
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def seq_lens_to_mask(seq_lens):
|
||||
batch_size = seq_lens.size(0)
|
||||
max_len = seq_lens.max()
|
||||
|
||||
indexes = torch.arange(max_len).view(1, -1).repeat(batch_size, 1).to(seq_lens.device)
|
||||
masks = indexes.lt(seq_lens.unsqueeze(1))
|
||||
|
||||
return masks
|
||||
|
||||
|
||||
def cut_long_training_sentences(sentences, max_sample_length=200):
|
||||
cutted_sentence = []
|
||||
for sent in sentences:
|
||||
sent_no_space = sent.replace(' ', '')
|
||||
if len(sent_no_space) > max_sample_length:
|
||||
parts = sent.strip().split()
|
||||
new_line = ''
|
||||
length = 0
|
||||
for part in parts:
|
||||
length += len(part)
|
||||
new_line += part + ' '
|
||||
if length > max_sample_length:
|
||||
new_line = new_line[:-1]
|
||||
cutted_sentence.append(new_line)
|
||||
length = 0
|
||||
new_line = ''
|
||||
if new_line != '':
|
||||
cutted_sentence.append(new_line[:-1])
|
||||
else:
|
||||
cutted_sentence.append(sent)
|
||||
return cutted_sentence
|
||||
|
||||
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class FocalLoss(nn.Module):
|
||||
r"""
|
||||
This criterion is a implemenation of Focal Loss, which is proposed in
|
||||
Focal Loss for Dense Object Detection.
|
||||
|
||||
Loss(x, class) = - \alpha (1-softmax(x)[class])^gamma \log(softmax(x)[class])
|
||||
|
||||
The losses are averaged across observations for each minibatch.
|
||||
Args:
|
||||
alpha(1D Tensor, Variable) : the scalar factor for this criterion
|
||||
gamma(float, double) : gamma > 0; reduces the relative loss for well-classified examples (p > .5),
|
||||
putting more focus on hard, misclassified examples
|
||||
size_average(bool): size_average(bool): By default, the losses are averaged over observations for each minibatch.
|
||||
However, if the field size_average is set to False, the losses are
|
||||
instead summed for each minibatch.
|
||||
"""
|
||||
|
||||
def __init__(self, class_num, gamma=2, size_average=True, reduce=False):
|
||||
super(FocalLoss, self).__init__()
|
||||
self.gamma = gamma
|
||||
self.class_num = class_num
|
||||
self.size_average = size_average
|
||||
self.reduce = reduce
|
||||
|
||||
def forward(self, inputs, targets):
|
||||
N = inputs.size(0)
|
||||
C = inputs.size(1)
|
||||
P = F.softmax(inputs, dim=-1)
|
||||
|
||||
class_mask = inputs.data.new(N, C).fill_(0)
|
||||
class_mask.requires_grad = True
|
||||
ids = targets.view(-1, 1)
|
||||
class_mask = class_mask.scatter(1, ids.data, 1.)
|
||||
|
||||
probs = (P * class_mask).sum(1).view(-1, 1)
|
||||
|
||||
log_p = probs.log()
|
||||
|
||||
batch_loss = - (torch.pow((1 - probs), self.gamma)) * log_p
|
||||
if self.reduce:
|
||||
if self.size_average:
|
||||
loss = batch_loss.mean()
|
||||
else:
|
||||
loss = batch_loss.sum()
|
||||
return loss
|
||||
return batch_loss
|
Loading…
Reference in New Issue
Block a user