From 9a8fe42cd4a322d0639fdd64d05574e70de55013 Mon Sep 17 00:00:00 2001 From: yh_cc Date: Tue, 18 Jun 2019 10:02:24 +0800 Subject: [PATCH] =?UTF-8?q?=E6=96=B0=E5=A2=9ENER=E7=9A=84=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E5=8A=A0=E8=BD=BD=E4=B8=8E=E6=A8=A1=E5=9E=8B=E4=BB=A3?= =?UTF-8?q?=E7=A0=81;=20=20=E4=BF=AE=E6=94=B9metric=E4=B8=AD=E7=9A=84typo;?= =?UTF-8?q?=20=E4=BF=AE=E6=94=B9LSTM=E4=B8=AD=E7=9A=84=E9=BB=98=E8=AE=A4?= =?UTF-8?q?=E5=88=9D=E5=A7=8B=E5=8C=96=E5=B0=86forget=20gate=E8=AE=BE?= =?UTF-8?q?=E7=BD=AE=E4=B8=BA1.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/metrics.py | 10 +- fastNLP/modules/encoder/embedding.py | 19 +-- fastNLP/modules/encoder/lstm.py | 10 +- .../seqence_labelling/ner/__init__.py | 0 .../ner/data/Conll2003Loader.py | 92 +++++++++++++ .../ner/data/OntoNoteLoader.py | 130 ++++++++++++++++++ .../seqence_labelling/ner/data/utils.py | 49 +++++++ .../ner/model/lstm_cnn_crf.py | 62 +++++++++ .../seqence_labelling/ner/test/__init__.py | 0 .../seqence_labelling/ner/test/test.py | 33 +++++ .../ner/train_cnn_lstm_crf_conll2003.py | 42 ++++++ .../seqence_labelling/ner/train_ontonote.py | 39 ++++++ 12 files changed, 469 insertions(+), 17 deletions(-) create mode 100644 reproduction/seqence_labelling/ner/__init__.py create mode 100644 reproduction/seqence_labelling/ner/data/Conll2003Loader.py create mode 100644 reproduction/seqence_labelling/ner/data/OntoNoteLoader.py create mode 100644 reproduction/seqence_labelling/ner/data/utils.py create mode 100644 reproduction/seqence_labelling/ner/model/lstm_cnn_crf.py create mode 100644 reproduction/seqence_labelling/ner/test/__init__.py create mode 100644 reproduction/seqence_labelling/ner/test/test.py create mode 100644 reproduction/seqence_labelling/ner/train_cnn_lstm_crf_conll2003.py create mode 100644 reproduction/seqence_labelling/ner/train_ontonote.py diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index cfcb9039..d54bf8ec 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -428,16 +428,16 @@ def _bioes_tag_to_spans(tags, ignore_labels=None): prev_bioes_tag = None for idx, tag in enumerate(tags): tag = tag.lower() - bieso_tag, label = tag[:1], tag[2:] - if bieso_tag in ('b', 's'): + bioes_tag, label = tag[:1], tag[2:] + if bioes_tag in ('b', 's'): spans.append((label, [idx, idx])) - elif bieso_tag in ('i', 'e') and prev_bioes_tag in ('b', 'i') and label == spans[-1][0]: + elif bioes_tag in ('i', 'e') and prev_bioes_tag in ('b', 'i') and label == spans[-1][0]: spans[-1][1][1] = idx - elif bieso_tag == 'o': + elif bioes_tag == 'o': pass else: spans.append((label, [idx, idx])) - prev_bioes_tag = bieso_tag + prev_bioes_tag = bioes_tag return [(span[0], (span[1][0], span[1][1] + 1)) for span in spans if span[0] not in ignore_labels diff --git a/fastNLP/modules/encoder/embedding.py b/fastNLP/modules/encoder/embedding.py index e8fe903b..121bc950 100644 --- a/fastNLP/modules/encoder/embedding.py +++ b/fastNLP/modules/encoder/embedding.py @@ -500,8 +500,8 @@ class CNNCharEmbedding(TokenEmbedding): """ 别名::class:`fastNLP.modules.CNNCharEmbedding` :class:`fastNLP.modules.encoder.embedding.CNNCharEmbedding` - 使用CNN生成character embedding。CNN的结果为, CNN(x) -> activation(x) -> pool -> fc. 不同的kernel大小的fitler结果是 - concat起来的。 + 使用CNN生成character embedding。CNN的结果为, embed(x) -> Dropout(x) -> CNN(x) -> activation(x) -> pool + -> fc. 不同的kernel大小的fitler结果是concat起来的。 Example:: @@ -511,13 +511,14 @@ class CNNCharEmbedding(TokenEmbedding): :param vocab: 词表 :param embed_size: 该word embedding的大小,默认值为50. :param char_emb_size: character的embed的大小。character是从vocab中生成的。默认值为50. + :param dropout: 以多大的概率drop :param filter_nums: filter的数量. 长度需要和kernels一致。默认值为[40, 30, 20]. :param kernel_sizes: kernel的大小. 默认值为[5, 3, 1]. :param pool_method: character的表示在合成一个表示时所使用的pool方法,支持'avg', 'max'. :param activation: CNN之后使用的激活方法,支持'relu', 'sigmoid', 'tanh' 或者自定义函数. :param min_char_freq: character的最少出现次数。默认值为2. """ - def __init__(self, vocab: Vocabulary, embed_size: int=50, char_emb_size: int=50, + def __init__(self, vocab: Vocabulary, embed_size: int=50, char_emb_size: int=50, dropout:float=0.5, filter_nums: List[int]=(40, 30, 20), kernel_sizes: List[int]=(5, 3, 1), pool_method: str='max', activation='relu', min_char_freq: int=2): super(CNNCharEmbedding, self).__init__(vocab) @@ -526,6 +527,7 @@ class CNNCharEmbedding(TokenEmbedding): assert kernel % 2 == 1, "Only odd kernel is allowed." assert pool_method in ('max', 'avg') + self.dropout = nn.Dropout(dropout, inplace=True) self.pool_method = pool_method # activation function if isinstance(activation, str): @@ -583,7 +585,7 @@ class CNNCharEmbedding(TokenEmbedding): # 为1的地方为mask chars_masks = chars.eq(self.char_pad_index) # batch_size x max_len x max_word_len 如果为0, 说明是padding的位置了 chars = self.char_embedding(chars) # batch_size x max_len x max_word_len x embed_size - + chars = self.dropout(chars) reshaped_chars = chars.reshape(batch_size*max_len, max_word_len, -1) reshaped_chars = reshaped_chars.transpose(1, 2) # B' x E x M conv_chars = [conv(reshaped_chars).transpose(1, 2).reshape(batch_size, max_len, max_word_len, -1) @@ -635,7 +637,7 @@ class LSTMCharEmbedding(TokenEmbedding): """ 别名::class:`fastNLP.modules.LSTMCharEmbedding` :class:`fastNLP.modules.encoder.embedding.LSTMCharEmbedding` - 使用LSTM的方式对character进行encode. + 使用LSTM的方式对character进行encode. embed(x) -> Dropout(x) -> LSTM(x) -> activation(x) -> pool Example:: @@ -644,13 +646,14 @@ class LSTMCharEmbedding(TokenEmbedding): :param vocab: 词表 :param embed_size: embedding的大小。默认值为50. :param char_emb_size: character的embedding的大小。默认值为50. + :param dropout: 以多大概率drop :param hidden_size: LSTM的中间hidden的大小,如果为bidirectional的,hidden会除二,默认为50. :param pool_method: 支持'max', 'avg' :param activation: 激活函数,支持'relu', 'sigmoid', 'tanh', 或者自定义函数. :param min_char_freq: character的最小出现次数。默认值为2. :param bidirectional: 是否使用双向的LSTM进行encode。默认值为True。 """ - def __init__(self, vocab: Vocabulary, embed_size: int=50, char_emb_size: int=50, hidden_size=50, + def __init__(self, vocab: Vocabulary, embed_size: int=50, char_emb_size: int=50, dropout:float=0.5, hidden_size=50, pool_method: str='max', activation='relu', min_char_freq: int=2, bidirectional=True): super(LSTMCharEmbedding, self).__init__(vocab) @@ -658,7 +661,7 @@ class LSTMCharEmbedding(TokenEmbedding): assert pool_method in ('max', 'avg') self.pool_method = pool_method - + self.dropout = nn.Dropout(dropout, inplace=True) # activation function if isinstance(activation, str): if activation.lower() == 'relu': @@ -715,7 +718,7 @@ class LSTMCharEmbedding(TokenEmbedding): # 为mask的地方为1 chars_masks = chars.eq(self.char_pad_index) # batch_size x max_len x max_word_len 如果为0, 说明是padding的位置了 chars = self.char_embedding(chars) # batch_size x max_len x max_word_len x embed_size - + chars = self.dropout(chars) reshaped_chars = chars.reshape(batch_size * max_len, max_word_len, -1) char_seq_len = chars_masks.eq(0).sum(dim=-1).reshape(batch_size * max_len) lstm_chars = self.lstm(reshaped_chars, char_seq_len)[0].reshape(batch_size, max_len, max_word_len, -1) diff --git a/fastNLP/modules/encoder/lstm.py b/fastNLP/modules/encoder/lstm.py index 3b97f4a7..537a446d 100644 --- a/fastNLP/modules/encoder/lstm.py +++ b/fastNLP/modules/encoder/lstm.py @@ -40,12 +40,14 @@ class LSTM(nn.Module): def init_param(self): for name, param in self.named_parameters(): - if 'bias_i' in name: - param.data.fill_(1) - elif 'bias_h' in name: + if 'bias' in name: + # based on https://github.com/pytorch/pytorch/issues/750#issuecomment-280671871 param.data.fill_(0) + n = param.size(0) + start, end = n // 4, n // 2 + param.data[start:end].fill_(1) else: - nn.init.xavier_normal_(param) + nn.init.xavier_uniform_(param) def forward(self, x, seq_len=None, h0=None, c0=None): """ diff --git a/reproduction/seqence_labelling/ner/__init__.py b/reproduction/seqence_labelling/ner/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/reproduction/seqence_labelling/ner/data/Conll2003Loader.py b/reproduction/seqence_labelling/ner/data/Conll2003Loader.py new file mode 100644 index 00000000..65ed7ab8 --- /dev/null +++ b/reproduction/seqence_labelling/ner/data/Conll2003Loader.py @@ -0,0 +1,92 @@ + +from fastNLP.core.vocabulary import VocabularyOption +from fastNLP.io.base_loader import DataSetLoader, DataInfo +from typing import Union, Dict +from fastNLP import Vocabulary +from fastNLP import Const +from reproduction.utils import check_dataloader_paths + +from fastNLP.io.dataset_loader import ConllLoader +from reproduction.seqence_labelling.ner.data.utils import iob2bioes, iob2 + + +class Conll2003DataLoader(DataSetLoader): + def __init__(self, task:str='ner', encoding_type:str='bioes'): + """ + 加载Conll2003格式的英语语料,该数据集的信息可以在https://www.clips.uantwerpen.be/conll2003/ner/找到。当task为pos + 时,返回的DataSet中target取值于第2列; 当task为chunk时,返回的DataSet中target取值于第3列;当task为ner时,返回 + 的DataSet中target取值于第4列。所有"-DOCSTART- -X- O O"将被忽略,这会导致数据的数量少于很多文献报道的值,但 + 鉴于"-DOCSTART- -X- O O"只是用于文档分割的符号,并不应该作为预测对象,所以我们忽略了数据中的中该值 + ner与chunk任务读取后的数据的target将为encoding_type类型。pos任务读取后就是pos列的数据。 + + :param task: 指定需要标注任务。可选ner, pos, chunk + """ + assert task in ('ner', 'pos', 'chunk') + index = {'ner':3, 'pos':1, 'chunk':2}[task] + self._loader = ConllLoader(headers=['raw_words', 'target'], indexes=[0, index]) + self._tag_converters = None + if task in ('ner', 'chunk'): + self._tag_converters = [iob2] + if encoding_type == 'bioes': + self._tag_converters.append(iob2bioes) + + def load(self, path: str): + dataset = self._loader.load(path) + def convert_tag_schema(tags): + for converter in self._tag_converters: + tags = converter(tags) + return tags + if self._tag_converters: + dataset.apply_field(convert_tag_schema, field_name=Const.TARGET, new_field_name=Const.TARGET) + return dataset + + def process(self, paths: Union[str, Dict[str, str]], word_vocab_opt:VocabularyOption=None, lower:bool=True): + """ + 读取并处理数据。数据中的'-DOCSTART-'开头的行会被忽略 + + :param paths: + :param word_vocab_opt: vocabulary的初始化值 + :param lower: 是否将所有字母转为小写 + :return: + """ + # 读取数据 + paths = check_dataloader_paths(paths) + data = DataInfo() + input_fields = [Const.TARGET, Const.INPUT, Const.INPUT_LEN] + target_fields = [Const.TARGET, Const.INPUT_LEN] + for name, path in paths.items(): + dataset = self.load(path) + dataset.apply_field(lambda words: words, field_name='raw_words', new_field_name=Const.INPUT) + if lower: + dataset.apply_field(lambda words:[word.lower() for word in words], field_name=Const.INPUT, + new_field_name=Const.INPUT) + data.datasets[name] = dataset + + # 对construct vocab + word_vocab = Vocabulary(min_freq=3) if word_vocab_opt is None else Vocabulary(**word_vocab_opt) + word_vocab.from_dataset(data.datasets['train'], field_name=Const.INPUT) + word_vocab.index_dataset(*data.datasets.values(), field_name=Const.INPUT, new_field_name=Const.INPUT) + data.vocabs[Const.INPUT] = word_vocab + + # cap words + cap_word_vocab = Vocabulary() + cap_word_vocab.from_dataset(*data.datasets.values(), field_name='raw_words') + cap_word_vocab.index_dataset(*data.datasets.values(), field_name='raw_words', new_field_name='cap_words') + input_fields.append('cap_words') + data.vocabs['cap_words'] = cap_word_vocab + + # 对target建vocab + target_vocab = Vocabulary(unknown=None, padding=None) + target_vocab.from_dataset(*data.datasets.values(), field_name=Const.TARGET) + target_vocab.index_dataset(*data.datasets.values(), field_name=Const.TARGET) + data.vocabs[Const.TARGET] = target_vocab + + for name, dataset in data.datasets.items(): + dataset.add_seq_len(Const.INPUT, new_field_name=Const.INPUT_LEN) + dataset.set_input(*input_fields) + dataset.set_target(*target_fields) + + return data + +if __name__ == '__main__': + pass \ No newline at end of file diff --git a/reproduction/seqence_labelling/ner/data/OntoNoteLoader.py b/reproduction/seqence_labelling/ner/data/OntoNoteLoader.py new file mode 100644 index 00000000..bf1ab71e --- /dev/null +++ b/reproduction/seqence_labelling/ner/data/OntoNoteLoader.py @@ -0,0 +1,130 @@ +from fastNLP.core.vocabulary import VocabularyOption +from fastNLP.io.base_loader import DataSetLoader, DataInfo +from typing import Union, Dict +from fastNLP import DataSet +from fastNLP import Vocabulary +from fastNLP import Const +from reproduction.utils import check_dataloader_paths + +from fastNLP.io.dataset_loader import ConllLoader +from reproduction.seqence_labelling.ner.data.utils import iob2bioes, iob2 + +class OntoNoteNERDataLoader(DataSetLoader): + """ + 用于读取处理为Conll格式后的OntoNote数据。将OntoNote数据处理为conll格式的过程可以参考https://github.com/yhcc/OntoNotes-5.0-NER。 + + """ + def __init__(self, encoding_type:str='bioes'): + assert encoding_type in ('bioes', 'bio') + self.encoding_type = encoding_type + if encoding_type=='bioes': + self.encoding_method = iob2bioes + else: + self.encoding_method = iob2 + + def load(self, path:str)->DataSet: + """ + 给定一个文件路径,读取数据。返回的DataSet包含以下的field + raw_words: List[str] + target: List[str] + + :param path: + :return: + """ + dataset = ConllLoader(headers=['raw_words', 'target'], indexes=[3, 10]).load(path) + def convert_to_bio(tags): + bio_tags = [] + flag = None + for tag in tags: + label = tag.strip("()*") + if '(' in tag: + bio_label = 'B-' + label + flag = label + elif flag: + bio_label = 'I-' + flag + else: + bio_label = 'O' + if ')' in tag: + flag = None + bio_tags.append(bio_label) + return self.encoding_method(bio_tags) + + dataset.apply_field(convert_to_bio, field_name='target', new_field_name='target') + + return dataset + + def process(self, paths: Union[str, Dict[str, str]], word_vocab_opt:VocabularyOption=None, + lower:bool=True)->DataInfo: + """ + 读取并处理数据。返回的DataInfo包含以下的内容 + vocabs: + word: Vocabulary + target: Vocabulary + datasets: + train: DataSet + words: List[int], 被设置为input + target: int. label,被同时设置为input和target + seq_len: int. 句子的长度,被同时设置为input和target + raw_words: List[str] + xxx(根据传入的paths可能有所变化) + + :param paths: + :param word_vocab_opt: vocabulary的初始化值 + :param lower: 是否使用小写 + :return: + """ + paths = check_dataloader_paths(paths) + data = DataInfo() + input_fields = [Const.TARGET, Const.INPUT, Const.INPUT_LEN] + target_fields = [Const.TARGET, Const.INPUT_LEN] + for name, path in paths.items(): + dataset = self.load(path) + dataset.apply_field(lambda words: words, field_name='raw_words', new_field_name=Const.INPUT) + if lower: + dataset.apply_field(lambda words:[word.lower() for word in words], field_name=Const.INPUT, + new_field_name=Const.INPUT) + data.datasets[name] = dataset + + # 对construct vocab + word_vocab = Vocabulary(min_freq=2) if word_vocab_opt is None else Vocabulary(**word_vocab_opt) + word_vocab.from_dataset(data.datasets['train'], field_name='raw_words') + word_vocab.index_dataset(*data.datasets.values(), field_name='raw_words', new_field_name=Const.INPUT) + data.vocabs[Const.INPUT] = word_vocab + + # cap words + cap_word_vocab = Vocabulary() + cap_word_vocab.from_dataset(data.datasets['train'], field_name='raw_words') + cap_word_vocab.index_dataset(*data.datasets.values(), field_name='raw_words', new_field_name='cap_words') + input_fields.append('cap_words') + data.vocabs['cap_words'] = cap_word_vocab + + # 对target建vocab + target_vocab = Vocabulary(unknown=None, padding=None) + target_vocab.from_dataset(*data.datasets.values(), field_name=Const.TARGET) + target_vocab.index_dataset(*data.datasets.values(), field_name=Const.TARGET) + data.vocabs[Const.TARGET] = target_vocab + + for name, dataset in data.datasets.items(): + dataset.add_seq_len(Const.INPUT, new_field_name=Const.INPUT_LEN) + dataset.set_input(*input_fields) + dataset.set_target(*target_fields) + + return data + + +if __name__ == '__main__': + loader = OntoNoteNERDataLoader() + dataset = loader.load('/hdd/fudanNLP/fastNLP/others/data/v4/english/test.txt') + print(dataset.target.value_count()) + print(dataset[:4]) + + +""" +train 115812 2200752 +development 15680 304684 +test 12217 230111 + +train 92403 1901772 +valid 13606 279180 +test 10258 204135 +""" \ No newline at end of file diff --git a/reproduction/seqence_labelling/ner/data/utils.py b/reproduction/seqence_labelling/ner/data/utils.py new file mode 100644 index 00000000..8f7af792 --- /dev/null +++ b/reproduction/seqence_labelling/ner/data/utils.py @@ -0,0 +1,49 @@ +from typing import List + +def iob2(tags:List[str])->List[str]: + """ + 检查数据是否是合法的IOB数据,如果是IOB1会被自动转换为IOB2。 + + :param tags: 需要转换的tags + """ + for i, tag in enumerate(tags): + if tag == "O": + continue + split = tag.split("-") + if len(split) != 2 or split[0] not in ["I", "B"]: + raise TypeError("The encoding schema is not a valid IOB type.") + if split[0] == "B": + continue + elif i == 0 or tags[i - 1] == "O": # conversion IOB1 to IOB2 + tags[i] = "B" + tag[1:] + elif tags[i - 1][1:] == tag[1:]: + continue + else: # conversion IOB1 to IOB2 + tags[i] = "B" + tag[1:] + return tags + +def iob2bioes(tags:List[str])->List[str]: + """ + 将iob的tag转换为bmeso编码 + :param tags: + :return: + """ + new_tags = [] + for i, tag in enumerate(tags): + if tag == 'O': + new_tags.append(tag) + else: + split = tag.split('-')[0] + if split == 'B': + if i+1!=len(tags) and tags[i+1].split('-')[0] == 'I': + new_tags.append(tag) + else: + new_tags.append(tag.replace('B-', 'S-')) + elif split == 'I': + if i + 11: + nn.init.xavier_normal_(param) + else: + nn.init.constant_(param, 0) + if 'crf' in name: + nn.init.zeros_(param) + + def _forward(self, words, cap_words, seq_len, target=None): + words = self.embedding(words) + chars = self.char_embedding(cap_words) + words = torch.cat([words, chars], dim=-1) + outputs, _ = self.lstm(words, seq_len) + self.dropout(outputs) + forwards, backwards = outputs.chunk(2, dim=-1) + + # forward_logits = F.log_softmax(self.forward_fc(forwards), dim=-1) + # backward_logits = F.log_softmax(self.backward_fc(backwards), dim=-1) + + logits = self.forward_fc(forwards) + self.backward_fc(backwards) + self.dropout(logits) + + if target is not None: + loss = self.crf(logits, target, seq_len_to_mask(seq_len)) + return {Const.LOSS: loss} + else: + pred, _ = self.crf.viterbi_decode(logits, seq_len_to_mask(seq_len)) + return {Const.OUTPUT: pred} + + def forward(self, words, cap_words, seq_len, target): + return self._forward(words, cap_words, seq_len, target) + + def predict(self, words, cap_words, seq_len): + return self._forward(words, cap_words, seq_len, None) diff --git a/reproduction/seqence_labelling/ner/test/__init__.py b/reproduction/seqence_labelling/ner/test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/reproduction/seqence_labelling/ner/test/test.py b/reproduction/seqence_labelling/ner/test/test.py new file mode 100644 index 00000000..09d0f468 --- /dev/null +++ b/reproduction/seqence_labelling/ner/test/test.py @@ -0,0 +1,33 @@ + +from reproduction.seqence_labelling.ner.data.Conll2003Loader import Conll2003DataLoader +from reproduction.seqence_labelling.ner.data.Conll2003Loader import iob2, iob2bioes +import unittest + +class TestTagSchemaConverter(unittest.TestCase): + def test_iob2(self): + tags = ['B-ORG', 'O', 'B-MISC', 'O', 'O', 'O', 'B-MISC', 'O', 'O'] + golden = ['B-ORG', 'O', 'B-MISC', 'O', 'O', 'O', 'B-MISC', 'O', 'O'] + self.assertListEqual(golden, iob2(tags)) + + tags = ['I-ORG', 'O'] + golden = ['B-ORG', 'O'] + self.assertListEqual(golden, iob2(tags)) + + tags = ['I-MISC', 'I-MISC', 'O', 'I-PER', 'I-PER', 'O'] + golden = ['B-MISC', 'I-MISC', 'O', 'B-PER', 'I-PER', 'O'] + self.assertListEqual(golden, iob2(tags)) + + def test_iob2bemso(self): + tags = ['B-MISC', 'I-MISC', 'O', 'B-PER', 'I-PER', 'O'] + golden = ['B-MISC', 'E-MISC', 'O', 'B-PER', 'E-PER', 'O'] + self.assertListEqual(golden, iob2bioes(tags)) + + +def test_conll2003_loader(): + path = '/hdd/fudanNLP/fastNLP/others/data/conll2003/train.txt' + loader = Conll2003DataLoader().load(path) + print(loader[:3]) + + +if __name__ == '__main__': + test_conll2003_loader() \ No newline at end of file diff --git a/reproduction/seqence_labelling/ner/train_cnn_lstm_crf_conll2003.py b/reproduction/seqence_labelling/ner/train_cnn_lstm_crf_conll2003.py new file mode 100644 index 00000000..278ff42f --- /dev/null +++ b/reproduction/seqence_labelling/ner/train_cnn_lstm_crf_conll2003.py @@ -0,0 +1,42 @@ + + +from fastNLP.modules.encoder.embedding import CNNCharEmbedding, StaticEmbedding, BertEmbedding +from fastNLP.core.vocabulary import VocabularyOption + +from reproduction.seqence_labelling.ner.model.lstm_cnn_crf import CNNBiLSTMCRF +from fastNLP import Trainer +from fastNLP import SpanFPreRecMetric +from fastNLP import BucketSampler +from fastNLP import Const +from torch.optim import SGD, Adam +from fastNLP import GradientClipCallback +from fastNLP.core.callback import FitlogCallback +import fitlog +fitlog.debug() + +from reproduction.seqence_labelling.ner.data.Conll2003Loader import Conll2003DataLoader + +encoding_type = 'bioes' + +data = Conll2003DataLoader(encoding_type=encoding_type).process('/hdd/fudanNLP/fastNLP/others/data/conll2003', + word_vocab_opt=VocabularyOption(min_freq=3)) +print(data) +char_embed = CNNCharEmbedding(vocab=data.vocabs['cap_words'], embed_size=30, char_emb_size=30, filter_nums=[30], + kernel_sizes=[3]) +word_embed = StaticEmbedding(vocab=data.vocabs[Const.INPUT], + model_dir_or_name='/hdd/fudanNLP/pretrain_vectors/glove.6B.100d.txt', + requires_grad=True) +word_embed.embedding.weight.data = word_embed.embedding.weight.data/word_embed.embedding.weight.data.std() + +model = CNNBiLSTMCRF(word_embed, char_embed, hidden_size=400, num_layers=1, tag_vocab=data.vocabs[Const.TARGET], + encoding_type=encoding_type) + +optimizer = Adam(model.parameters(), lr=0.001) + +callbacks = [GradientClipCallback(clip_type='value'), FitlogCallback({'test':data.datasets['test']}, verbose=1)] + +trainer = Trainer(train_data=data.datasets['train'], model=model, optimizer=optimizer, sampler=BucketSampler(), + device=0, dev_data=data.datasets['dev'], batch_size=32, + metrics=SpanFPreRecMetric(tag_vocab=data.vocabs[Const.TARGET], encoding_type=encoding_type), + callbacks=callbacks, num_workers=1, n_epochs=100) +trainer.train() \ No newline at end of file diff --git a/reproduction/seqence_labelling/ner/train_ontonote.py b/reproduction/seqence_labelling/ner/train_ontonote.py new file mode 100644 index 00000000..6f443dfd --- /dev/null +++ b/reproduction/seqence_labelling/ner/train_ontonote.py @@ -0,0 +1,39 @@ + + +from fastNLP.modules.encoder.embedding import CNNCharEmbedding, StaticEmbedding + +from reproduction.seqence_labelling.ner.model.lstm_cnn_crf import CNNBiLSTMCRF +from fastNLP import Trainer +from fastNLP import SpanFPreRecMetric +from fastNLP import BucketSampler +from fastNLP import Const +from torch.optim import SGD, Adam +from fastNLP import GradientClipCallback +from fastNLP.core.callback import FitlogCallback +import fitlog +fitlog.debug() + +from reproduction.seqence_labelling.ner.data.OntoNoteLoader import OntoNoteNERDataLoader + +encoding_type = 'bioes' + +data = OntoNoteNERDataLoader(encoding_type=encoding_type).process('/hdd/fudanNLP/fastNLP/others/data/v4/english') +print(data) +char_embed = CNNCharEmbedding(vocab=data.vocabs['cap_words'], embed_size=30, char_emb_size=30, filter_nums=[30], + kernel_sizes=[3]) +word_embed = StaticEmbedding(vocab=data.vocabs[Const.INPUT], + model_dir_or_name='/hdd/fudanNLP/pretrain_vectors/glove.6B.100d.txt', + requires_grad=True) + +model = CNNBiLSTMCRF(word_embed, char_embed, hidden_size=200, num_layers=1, tag_vocab=data.vocabs[Const.TARGET], + encoding_type=encoding_type) + +optimizer = Adam(model.parameters(), lr=0.001) + +callbacks = [GradientClipCallback(), FitlogCallback(data.datasets['test'], verbose=1)] + +trainer = Trainer(train_data=data.datasets['train'], model=model, optimizer=optimizer, sampler=BucketSampler(), + device=1, dev_data=data.datasets['dev'], batch_size=32, + metrics=SpanFPreRecMetric(tag_vocab=data.vocabs[Const.TARGET], encoding_type=encoding_type), + callbacks=callbacks, num_workers=1, n_epochs=100) +trainer.train() \ No newline at end of file