mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-11-30 03:07:59 +08:00
新增NER的数据加载与模型代码; 修改metric中的typo; 修改LSTM中的默认初始化将forget gate设置为1.
This commit is contained in:
parent
2f5d8967a3
commit
9a8fe42cd4
@ -428,16 +428,16 @@ def _bioes_tag_to_spans(tags, ignore_labels=None):
|
||||
prev_bioes_tag = None
|
||||
for idx, tag in enumerate(tags):
|
||||
tag = tag.lower()
|
||||
bieso_tag, label = tag[:1], tag[2:]
|
||||
if bieso_tag in ('b', 's'):
|
||||
bioes_tag, label = tag[:1], tag[2:]
|
||||
if bioes_tag in ('b', 's'):
|
||||
spans.append((label, [idx, idx]))
|
||||
elif bieso_tag in ('i', 'e') and prev_bioes_tag in ('b', 'i') and label == spans[-1][0]:
|
||||
elif bioes_tag in ('i', 'e') and prev_bioes_tag in ('b', 'i') and label == spans[-1][0]:
|
||||
spans[-1][1][1] = idx
|
||||
elif bieso_tag == 'o':
|
||||
elif bioes_tag == 'o':
|
||||
pass
|
||||
else:
|
||||
spans.append((label, [idx, idx]))
|
||||
prev_bioes_tag = bieso_tag
|
||||
prev_bioes_tag = bioes_tag
|
||||
return [(span[0], (span[1][0], span[1][1] + 1))
|
||||
for span in spans
|
||||
if span[0] not in ignore_labels
|
||||
|
@ -500,8 +500,8 @@ class CNNCharEmbedding(TokenEmbedding):
|
||||
"""
|
||||
别名::class:`fastNLP.modules.CNNCharEmbedding` :class:`fastNLP.modules.encoder.embedding.CNNCharEmbedding`
|
||||
|
||||
使用CNN生成character embedding。CNN的结果为, CNN(x) -> activation(x) -> pool -> fc. 不同的kernel大小的fitler结果是
|
||||
concat起来的。
|
||||
使用CNN生成character embedding。CNN的结果为, embed(x) -> Dropout(x) -> CNN(x) -> activation(x) -> pool
|
||||
-> fc. 不同的kernel大小的fitler结果是concat起来的。
|
||||
|
||||
Example::
|
||||
|
||||
@ -511,13 +511,14 @@ class CNNCharEmbedding(TokenEmbedding):
|
||||
:param vocab: 词表
|
||||
:param embed_size: 该word embedding的大小,默认值为50.
|
||||
:param char_emb_size: character的embed的大小。character是从vocab中生成的。默认值为50.
|
||||
:param dropout: 以多大的概率drop
|
||||
:param filter_nums: filter的数量. 长度需要和kernels一致。默认值为[40, 30, 20].
|
||||
:param kernel_sizes: kernel的大小. 默认值为[5, 3, 1].
|
||||
:param pool_method: character的表示在合成一个表示时所使用的pool方法,支持'avg', 'max'.
|
||||
:param activation: CNN之后使用的激活方法,支持'relu', 'sigmoid', 'tanh' 或者自定义函数.
|
||||
:param min_char_freq: character的最少出现次数。默认值为2.
|
||||
"""
|
||||
def __init__(self, vocab: Vocabulary, embed_size: int=50, char_emb_size: int=50,
|
||||
def __init__(self, vocab: Vocabulary, embed_size: int=50, char_emb_size: int=50, dropout:float=0.5,
|
||||
filter_nums: List[int]=(40, 30, 20), kernel_sizes: List[int]=(5, 3, 1), pool_method: str='max',
|
||||
activation='relu', min_char_freq: int=2):
|
||||
super(CNNCharEmbedding, self).__init__(vocab)
|
||||
@ -526,6 +527,7 @@ class CNNCharEmbedding(TokenEmbedding):
|
||||
assert kernel % 2 == 1, "Only odd kernel is allowed."
|
||||
|
||||
assert pool_method in ('max', 'avg')
|
||||
self.dropout = nn.Dropout(dropout, inplace=True)
|
||||
self.pool_method = pool_method
|
||||
# activation function
|
||||
if isinstance(activation, str):
|
||||
@ -583,7 +585,7 @@ class CNNCharEmbedding(TokenEmbedding):
|
||||
# 为1的地方为mask
|
||||
chars_masks = chars.eq(self.char_pad_index) # batch_size x max_len x max_word_len 如果为0, 说明是padding的位置了
|
||||
chars = self.char_embedding(chars) # batch_size x max_len x max_word_len x embed_size
|
||||
|
||||
chars = self.dropout(chars)
|
||||
reshaped_chars = chars.reshape(batch_size*max_len, max_word_len, -1)
|
||||
reshaped_chars = reshaped_chars.transpose(1, 2) # B' x E x M
|
||||
conv_chars = [conv(reshaped_chars).transpose(1, 2).reshape(batch_size, max_len, max_word_len, -1)
|
||||
@ -635,7 +637,7 @@ class LSTMCharEmbedding(TokenEmbedding):
|
||||
"""
|
||||
别名::class:`fastNLP.modules.LSTMCharEmbedding` :class:`fastNLP.modules.encoder.embedding.LSTMCharEmbedding`
|
||||
|
||||
使用LSTM的方式对character进行encode.
|
||||
使用LSTM的方式对character进行encode. embed(x) -> Dropout(x) -> LSTM(x) -> activation(x) -> pool
|
||||
|
||||
Example::
|
||||
|
||||
@ -644,13 +646,14 @@ class LSTMCharEmbedding(TokenEmbedding):
|
||||
:param vocab: 词表
|
||||
:param embed_size: embedding的大小。默认值为50.
|
||||
:param char_emb_size: character的embedding的大小。默认值为50.
|
||||
:param dropout: 以多大概率drop
|
||||
:param hidden_size: LSTM的中间hidden的大小,如果为bidirectional的,hidden会除二,默认为50.
|
||||
:param pool_method: 支持'max', 'avg'
|
||||
:param activation: 激活函数,支持'relu', 'sigmoid', 'tanh', 或者自定义函数.
|
||||
:param min_char_freq: character的最小出现次数。默认值为2.
|
||||
:param bidirectional: 是否使用双向的LSTM进行encode。默认值为True。
|
||||
"""
|
||||
def __init__(self, vocab: Vocabulary, embed_size: int=50, char_emb_size: int=50, hidden_size=50,
|
||||
def __init__(self, vocab: Vocabulary, embed_size: int=50, char_emb_size: int=50, dropout:float=0.5, hidden_size=50,
|
||||
pool_method: str='max', activation='relu', min_char_freq: int=2, bidirectional=True):
|
||||
super(LSTMCharEmbedding, self).__init__(vocab)
|
||||
|
||||
@ -658,7 +661,7 @@ class LSTMCharEmbedding(TokenEmbedding):
|
||||
|
||||
assert pool_method in ('max', 'avg')
|
||||
self.pool_method = pool_method
|
||||
|
||||
self.dropout = nn.Dropout(dropout, inplace=True)
|
||||
# activation function
|
||||
if isinstance(activation, str):
|
||||
if activation.lower() == 'relu':
|
||||
@ -715,7 +718,7 @@ class LSTMCharEmbedding(TokenEmbedding):
|
||||
# 为mask的地方为1
|
||||
chars_masks = chars.eq(self.char_pad_index) # batch_size x max_len x max_word_len 如果为0, 说明是padding的位置了
|
||||
chars = self.char_embedding(chars) # batch_size x max_len x max_word_len x embed_size
|
||||
|
||||
chars = self.dropout(chars)
|
||||
reshaped_chars = chars.reshape(batch_size * max_len, max_word_len, -1)
|
||||
char_seq_len = chars_masks.eq(0).sum(dim=-1).reshape(batch_size * max_len)
|
||||
lstm_chars = self.lstm(reshaped_chars, char_seq_len)[0].reshape(batch_size, max_len, max_word_len, -1)
|
||||
|
@ -40,12 +40,14 @@ class LSTM(nn.Module):
|
||||
|
||||
def init_param(self):
|
||||
for name, param in self.named_parameters():
|
||||
if 'bias_i' in name:
|
||||
param.data.fill_(1)
|
||||
elif 'bias_h' in name:
|
||||
if 'bias' in name:
|
||||
# based on https://github.com/pytorch/pytorch/issues/750#issuecomment-280671871
|
||||
param.data.fill_(0)
|
||||
n = param.size(0)
|
||||
start, end = n // 4, n // 2
|
||||
param.data[start:end].fill_(1)
|
||||
else:
|
||||
nn.init.xavier_normal_(param)
|
||||
nn.init.xavier_uniform_(param)
|
||||
|
||||
def forward(self, x, seq_len=None, h0=None, c0=None):
|
||||
"""
|
||||
|
0
reproduction/seqence_labelling/ner/__init__.py
Normal file
0
reproduction/seqence_labelling/ner/__init__.py
Normal file
92
reproduction/seqence_labelling/ner/data/Conll2003Loader.py
Normal file
92
reproduction/seqence_labelling/ner/data/Conll2003Loader.py
Normal file
@ -0,0 +1,92 @@
|
||||
|
||||
from fastNLP.core.vocabulary import VocabularyOption
|
||||
from fastNLP.io.base_loader import DataSetLoader, DataInfo
|
||||
from typing import Union, Dict
|
||||
from fastNLP import Vocabulary
|
||||
from fastNLP import Const
|
||||
from reproduction.utils import check_dataloader_paths
|
||||
|
||||
from fastNLP.io.dataset_loader import ConllLoader
|
||||
from reproduction.seqence_labelling.ner.data.utils import iob2bioes, iob2
|
||||
|
||||
|
||||
class Conll2003DataLoader(DataSetLoader):
|
||||
def __init__(self, task:str='ner', encoding_type:str='bioes'):
|
||||
"""
|
||||
加载Conll2003格式的英语语料,该数据集的信息可以在https://www.clips.uantwerpen.be/conll2003/ner/找到。当task为pos
|
||||
时,返回的DataSet中target取值于第2列; 当task为chunk时,返回的DataSet中target取值于第3列;当task为ner时,返回
|
||||
的DataSet中target取值于第4列。所有"-DOCSTART- -X- O O"将被忽略,这会导致数据的数量少于很多文献报道的值,但
|
||||
鉴于"-DOCSTART- -X- O O"只是用于文档分割的符号,并不应该作为预测对象,所以我们忽略了数据中的中该值
|
||||
ner与chunk任务读取后的数据的target将为encoding_type类型。pos任务读取后就是pos列的数据。
|
||||
|
||||
:param task: 指定需要标注任务。可选ner, pos, chunk
|
||||
"""
|
||||
assert task in ('ner', 'pos', 'chunk')
|
||||
index = {'ner':3, 'pos':1, 'chunk':2}[task]
|
||||
self._loader = ConllLoader(headers=['raw_words', 'target'], indexes=[0, index])
|
||||
self._tag_converters = None
|
||||
if task in ('ner', 'chunk'):
|
||||
self._tag_converters = [iob2]
|
||||
if encoding_type == 'bioes':
|
||||
self._tag_converters.append(iob2bioes)
|
||||
|
||||
def load(self, path: str):
|
||||
dataset = self._loader.load(path)
|
||||
def convert_tag_schema(tags):
|
||||
for converter in self._tag_converters:
|
||||
tags = converter(tags)
|
||||
return tags
|
||||
if self._tag_converters:
|
||||
dataset.apply_field(convert_tag_schema, field_name=Const.TARGET, new_field_name=Const.TARGET)
|
||||
return dataset
|
||||
|
||||
def process(self, paths: Union[str, Dict[str, str]], word_vocab_opt:VocabularyOption=None, lower:bool=True):
|
||||
"""
|
||||
读取并处理数据。数据中的'-DOCSTART-'开头的行会被忽略
|
||||
|
||||
:param paths:
|
||||
:param word_vocab_opt: vocabulary的初始化值
|
||||
:param lower: 是否将所有字母转为小写
|
||||
:return:
|
||||
"""
|
||||
# 读取数据
|
||||
paths = check_dataloader_paths(paths)
|
||||
data = DataInfo()
|
||||
input_fields = [Const.TARGET, Const.INPUT, Const.INPUT_LEN]
|
||||
target_fields = [Const.TARGET, Const.INPUT_LEN]
|
||||
for name, path in paths.items():
|
||||
dataset = self.load(path)
|
||||
dataset.apply_field(lambda words: words, field_name='raw_words', new_field_name=Const.INPUT)
|
||||
if lower:
|
||||
dataset.apply_field(lambda words:[word.lower() for word in words], field_name=Const.INPUT,
|
||||
new_field_name=Const.INPUT)
|
||||
data.datasets[name] = dataset
|
||||
|
||||
# 对construct vocab
|
||||
word_vocab = Vocabulary(min_freq=3) if word_vocab_opt is None else Vocabulary(**word_vocab_opt)
|
||||
word_vocab.from_dataset(data.datasets['train'], field_name=Const.INPUT)
|
||||
word_vocab.index_dataset(*data.datasets.values(), field_name=Const.INPUT, new_field_name=Const.INPUT)
|
||||
data.vocabs[Const.INPUT] = word_vocab
|
||||
|
||||
# cap words
|
||||
cap_word_vocab = Vocabulary()
|
||||
cap_word_vocab.from_dataset(*data.datasets.values(), field_name='raw_words')
|
||||
cap_word_vocab.index_dataset(*data.datasets.values(), field_name='raw_words', new_field_name='cap_words')
|
||||
input_fields.append('cap_words')
|
||||
data.vocabs['cap_words'] = cap_word_vocab
|
||||
|
||||
# 对target建vocab
|
||||
target_vocab = Vocabulary(unknown=None, padding=None)
|
||||
target_vocab.from_dataset(*data.datasets.values(), field_name=Const.TARGET)
|
||||
target_vocab.index_dataset(*data.datasets.values(), field_name=Const.TARGET)
|
||||
data.vocabs[Const.TARGET] = target_vocab
|
||||
|
||||
for name, dataset in data.datasets.items():
|
||||
dataset.add_seq_len(Const.INPUT, new_field_name=Const.INPUT_LEN)
|
||||
dataset.set_input(*input_fields)
|
||||
dataset.set_target(*target_fields)
|
||||
|
||||
return data
|
||||
|
||||
if __name__ == '__main__':
|
||||
pass
|
130
reproduction/seqence_labelling/ner/data/OntoNoteLoader.py
Normal file
130
reproduction/seqence_labelling/ner/data/OntoNoteLoader.py
Normal file
@ -0,0 +1,130 @@
|
||||
from fastNLP.core.vocabulary import VocabularyOption
|
||||
from fastNLP.io.base_loader import DataSetLoader, DataInfo
|
||||
from typing import Union, Dict
|
||||
from fastNLP import DataSet
|
||||
from fastNLP import Vocabulary
|
||||
from fastNLP import Const
|
||||
from reproduction.utils import check_dataloader_paths
|
||||
|
||||
from fastNLP.io.dataset_loader import ConllLoader
|
||||
from reproduction.seqence_labelling.ner.data.utils import iob2bioes, iob2
|
||||
|
||||
class OntoNoteNERDataLoader(DataSetLoader):
|
||||
"""
|
||||
用于读取处理为Conll格式后的OntoNote数据。将OntoNote数据处理为conll格式的过程可以参考https://github.com/yhcc/OntoNotes-5.0-NER。
|
||||
|
||||
"""
|
||||
def __init__(self, encoding_type:str='bioes'):
|
||||
assert encoding_type in ('bioes', 'bio')
|
||||
self.encoding_type = encoding_type
|
||||
if encoding_type=='bioes':
|
||||
self.encoding_method = iob2bioes
|
||||
else:
|
||||
self.encoding_method = iob2
|
||||
|
||||
def load(self, path:str)->DataSet:
|
||||
"""
|
||||
给定一个文件路径,读取数据。返回的DataSet包含以下的field
|
||||
raw_words: List[str]
|
||||
target: List[str]
|
||||
|
||||
:param path:
|
||||
:return:
|
||||
"""
|
||||
dataset = ConllLoader(headers=['raw_words', 'target'], indexes=[3, 10]).load(path)
|
||||
def convert_to_bio(tags):
|
||||
bio_tags = []
|
||||
flag = None
|
||||
for tag in tags:
|
||||
label = tag.strip("()*")
|
||||
if '(' in tag:
|
||||
bio_label = 'B-' + label
|
||||
flag = label
|
||||
elif flag:
|
||||
bio_label = 'I-' + flag
|
||||
else:
|
||||
bio_label = 'O'
|
||||
if ')' in tag:
|
||||
flag = None
|
||||
bio_tags.append(bio_label)
|
||||
return self.encoding_method(bio_tags)
|
||||
|
||||
dataset.apply_field(convert_to_bio, field_name='target', new_field_name='target')
|
||||
|
||||
return dataset
|
||||
|
||||
def process(self, paths: Union[str, Dict[str, str]], word_vocab_opt:VocabularyOption=None,
|
||||
lower:bool=True)->DataInfo:
|
||||
"""
|
||||
读取并处理数据。返回的DataInfo包含以下的内容
|
||||
vocabs:
|
||||
word: Vocabulary
|
||||
target: Vocabulary
|
||||
datasets:
|
||||
train: DataSet
|
||||
words: List[int], 被设置为input
|
||||
target: int. label,被同时设置为input和target
|
||||
seq_len: int. 句子的长度,被同时设置为input和target
|
||||
raw_words: List[str]
|
||||
xxx(根据传入的paths可能有所变化)
|
||||
|
||||
:param paths:
|
||||
:param word_vocab_opt: vocabulary的初始化值
|
||||
:param lower: 是否使用小写
|
||||
:return:
|
||||
"""
|
||||
paths = check_dataloader_paths(paths)
|
||||
data = DataInfo()
|
||||
input_fields = [Const.TARGET, Const.INPUT, Const.INPUT_LEN]
|
||||
target_fields = [Const.TARGET, Const.INPUT_LEN]
|
||||
for name, path in paths.items():
|
||||
dataset = self.load(path)
|
||||
dataset.apply_field(lambda words: words, field_name='raw_words', new_field_name=Const.INPUT)
|
||||
if lower:
|
||||
dataset.apply_field(lambda words:[word.lower() for word in words], field_name=Const.INPUT,
|
||||
new_field_name=Const.INPUT)
|
||||
data.datasets[name] = dataset
|
||||
|
||||
# 对construct vocab
|
||||
word_vocab = Vocabulary(min_freq=2) if word_vocab_opt is None else Vocabulary(**word_vocab_opt)
|
||||
word_vocab.from_dataset(data.datasets['train'], field_name='raw_words')
|
||||
word_vocab.index_dataset(*data.datasets.values(), field_name='raw_words', new_field_name=Const.INPUT)
|
||||
data.vocabs[Const.INPUT] = word_vocab
|
||||
|
||||
# cap words
|
||||
cap_word_vocab = Vocabulary()
|
||||
cap_word_vocab.from_dataset(data.datasets['train'], field_name='raw_words')
|
||||
cap_word_vocab.index_dataset(*data.datasets.values(), field_name='raw_words', new_field_name='cap_words')
|
||||
input_fields.append('cap_words')
|
||||
data.vocabs['cap_words'] = cap_word_vocab
|
||||
|
||||
# 对target建vocab
|
||||
target_vocab = Vocabulary(unknown=None, padding=None)
|
||||
target_vocab.from_dataset(*data.datasets.values(), field_name=Const.TARGET)
|
||||
target_vocab.index_dataset(*data.datasets.values(), field_name=Const.TARGET)
|
||||
data.vocabs[Const.TARGET] = target_vocab
|
||||
|
||||
for name, dataset in data.datasets.items():
|
||||
dataset.add_seq_len(Const.INPUT, new_field_name=Const.INPUT_LEN)
|
||||
dataset.set_input(*input_fields)
|
||||
dataset.set_target(*target_fields)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
loader = OntoNoteNERDataLoader()
|
||||
dataset = loader.load('/hdd/fudanNLP/fastNLP/others/data/v4/english/test.txt')
|
||||
print(dataset.target.value_count())
|
||||
print(dataset[:4])
|
||||
|
||||
|
||||
"""
|
||||
train 115812 2200752
|
||||
development 15680 304684
|
||||
test 12217 230111
|
||||
|
||||
train 92403 1901772
|
||||
valid 13606 279180
|
||||
test 10258 204135
|
||||
"""
|
49
reproduction/seqence_labelling/ner/data/utils.py
Normal file
49
reproduction/seqence_labelling/ner/data/utils.py
Normal file
@ -0,0 +1,49 @@
|
||||
from typing import List
|
||||
|
||||
def iob2(tags:List[str])->List[str]:
|
||||
"""
|
||||
检查数据是否是合法的IOB数据,如果是IOB1会被自动转换为IOB2。
|
||||
|
||||
:param tags: 需要转换的tags
|
||||
"""
|
||||
for i, tag in enumerate(tags):
|
||||
if tag == "O":
|
||||
continue
|
||||
split = tag.split("-")
|
||||
if len(split) != 2 or split[0] not in ["I", "B"]:
|
||||
raise TypeError("The encoding schema is not a valid IOB type.")
|
||||
if split[0] == "B":
|
||||
continue
|
||||
elif i == 0 or tags[i - 1] == "O": # conversion IOB1 to IOB2
|
||||
tags[i] = "B" + tag[1:]
|
||||
elif tags[i - 1][1:] == tag[1:]:
|
||||
continue
|
||||
else: # conversion IOB1 to IOB2
|
||||
tags[i] = "B" + tag[1:]
|
||||
return tags
|
||||
|
||||
def iob2bioes(tags:List[str])->List[str]:
|
||||
"""
|
||||
将iob的tag转换为bmeso编码
|
||||
:param tags:
|
||||
:return:
|
||||
"""
|
||||
new_tags = []
|
||||
for i, tag in enumerate(tags):
|
||||
if tag == 'O':
|
||||
new_tags.append(tag)
|
||||
else:
|
||||
split = tag.split('-')[0]
|
||||
if split == 'B':
|
||||
if i+1!=len(tags) and tags[i+1].split('-')[0] == 'I':
|
||||
new_tags.append(tag)
|
||||
else:
|
||||
new_tags.append(tag.replace('B-', 'S-'))
|
||||
elif split == 'I':
|
||||
if i + 1<len(tags) and tags[i+1].split('-')[0] == 'I':
|
||||
new_tags.append(tag)
|
||||
else:
|
||||
new_tags.append(tag.replace('I-', 'E-'))
|
||||
else:
|
||||
raise TypeError("Invalid IOB format.")
|
||||
return new_tags
|
62
reproduction/seqence_labelling/ner/model/lstm_cnn_crf.py
Normal file
62
reproduction/seqence_labelling/ner/model/lstm_cnn_crf.py
Normal file
@ -0,0 +1,62 @@
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from fastNLP import seq_len_to_mask
|
||||
from fastNLP.modules import Embedding
|
||||
from fastNLP.modules import LSTM
|
||||
from fastNLP.modules import ConditionalRandomField, allowed_transitions, TimestepDropout
|
||||
import torch.nn.functional as F
|
||||
from fastNLP import Const
|
||||
|
||||
class CNNBiLSTMCRF(nn.Module):
|
||||
def __init__(self, embed, char_embed, hidden_size, num_layers, tag_vocab, dropout=0.5, encoding_type='bioes'):
|
||||
super().__init__()
|
||||
|
||||
self.embedding = Embedding(embed, dropout=0.5)
|
||||
self.char_embedding = Embedding(char_embed, dropout=0.5)
|
||||
self.lstm = LSTM(input_size=self.embedding.embedding_dim+self.char_embedding.embedding_dim,
|
||||
hidden_size=hidden_size//2, num_layers=num_layers,
|
||||
bidirectional=True, batch_first=True, dropout=dropout)
|
||||
self.forward_fc = nn.Linear(hidden_size//2, len(tag_vocab))
|
||||
self.backward_fc = nn.Linear(hidden_size//2, len(tag_vocab))
|
||||
|
||||
transitions = allowed_transitions(tag_vocab.idx2word, encoding_type=encoding_type, include_start_end=False)
|
||||
self.crf = ConditionalRandomField(len(tag_vocab), include_start_end_trans=False, allowed_transitions=transitions)
|
||||
|
||||
self.dropout = TimestepDropout(dropout, inplace=True)
|
||||
|
||||
for name, param in self.named_parameters():
|
||||
if 'ward_fc' in name:
|
||||
if param.data.dim()>1:
|
||||
nn.init.xavier_normal_(param)
|
||||
else:
|
||||
nn.init.constant_(param, 0)
|
||||
if 'crf' in name:
|
||||
nn.init.zeros_(param)
|
||||
|
||||
def _forward(self, words, cap_words, seq_len, target=None):
|
||||
words = self.embedding(words)
|
||||
chars = self.char_embedding(cap_words)
|
||||
words = torch.cat([words, chars], dim=-1)
|
||||
outputs, _ = self.lstm(words, seq_len)
|
||||
self.dropout(outputs)
|
||||
forwards, backwards = outputs.chunk(2, dim=-1)
|
||||
|
||||
# forward_logits = F.log_softmax(self.forward_fc(forwards), dim=-1)
|
||||
# backward_logits = F.log_softmax(self.backward_fc(backwards), dim=-1)
|
||||
|
||||
logits = self.forward_fc(forwards) + self.backward_fc(backwards)
|
||||
self.dropout(logits)
|
||||
|
||||
if target is not None:
|
||||
loss = self.crf(logits, target, seq_len_to_mask(seq_len))
|
||||
return {Const.LOSS: loss}
|
||||
else:
|
||||
pred, _ = self.crf.viterbi_decode(logits, seq_len_to_mask(seq_len))
|
||||
return {Const.OUTPUT: pred}
|
||||
|
||||
def forward(self, words, cap_words, seq_len, target):
|
||||
return self._forward(words, cap_words, seq_len, target)
|
||||
|
||||
def predict(self, words, cap_words, seq_len):
|
||||
return self._forward(words, cap_words, seq_len, None)
|
0
reproduction/seqence_labelling/ner/test/__init__.py
Normal file
0
reproduction/seqence_labelling/ner/test/__init__.py
Normal file
33
reproduction/seqence_labelling/ner/test/test.py
Normal file
33
reproduction/seqence_labelling/ner/test/test.py
Normal file
@ -0,0 +1,33 @@
|
||||
|
||||
from reproduction.seqence_labelling.ner.data.Conll2003Loader import Conll2003DataLoader
|
||||
from reproduction.seqence_labelling.ner.data.Conll2003Loader import iob2, iob2bioes
|
||||
import unittest
|
||||
|
||||
class TestTagSchemaConverter(unittest.TestCase):
|
||||
def test_iob2(self):
|
||||
tags = ['B-ORG', 'O', 'B-MISC', 'O', 'O', 'O', 'B-MISC', 'O', 'O']
|
||||
golden = ['B-ORG', 'O', 'B-MISC', 'O', 'O', 'O', 'B-MISC', 'O', 'O']
|
||||
self.assertListEqual(golden, iob2(tags))
|
||||
|
||||
tags = ['I-ORG', 'O']
|
||||
golden = ['B-ORG', 'O']
|
||||
self.assertListEqual(golden, iob2(tags))
|
||||
|
||||
tags = ['I-MISC', 'I-MISC', 'O', 'I-PER', 'I-PER', 'O']
|
||||
golden = ['B-MISC', 'I-MISC', 'O', 'B-PER', 'I-PER', 'O']
|
||||
self.assertListEqual(golden, iob2(tags))
|
||||
|
||||
def test_iob2bemso(self):
|
||||
tags = ['B-MISC', 'I-MISC', 'O', 'B-PER', 'I-PER', 'O']
|
||||
golden = ['B-MISC', 'E-MISC', 'O', 'B-PER', 'E-PER', 'O']
|
||||
self.assertListEqual(golden, iob2bioes(tags))
|
||||
|
||||
|
||||
def test_conll2003_loader():
|
||||
path = '/hdd/fudanNLP/fastNLP/others/data/conll2003/train.txt'
|
||||
loader = Conll2003DataLoader().load(path)
|
||||
print(loader[:3])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_conll2003_loader()
|
@ -0,0 +1,42 @@
|
||||
|
||||
|
||||
from fastNLP.modules.encoder.embedding import CNNCharEmbedding, StaticEmbedding, BertEmbedding
|
||||
from fastNLP.core.vocabulary import VocabularyOption
|
||||
|
||||
from reproduction.seqence_labelling.ner.model.lstm_cnn_crf import CNNBiLSTMCRF
|
||||
from fastNLP import Trainer
|
||||
from fastNLP import SpanFPreRecMetric
|
||||
from fastNLP import BucketSampler
|
||||
from fastNLP import Const
|
||||
from torch.optim import SGD, Adam
|
||||
from fastNLP import GradientClipCallback
|
||||
from fastNLP.core.callback import FitlogCallback
|
||||
import fitlog
|
||||
fitlog.debug()
|
||||
|
||||
from reproduction.seqence_labelling.ner.data.Conll2003Loader import Conll2003DataLoader
|
||||
|
||||
encoding_type = 'bioes'
|
||||
|
||||
data = Conll2003DataLoader(encoding_type=encoding_type).process('/hdd/fudanNLP/fastNLP/others/data/conll2003',
|
||||
word_vocab_opt=VocabularyOption(min_freq=3))
|
||||
print(data)
|
||||
char_embed = CNNCharEmbedding(vocab=data.vocabs['cap_words'], embed_size=30, char_emb_size=30, filter_nums=[30],
|
||||
kernel_sizes=[3])
|
||||
word_embed = StaticEmbedding(vocab=data.vocabs[Const.INPUT],
|
||||
model_dir_or_name='/hdd/fudanNLP/pretrain_vectors/glove.6B.100d.txt',
|
||||
requires_grad=True)
|
||||
word_embed.embedding.weight.data = word_embed.embedding.weight.data/word_embed.embedding.weight.data.std()
|
||||
|
||||
model = CNNBiLSTMCRF(word_embed, char_embed, hidden_size=400, num_layers=1, tag_vocab=data.vocabs[Const.TARGET],
|
||||
encoding_type=encoding_type)
|
||||
|
||||
optimizer = Adam(model.parameters(), lr=0.001)
|
||||
|
||||
callbacks = [GradientClipCallback(clip_type='value'), FitlogCallback({'test':data.datasets['test']}, verbose=1)]
|
||||
|
||||
trainer = Trainer(train_data=data.datasets['train'], model=model, optimizer=optimizer, sampler=BucketSampler(),
|
||||
device=0, dev_data=data.datasets['dev'], batch_size=32,
|
||||
metrics=SpanFPreRecMetric(tag_vocab=data.vocabs[Const.TARGET], encoding_type=encoding_type),
|
||||
callbacks=callbacks, num_workers=1, n_epochs=100)
|
||||
trainer.train()
|
39
reproduction/seqence_labelling/ner/train_ontonote.py
Normal file
39
reproduction/seqence_labelling/ner/train_ontonote.py
Normal file
@ -0,0 +1,39 @@
|
||||
|
||||
|
||||
from fastNLP.modules.encoder.embedding import CNNCharEmbedding, StaticEmbedding
|
||||
|
||||
from reproduction.seqence_labelling.ner.model.lstm_cnn_crf import CNNBiLSTMCRF
|
||||
from fastNLP import Trainer
|
||||
from fastNLP import SpanFPreRecMetric
|
||||
from fastNLP import BucketSampler
|
||||
from fastNLP import Const
|
||||
from torch.optim import SGD, Adam
|
||||
from fastNLP import GradientClipCallback
|
||||
from fastNLP.core.callback import FitlogCallback
|
||||
import fitlog
|
||||
fitlog.debug()
|
||||
|
||||
from reproduction.seqence_labelling.ner.data.OntoNoteLoader import OntoNoteNERDataLoader
|
||||
|
||||
encoding_type = 'bioes'
|
||||
|
||||
data = OntoNoteNERDataLoader(encoding_type=encoding_type).process('/hdd/fudanNLP/fastNLP/others/data/v4/english')
|
||||
print(data)
|
||||
char_embed = CNNCharEmbedding(vocab=data.vocabs['cap_words'], embed_size=30, char_emb_size=30, filter_nums=[30],
|
||||
kernel_sizes=[3])
|
||||
word_embed = StaticEmbedding(vocab=data.vocabs[Const.INPUT],
|
||||
model_dir_or_name='/hdd/fudanNLP/pretrain_vectors/glove.6B.100d.txt',
|
||||
requires_grad=True)
|
||||
|
||||
model = CNNBiLSTMCRF(word_embed, char_embed, hidden_size=200, num_layers=1, tag_vocab=data.vocabs[Const.TARGET],
|
||||
encoding_type=encoding_type)
|
||||
|
||||
optimizer = Adam(model.parameters(), lr=0.001)
|
||||
|
||||
callbacks = [GradientClipCallback(), FitlogCallback(data.datasets['test'], verbose=1)]
|
||||
|
||||
trainer = Trainer(train_data=data.datasets['train'], model=model, optimizer=optimizer, sampler=BucketSampler(),
|
||||
device=1, dev_data=data.datasets['dev'], batch_size=32,
|
||||
metrics=SpanFPreRecMetric(tag_vocab=data.vocabs[Const.TARGET], encoding_type=encoding_type),
|
||||
callbacks=callbacks, num_workers=1, n_epochs=100)
|
||||
trainer.train()
|
Loading…
Reference in New Issue
Block a user