sequence labeling更新

This commit is contained in:
yh_cc 2019-06-19 11:14:41 +08:00
parent 4d138ed7f8
commit 4533427ea3
5 changed files with 24 additions and 23 deletions

View File

@ -63,8 +63,10 @@ class Conll2003DataLoader(DataSetLoader):
data.datasets[name] = dataset
# 对construct vocab
word_vocab = Vocabulary(min_freq=3) if word_vocab_opt is None else Vocabulary(**word_vocab_opt)
word_vocab.from_dataset(data.datasets['train'], field_name=Const.INPUT)
word_vocab = Vocabulary(min_freq=2) if word_vocab_opt is None else Vocabulary(**word_vocab_opt)
# word_vocab.from_dataset(data.datasets['train'], field_name=Const.INPUT)
# TODO 这样感觉不规范呐
word_vocab.from_dataset(*data.datasets.values(), field_name=Const.INPUT)
word_vocab.index_dataset(*data.datasets.values(), field_name=Const.INPUT, new_field_name=Const.INPUT)
data.vocabs[Const.INPUT] = word_vocab

View File

@ -87,7 +87,8 @@ class OntoNoteNERDataLoader(DataSetLoader):
# 对construct vocab
word_vocab = Vocabulary(min_freq=2) if word_vocab_opt is None else Vocabulary(**word_vocab_opt)
word_vocab.from_dataset(data.datasets['train'], field_name='raw_words')
# word_vocab.from_dataset(data.datasets['train'], field_name='raw_words')
word_vocab.from_dataset(*data.datasets.values(), field_name=Const.INPUT)
word_vocab.index_dataset(*data.datasets.values(), field_name='raw_words', new_field_name=Const.INPUT)
data.vocabs[Const.INPUT] = word_vocab

View File

@ -4,7 +4,7 @@ from torch import nn
from fastNLP import seq_len_to_mask
from fastNLP.modules import Embedding
from fastNLP.modules import LSTM
from fastNLP.modules import ConditionalRandomField, allowed_transitions, TimestepDropout
from fastNLP.modules import ConditionalRandomField, allowed_transitions
import torch.nn.functional as F
from fastNLP import Const
@ -17,13 +17,12 @@ class CNNBiLSTMCRF(nn.Module):
self.lstm = LSTM(input_size=self.embedding.embedding_dim+self.char_embedding.embedding_dim,
hidden_size=hidden_size//2, num_layers=num_layers,
bidirectional=True, batch_first=True, dropout=dropout)
self.forward_fc = nn.Linear(hidden_size//2, len(tag_vocab))
self.backward_fc = nn.Linear(hidden_size//2, len(tag_vocab))
self.fc = nn.Linear(hidden_size, len(tag_vocab))
transitions = allowed_transitions(tag_vocab.idx2word, encoding_type=encoding_type, include_start_end=False)
self.crf = ConditionalRandomField(len(tag_vocab), include_start_end_trans=False, allowed_transitions=transitions)
transitions = allowed_transitions(tag_vocab.idx2word, encoding_type=encoding_type, include_start_end=True)
self.crf = ConditionalRandomField(len(tag_vocab), include_start_end_trans=True, allowed_transitions=transitions)
self.dropout = TimestepDropout(dropout, inplace=True)
self.dropout = nn.Dropout(dropout, inplace=True)
for name, param in self.named_parameters():
if 'ward_fc' in name:
@ -40,13 +39,8 @@ class CNNBiLSTMCRF(nn.Module):
words = torch.cat([words, chars], dim=-1)
outputs, _ = self.lstm(words, seq_len)
self.dropout(outputs)
forwards, backwards = outputs.chunk(2, dim=-1)
# forward_logits = F.log_softmax(self.forward_fc(forwards), dim=-1)
# backward_logits = F.log_softmax(self.backward_fc(backwards), dim=-1)
logits = self.forward_fc(forwards) + self.backward_fc(backwards)
self.dropout(logits)
logits = F.log_softmax(self.fc(outputs), dim=-1)
if target is not None:
loss = self.crf(logits, target, seq_len_to_mask(seq_len))

View File

@ -10,7 +10,8 @@ from fastNLP import BucketSampler
from fastNLP import Const
from torch.optim import SGD, Adam
from fastNLP import GradientClipCallback
from fastNLP.core.callback import FitlogCallback
from fastNLP.core.callback import FitlogCallback, LRScheduler
from torch.optim.lr_scheduler import LambdaLR
import fitlog
fitlog.debug()
@ -19,7 +20,7 @@ from reproduction.seqence_labelling.ner.data.Conll2003Loader import Conll2003Dat
encoding_type = 'bioes'
data = Conll2003DataLoader(encoding_type=encoding_type).process('/hdd/fudanNLP/fastNLP/others/data/conll2003',
word_vocab_opt=VocabularyOption(min_freq=3))
word_vocab_opt=VocabularyOption(min_freq=2))
print(data)
char_embed = CNNCharEmbedding(vocab=data.vocabs['cap_words'], embed_size=30, char_emb_size=30, filter_nums=[30],
kernel_sizes=[3])
@ -28,15 +29,18 @@ word_embed = StaticEmbedding(vocab=data.vocabs[Const.INPUT],
requires_grad=True)
word_embed.embedding.weight.data = word_embed.embedding.weight.data/word_embed.embedding.weight.data.std()
model = CNNBiLSTMCRF(word_embed, char_embed, hidden_size=400, num_layers=1, tag_vocab=data.vocabs[Const.TARGET],
model = CNNBiLSTMCRF(word_embed, char_embed, hidden_size=200, num_layers=1, tag_vocab=data.vocabs[Const.TARGET],
encoding_type=encoding_type)
optimizer = Adam(model.parameters(), lr=0.001)
optimizer = SGD(model.parameters(), lr=0.01, momentum=0.9)
scheduler = LRScheduler(LambdaLR(optimizer, lr_lambda=lambda epoch: 1 / (1 + 0.05 * epoch)))
callbacks = [GradientClipCallback(clip_type='value'), FitlogCallback({'test':data.datasets['test']}, verbose=1)]
callbacks = [GradientClipCallback(clip_type='value', clip_value=5), FitlogCallback({'test':data.datasets['test'],
'train':data.datasets['train']}, verbose=1),
scheduler]
trainer = Trainer(train_data=data.datasets['train'], model=model, optimizer=optimizer, sampler=BucketSampler(),
device=0, dev_data=data.datasets['dev'], batch_size=32,
device=0, dev_data=data.datasets['dev'], batch_size=10,
metrics=SpanFPreRecMetric(tag_vocab=data.vocabs[Const.TARGET], encoding_type=encoding_type),
callbacks=callbacks, num_workers=1, n_epochs=100)
trainer.train()

View File

@ -25,10 +25,10 @@ word_embed = StaticEmbedding(vocab=data.vocabs[Const.INPUT],
model_dir_or_name='/hdd/fudanNLP/pretrain_vectors/glove.6B.100d.txt',
requires_grad=True)
model = CNNBiLSTMCRF(word_embed, char_embed, hidden_size=200, num_layers=1, tag_vocab=data.vocabs[Const.TARGET],
model = CNNBiLSTMCRF(word_embed, char_embed, hidden_size=400, num_layers=2, tag_vocab=data.vocabs[Const.TARGET],
encoding_type=encoding_type)
optimizer = Adam(model.parameters(), lr=0.001)
optimizer = SGD(model.parameters(), lr=0.015, momentum=0.9)
callbacks = [GradientClipCallback(), FitlogCallback(data.datasets['test'], verbose=1)]