mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-01 11:48:09 +08:00
1. 修改ELMO加载allennlp的权重;
This commit is contained in:
parent
fb0ce6cac4
commit
2c9a6e0ba4
@ -117,6 +117,8 @@ class Vocabulary(object):
|
||||
|
||||
:param str word: 新词
|
||||
"""
|
||||
if word in self._no_create_word:
|
||||
self._no_create_word.pop(word)
|
||||
self.add(word)
|
||||
|
||||
@_check_build_status
|
||||
@ -126,6 +128,9 @@ class Vocabulary(object):
|
||||
|
||||
:param list[str] word_lst: 词的序列
|
||||
"""
|
||||
for word in word_lst:
|
||||
if word in self._no_create_word:
|
||||
self._no_create_word.pop(word)
|
||||
self.update(word_lst)
|
||||
|
||||
def build_vocab(self):
|
||||
|
@ -179,16 +179,16 @@ class StaticEmbedding(TokenEmbedding):
|
||||
:param model_dir_or_name: 可以有两种方式调用预训练好的static embedding:第一种是传入embedding的文件名,第二种是传入embedding
|
||||
的名称。目前支持的embedding包括{`en` 或者 `en-glove-840b-300` : glove.840B.300d, `en-glove-6b-50` : glove.6B.50d,
|
||||
`en-word2vec-300` : GoogleNews-vectors-negative300}。第二种情况将自动查看缓存中是否存在该模型,没有的话将自动下载。
|
||||
:param requires_grad: 是否需要gradient. 默认为True
|
||||
:param init_method: 如何初始化没有找到的值。可以使用torch.nn.init.*中各种方法。调用该方法时传入一个tensor对象。
|
||||
:param normailize: 是否对vector进行normalize,使得每个vector的norm为1。
|
||||
:param bool requires_grad: 是否需要gradient. 默认为True
|
||||
:param callable init_method: 如何初始化没有找到的值。可以使用torch.nn.init.*中各种方法。调用该方法时传入一个tensor对象。
|
||||
:param bool normailize: 是否对vector进行normalize,使得每个vector的norm为1。
|
||||
:param bool lower: 是否将vocab中的词语小写后再和预训练的词表进行匹配。如果你的词表中包含大写的词语,或者就是需要单独
|
||||
为大写的词语开辟一个vector表示,则将lower设置为False。
|
||||
"""
|
||||
def __init__(self, vocab: Vocabulary, model_dir_or_name: str='en', requires_grad: bool=True, init_method=None,
|
||||
normalize=False):
|
||||
normalize=False, lower=False):
|
||||
super(StaticEmbedding, self).__init__(vocab)
|
||||
|
||||
# 优先定义需要下载的static embedding有哪些。这里估计需要自己搞一个server,
|
||||
|
||||
# 得到cache_path
|
||||
if model_dir_or_name.lower() in PRETRAIN_STATIC_FILES:
|
||||
PRETRAIN_URL = _get_base_url('static')
|
||||
@ -202,8 +202,40 @@ class StaticEmbedding(TokenEmbedding):
|
||||
raise ValueError(f"Cannot recognize {model_dir_or_name}.")
|
||||
|
||||
# 读取embedding
|
||||
embedding = self._load_with_vocab(model_path, vocab=vocab, init_method=init_method,
|
||||
normalize=normalize)
|
||||
if lower:
|
||||
lowered_vocab = Vocabulary(padding=vocab.padding, unknown=vocab.unknown)
|
||||
for word, index in vocab:
|
||||
if not vocab._is_word_no_create_entry(word):
|
||||
lowered_vocab.add_word(word.lower()) # 先加入需要创建entry的
|
||||
for word in vocab._no_create_word.keys(): # 不需要创建entry的
|
||||
if word in vocab:
|
||||
lowered_word = word.lower()
|
||||
if lowered_word not in lowered_vocab.word_count:
|
||||
lowered_vocab.add_word(lowered_word)
|
||||
lowered_vocab._no_create_word[lowered_word] += 1
|
||||
print(f"All word in vocab have been lowered. There are {len(vocab)} words, {len(lowered_vocab)} unique lowered "
|
||||
f"words.")
|
||||
embedding = self._load_with_vocab(model_path, vocab=lowered_vocab, init_method=init_method,
|
||||
normalize=normalize)
|
||||
# 需要适配一下
|
||||
if not hasattr(self, 'words_to_words'):
|
||||
self.words_to_words = torch.arange(len(lowered_vocab, )).long()
|
||||
if lowered_vocab.unknown:
|
||||
unknown_idx = lowered_vocab.unknown_idx
|
||||
else:
|
||||
unknown_idx = embedding.size(0) - 1 # 否则是最后一个为unknow
|
||||
words_to_words = nn.Parameter(torch.full((len(vocab),), fill_value=unknown_idx).long(),
|
||||
requires_grad=False)
|
||||
for word, index in vocab:
|
||||
if word not in lowered_vocab:
|
||||
word = word.lower()
|
||||
if lowered_vocab._is_word_no_create_entry(word): # 如果不需要创建entry,已经默认unknown了
|
||||
continue
|
||||
words_to_words[index] = self.words_to_words[lowered_vocab.to_index(word)]
|
||||
self.words_to_words = words_to_words
|
||||
else:
|
||||
embedding = self._load_with_vocab(model_path, vocab=vocab, init_method=init_method,
|
||||
normalize=normalize)
|
||||
self.embedding = nn.Embedding(num_embeddings=embedding.shape[0], embedding_dim=embedding.shape[1],
|
||||
padding_idx=vocab.padding_idx,
|
||||
max_norm=None, norm_type=2, scale_grad_by_freq=False,
|
||||
@ -301,7 +333,7 @@ class StaticEmbedding(TokenEmbedding):
|
||||
if vocab._no_create_word_length>0:
|
||||
if vocab.unknown is None: # 创建一个专门的unknown
|
||||
unknown_idx = len(matrix)
|
||||
vectors = torch.cat([vectors, torch.zeros(1, dim)], dim=0).contiguous()
|
||||
vectors = torch.cat((vectors, torch.zeros(1, dim)), dim=0).contiguous()
|
||||
else:
|
||||
unknown_idx = vocab.unknown_idx
|
||||
words_to_words = nn.Parameter(torch.full((len(vocab),), fill_value=unknown_idx).long(),
|
||||
@ -438,19 +470,15 @@ class ElmoEmbedding(ContextualEmbedding):
|
||||
:param model_dir_or_name: 可以有两种方式调用预训练好的ELMo embedding:第一种是传入ELMo权重的文件名,第二种是传入ELMo版本的名称,
|
||||
目前支持的ELMo包括{`en` : 英文版本的ELMo, `cn` : 中文版本的ELMo,}。第二种情况将自动查看缓存中是否存在该模型,没有的话将自动下载
|
||||
:param layers: str, 指定返回的层数, 以,隔开不同的层。如果要返回第二层的结果'2', 返回后两层的结果'1,2'。不同的层的结果
|
||||
按照这个顺序concat起来。默认为'2'。
|
||||
:param requires_grad: bool, 该层是否需要gradient. 默认为False
|
||||
按照这个顺序concat起来。默认为'2'。'mix'会使用可学习的权重结合不同层的表示(权重是否可训练与requires_grad保持一致,
|
||||
初始化权重对三层结果进行mean-pooling, 可以通过ElmoEmbedding.set_mix_weights_requires_grad()方法只将mix weights设置为可学习。)
|
||||
:param requires_grad: bool, 该层是否需要gradient, 默认为False.
|
||||
:param cache_word_reprs: 可以选择对word的表示进行cache; 设置为True的话,将在初始化的时候为每个word生成对应的embedding,
|
||||
并删除character encoder,之后将直接使用cache的embedding。默认为False。
|
||||
"""
|
||||
def __init__(self, vocab: Vocabulary, model_dir_or_name: str='en',
|
||||
layers: str='2', requires_grad: bool=False, cache_word_reprs: bool=False):
|
||||
super(ElmoEmbedding, self).__init__(vocab)
|
||||
layers = list(map(int, layers.split(',')))
|
||||
assert len(layers) > 0, "Must choose one output"
|
||||
for layer in layers:
|
||||
assert 0 <= layer <= 2, "Layer index should be in range [0, 2]."
|
||||
self.layers = layers
|
||||
|
||||
# 根据model_dir_or_name检查是否存在并下载
|
||||
if model_dir_or_name.lower() in PRETRAINED_ELMO_MODEL_DIR:
|
||||
@ -464,8 +492,49 @@ class ElmoEmbedding(ContextualEmbedding):
|
||||
else:
|
||||
raise ValueError(f"Cannot recognize {model_dir_or_name}.")
|
||||
self.model = _ElmoModel(model_dir, vocab, cache_word_reprs=cache_word_reprs)
|
||||
|
||||
if layers=='mix':
|
||||
self.layer_weights = nn.Parameter(torch.zeros(self.model.config['encoder']['n_layers']+1),
|
||||
requires_grad=requires_grad)
|
||||
self.gamma = nn.Parameter(torch.ones(1), requires_grad=requires_grad)
|
||||
self._get_outputs = self._get_mixed_outputs
|
||||
self._embed_size = self.model.config['encoder']['projection_dim'] * 2
|
||||
else:
|
||||
layers = list(map(int, layers.split(',')))
|
||||
assert len(layers) > 0, "Must choose one output"
|
||||
for layer in layers:
|
||||
assert 0 <= layer <= 2, "Layer index should be in range [0, 2]."
|
||||
self.layers = layers
|
||||
self._get_outputs = self._get_layer_outputs
|
||||
self._embed_size = len(self.layers) * self.model.config['encoder']['projection_dim'] * 2
|
||||
|
||||
self.requires_grad = requires_grad
|
||||
self._embed_size = len(self.layers) * self.model.config['encoder']['projection_dim'] * 2
|
||||
|
||||
def _get_mixed_outputs(self, outputs):
|
||||
# outputs: num_layers x batch_size x max_len x hidden_size
|
||||
# return: batch_size x max_len x hidden_size
|
||||
weights = F.softmax(self.layer_weights+1/len(outputs), dim=0).to(outputs)
|
||||
outputs = torch.einsum('l,lbij->bij', weights, outputs)
|
||||
return self.gamma.to(outputs)*outputs
|
||||
|
||||
def set_mix_weights_requires_grad(self, flag=True):
|
||||
"""
|
||||
当初始化ElmoEmbedding时layers被设置为mix时,可以通过调用该方法设置mix weights是否可训练。如果layers不是mix,调用
|
||||
该方法没有用。
|
||||
:param bool flag: 混合不同层表示的结果是否可以训练。
|
||||
:return:
|
||||
"""
|
||||
if hasattr(self, 'layer_weights'):
|
||||
self.layer_weights.requires_grad = flag
|
||||
self.gamma.requires_grad = flag
|
||||
|
||||
def _get_layer_outputs(self, outputs):
|
||||
if len(self.layers) == 1:
|
||||
outputs = outputs[self.layers[0]]
|
||||
else:
|
||||
outputs = torch.cat(tuple([*outputs[self.layers]]), dim=-1)
|
||||
|
||||
return outputs
|
||||
|
||||
def forward(self, words: torch.LongTensor):
|
||||
"""
|
||||
@ -480,15 +549,12 @@ class ElmoEmbedding(ContextualEmbedding):
|
||||
if outputs is not None:
|
||||
return outputs
|
||||
outputs = self.model(words)
|
||||
if len(self.layers) == 1:
|
||||
outputs = outputs[self.layers[0]]
|
||||
else:
|
||||
outputs = torch.cat([*outputs[self.layers]], dim=-1)
|
||||
|
||||
return outputs
|
||||
return self._get_outputs(outputs)
|
||||
|
||||
def _delete_model_weights(self):
|
||||
del self.layers, self.model
|
||||
for name in ['layers', 'model', 'layer_weights', 'gamma']:
|
||||
if hasattr(self, name):
|
||||
delattr(self, name)
|
||||
|
||||
@property
|
||||
def requires_grad(self):
|
||||
@ -892,10 +958,11 @@ class StackEmbedding(TokenEmbedding):
|
||||
def __init__(self, embeds: List[TokenEmbedding]):
|
||||
vocabs = []
|
||||
for embed in embeds:
|
||||
vocabs.append(embed.get_word_vocab())
|
||||
if hasattr(embed, 'get_word_vocab'):
|
||||
vocabs.append(embed.get_word_vocab())
|
||||
_vocab = vocabs[0]
|
||||
for vocab in vocabs[1:]:
|
||||
assert vocab == _vocab, "All embeddings should use the same word vocabulary."
|
||||
assert vocab == _vocab, "All embeddings in StackEmbedding should use the same word vocabulary."
|
||||
|
||||
super(StackEmbedding, self).__init__(_vocab)
|
||||
assert isinstance(embeds, list)
|
||||
|
@ -1,93 +0,0 @@
|
||||
|
||||
from fastNLP.core.vocabulary import VocabularyOption
|
||||
from fastNLP.io.base_loader import DataSetLoader, DataInfo
|
||||
from typing import Union, Dict
|
||||
from fastNLP import Vocabulary
|
||||
from fastNLP import Const
|
||||
from reproduction.utils import check_dataloader_paths
|
||||
|
||||
from fastNLP.io.dataset_loader import ConllLoader
|
||||
from reproduction.seqence_labelling.ner.data.utils import iob2bioes, iob2
|
||||
|
||||
|
||||
class Conll2003DataLoader(DataSetLoader):
|
||||
def __init__(self, task:str='ner', encoding_type:str='bioes'):
|
||||
"""
|
||||
加载Conll2003格式的英语语料,该数据集的信息可以在https://www.clips.uantwerpen.be/conll2003/ner/找到。当task为pos
|
||||
时,返回的DataSet中target取值于第2列; 当task为chunk时,返回的DataSet中target取值于第3列;当task为ner时,返回
|
||||
的DataSet中target取值于第4列。所有"-DOCSTART- -X- O O"将被忽略,这会导致数据的数量少于很多文献报道的值,但
|
||||
鉴于"-DOCSTART- -X- O O"只是用于文档分割的符号,并不应该作为预测对象,所以我们忽略了数据中的-DOCTSTART-开头的行
|
||||
ner与chunk任务读取后的数据的target将为encoding_type类型。pos任务读取后就是pos列的数据。
|
||||
|
||||
:param task: 指定需要标注任务。可选ner, pos, chunk
|
||||
"""
|
||||
assert task in ('ner', 'pos', 'chunk')
|
||||
index = {'ner':3, 'pos':1, 'chunk':2}[task]
|
||||
self._loader = ConllLoader(headers=['raw_words', 'target'], indexes=[0, index])
|
||||
self._tag_converters = None
|
||||
if task in ('ner', 'chunk'):
|
||||
self._tag_converters = [iob2]
|
||||
if encoding_type == 'bioes':
|
||||
self._tag_converters.append(iob2bioes)
|
||||
|
||||
def load(self, path: str):
|
||||
dataset = self._loader.load(path)
|
||||
def convert_tag_schema(tags):
|
||||
for converter in self._tag_converters:
|
||||
tags = converter(tags)
|
||||
return tags
|
||||
if self._tag_converters:
|
||||
dataset.apply_field(convert_tag_schema, field_name=Const.TARGET, new_field_name=Const.TARGET)
|
||||
return dataset
|
||||
|
||||
def process(self, paths: Union[str, Dict[str, str]], word_vocab_opt:VocabularyOption=None, lower:bool=True):
|
||||
"""
|
||||
读取并处理数据。数据中的'-DOCSTART-'开头的行会被忽略
|
||||
|
||||
:param paths:
|
||||
:param word_vocab_opt: vocabulary的初始化值
|
||||
:param lower: 是否将所有字母转为小写
|
||||
:return:
|
||||
"""
|
||||
# 读取数据
|
||||
paths = check_dataloader_paths(paths)
|
||||
data = DataInfo()
|
||||
input_fields = [Const.TARGET, Const.INPUT, Const.INPUT_LEN]
|
||||
target_fields = [Const.TARGET, Const.INPUT_LEN]
|
||||
for name, path in paths.items():
|
||||
dataset = self.load(path)
|
||||
dataset.apply_field(lambda words: words, field_name='raw_words', new_field_name=Const.INPUT)
|
||||
if lower:
|
||||
dataset.words.lower()
|
||||
data.datasets[name] = dataset
|
||||
|
||||
# 对construct vocab
|
||||
word_vocab = Vocabulary(min_freq=2) if word_vocab_opt is None else Vocabulary(**word_vocab_opt)
|
||||
word_vocab.from_dataset(data.datasets['train'], field_name=Const.INPUT,
|
||||
no_create_entry_dataset=[dataset for name, dataset in data.datasets.items() if name!='train'])
|
||||
word_vocab.index_dataset(*data.datasets.values(), field_name=Const.INPUT, new_field_name=Const.INPUT)
|
||||
data.vocabs[Const.INPUT] = word_vocab
|
||||
|
||||
# cap words
|
||||
cap_word_vocab = Vocabulary()
|
||||
cap_word_vocab.from_dataset(data.datasets['train'], field_name='raw_words',
|
||||
no_create_entry_dataset=[dataset for name, dataset in data.datasets.items() if name!='train'])
|
||||
cap_word_vocab.index_dataset(*data.datasets.values(), field_name='raw_words', new_field_name='cap_words')
|
||||
input_fields.append('cap_words')
|
||||
data.vocabs['cap_words'] = cap_word_vocab
|
||||
|
||||
# 对target建vocab
|
||||
target_vocab = Vocabulary(unknown=None, padding=None)
|
||||
target_vocab.from_dataset(*data.datasets.values(), field_name=Const.TARGET)
|
||||
target_vocab.index_dataset(*data.datasets.values(), field_name=Const.TARGET)
|
||||
data.vocabs[Const.TARGET] = target_vocab
|
||||
|
||||
for name, dataset in data.datasets.items():
|
||||
dataset.add_seq_len(Const.INPUT, new_field_name=Const.INPUT_LEN)
|
||||
dataset.set_input(*input_fields)
|
||||
dataset.set_target(*target_fields)
|
||||
|
||||
return data
|
||||
|
||||
if __name__ == '__main__':
|
||||
pass
|
@ -1,152 +0,0 @@
|
||||
from fastNLP.core.vocabulary import VocabularyOption
|
||||
from fastNLP.io.base_loader import DataSetLoader, DataInfo
|
||||
from typing import Union, Dict
|
||||
from fastNLP import DataSet
|
||||
from fastNLP import Vocabulary
|
||||
from fastNLP import Const
|
||||
from reproduction.utils import check_dataloader_paths
|
||||
|
||||
from fastNLP.io.dataset_loader import ConllLoader
|
||||
from reproduction.seqence_labelling.ner.data.utils import iob2bioes, iob2
|
||||
|
||||
class OntoNoteNERDataLoader(DataSetLoader):
|
||||
"""
|
||||
用于读取处理为Conll格式后的OntoNote数据。将OntoNote数据处理为conll格式的过程可以参考https://github.com/yhcc/OntoNotes-5.0-NER。
|
||||
|
||||
"""
|
||||
def __init__(self, encoding_type:str='bioes'):
|
||||
assert encoding_type in ('bioes', 'bio')
|
||||
self.encoding_type = encoding_type
|
||||
if encoding_type=='bioes':
|
||||
self.encoding_method = iob2bioes
|
||||
else:
|
||||
self.encoding_method = iob2
|
||||
|
||||
def load(self, path:str)->DataSet:
|
||||
"""
|
||||
给定一个文件路径,读取数据。返回的DataSet包含以下的field
|
||||
raw_words: List[str]
|
||||
target: List[str]
|
||||
|
||||
:param path:
|
||||
:return:
|
||||
"""
|
||||
dataset = ConllLoader(headers=['raw_words', 'target'], indexes=[3, 10]).load(path)
|
||||
def convert_to_bio(tags):
|
||||
bio_tags = []
|
||||
flag = None
|
||||
for tag in tags:
|
||||
label = tag.strip("()*")
|
||||
if '(' in tag:
|
||||
bio_label = 'B-' + label
|
||||
flag = label
|
||||
elif flag:
|
||||
bio_label = 'I-' + flag
|
||||
else:
|
||||
bio_label = 'O'
|
||||
if ')' in tag:
|
||||
flag = None
|
||||
bio_tags.append(bio_label)
|
||||
return self.encoding_method(bio_tags)
|
||||
|
||||
def convert_word(words):
|
||||
converted_words = []
|
||||
for word in words:
|
||||
word = word.replace('/.', '.') # 有些结尾的.是/.形式的
|
||||
if not word.startswith('-'):
|
||||
converted_words.append(word)
|
||||
continue
|
||||
# 以下是由于这些符号被转义了,再转回来
|
||||
tfrs = {'-LRB-':'(',
|
||||
'-RRB-': ')',
|
||||
'-LSB-': '[',
|
||||
'-RSB-': ']',
|
||||
'-LCB-': '{',
|
||||
'-RCB-': '}'
|
||||
}
|
||||
if word in tfrs:
|
||||
converted_words.append(tfrs[word])
|
||||
else:
|
||||
converted_words.append(word)
|
||||
return converted_words
|
||||
|
||||
dataset.apply_field(convert_word, field_name='raw_words', new_field_name='raw_words')
|
||||
dataset.apply_field(convert_to_bio, field_name='target', new_field_name='target')
|
||||
|
||||
return dataset
|
||||
|
||||
def process(self, paths: Union[str, Dict[str, str]], word_vocab_opt:VocabularyOption=None,
|
||||
lower:bool=True)->DataInfo:
|
||||
"""
|
||||
读取并处理数据。返回的DataInfo包含以下的内容
|
||||
vocabs:
|
||||
word: Vocabulary
|
||||
target: Vocabulary
|
||||
datasets:
|
||||
train: DataSet
|
||||
words: List[int], 被设置为input
|
||||
target: int. label,被同时设置为input和target
|
||||
seq_len: int. 句子的长度,被同时设置为input和target
|
||||
raw_words: List[str]
|
||||
xxx(根据传入的paths可能有所变化)
|
||||
|
||||
:param paths:
|
||||
:param word_vocab_opt: vocabulary的初始化值
|
||||
:param lower: 是否使用小写
|
||||
:return:
|
||||
"""
|
||||
paths = check_dataloader_paths(paths)
|
||||
data = DataInfo()
|
||||
input_fields = [Const.TARGET, Const.INPUT, Const.INPUT_LEN]
|
||||
target_fields = [Const.TARGET, Const.INPUT_LEN]
|
||||
for name, path in paths.items():
|
||||
dataset = self.load(path)
|
||||
dataset.apply_field(lambda words: words, field_name='raw_words', new_field_name=Const.INPUT)
|
||||
if lower:
|
||||
dataset.words.lower()
|
||||
data.datasets[name] = dataset
|
||||
|
||||
# 对construct vocab
|
||||
word_vocab = Vocabulary(min_freq=2) if word_vocab_opt is None else Vocabulary(**word_vocab_opt)
|
||||
word_vocab.from_dataset(data.datasets['train'], field_name=Const.INPUT,
|
||||
no_create_entry_dataset=[dataset for name, dataset in data.datasets.items() if name!='train'])
|
||||
word_vocab.index_dataset(*data.datasets.values(), field_name=Const.INPUT, new_field_name=Const.INPUT)
|
||||
data.vocabs[Const.INPUT] = word_vocab
|
||||
|
||||
# cap words
|
||||
cap_word_vocab = Vocabulary()
|
||||
cap_word_vocab.from_dataset(*data.datasets.values(), field_name='raw_words')
|
||||
cap_word_vocab.index_dataset(*data.datasets.values(), field_name='raw_words', new_field_name='cap_words')
|
||||
input_fields.append('cap_words')
|
||||
data.vocabs['cap_words'] = cap_word_vocab
|
||||
|
||||
# 对target建vocab
|
||||
target_vocab = Vocabulary(unknown=None, padding=None)
|
||||
target_vocab.from_dataset(*data.datasets.values(), field_name=Const.TARGET)
|
||||
target_vocab.index_dataset(*data.datasets.values(), field_name=Const.TARGET)
|
||||
data.vocabs[Const.TARGET] = target_vocab
|
||||
|
||||
for name, dataset in data.datasets.items():
|
||||
dataset.add_seq_len(Const.INPUT, new_field_name=Const.INPUT_LEN)
|
||||
dataset.set_input(*input_fields)
|
||||
dataset.set_target(*target_fields)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
loader = OntoNoteNERDataLoader()
|
||||
dataset = loader.load('/hdd/fudanNLP/fastNLP/others/data/v4/english/test.txt')
|
||||
print(dataset.target.value_count())
|
||||
print(dataset[:4])
|
||||
|
||||
|
||||
"""
|
||||
train 115812 2200752
|
||||
development 15680 304684
|
||||
test 12217 230111
|
||||
|
||||
train 92403 1901772
|
||||
valid 13606 279180
|
||||
test 10258 204135
|
||||
"""
|
@ -1,49 +0,0 @@
|
||||
from typing import List
|
||||
|
||||
def iob2(tags:List[str])->List[str]:
|
||||
"""
|
||||
检查数据是否是合法的IOB数据,如果是IOB1会被自动转换为IOB2。
|
||||
|
||||
:param tags: 需要转换的tags
|
||||
"""
|
||||
for i, tag in enumerate(tags):
|
||||
if tag == "O":
|
||||
continue
|
||||
split = tag.split("-")
|
||||
if len(split) != 2 or split[0] not in ["I", "B"]:
|
||||
raise TypeError("The encoding schema is not a valid IOB type.")
|
||||
if split[0] == "B":
|
||||
continue
|
||||
elif i == 0 or tags[i - 1] == "O": # conversion IOB1 to IOB2
|
||||
tags[i] = "B" + tag[1:]
|
||||
elif tags[i - 1][1:] == tag[1:]:
|
||||
continue
|
||||
else: # conversion IOB1 to IOB2
|
||||
tags[i] = "B" + tag[1:]
|
||||
return tags
|
||||
|
||||
def iob2bioes(tags:List[str])->List[str]:
|
||||
"""
|
||||
将iob的tag转换为bmeso编码
|
||||
:param tags:
|
||||
:return:
|
||||
"""
|
||||
new_tags = []
|
||||
for i, tag in enumerate(tags):
|
||||
if tag == 'O':
|
||||
new_tags.append(tag)
|
||||
else:
|
||||
split = tag.split('-')[0]
|
||||
if split == 'B':
|
||||
if i+1!=len(tags) and tags[i+1].split('-')[0] == 'I':
|
||||
new_tags.append(tag)
|
||||
else:
|
||||
new_tags.append(tag.replace('B-', 'S-'))
|
||||
elif split == 'I':
|
||||
if i + 1<len(tags) and tags[i+1].split('-')[0] == 'I':
|
||||
new_tags.append(tag)
|
||||
else:
|
||||
new_tags.append(tag.replace('I-', 'E-'))
|
||||
else:
|
||||
raise TypeError("Invalid IOB format.")
|
||||
return new_tags
|
@ -29,13 +29,15 @@ def check_dataloader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]:
|
||||
path_pair = ('train', filename)
|
||||
if 'dev' in filename:
|
||||
if path_pair:
|
||||
raise Exception("File:{} in {} contains bot `{}` and `dev`.".format(filename, paths, path_pair[0]))
|
||||
raise Exception("File:{} in {} contains both `{}` and `dev`.".format(filename, paths, path_pair[0]))
|
||||
path_pair = ('dev', filename)
|
||||
if 'test' in filename:
|
||||
if path_pair:
|
||||
raise Exception("File:{} in {} contains bot `{}` and `test`.".format(filename, paths, path_pair[0]))
|
||||
raise Exception("File:{} in {} contains both `{}` and `test`.".format(filename, paths, path_pair[0]))
|
||||
path_pair = ('test', filename)
|
||||
if path_pair:
|
||||
if path_pair[0] in files:
|
||||
raise RuntimeError(f"Multiple file under {paths} have '{path_pair[0]}' in their filename.")
|
||||
files[path_pair[0]] = os.path.join(paths, path_pair[1])
|
||||
return files
|
||||
else:
|
||||
|
Loading…
Reference in New Issue
Block a user