mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-02 12:17:35 +08:00
增加DataBundle的方法;增加BilSTMCRF的注释
This commit is contained in:
parent
b134c9f7e7
commit
9529f89abd
@ -575,18 +575,18 @@ class DataSet(object):
|
||||
"""
|
||||
return len(self)
|
||||
|
||||
def rename_field(self, old_name, new_name):
|
||||
def rename_field(self, field_name, new_field_name):
|
||||
"""
|
||||
将某个field重新命名.
|
||||
|
||||
:param str old_name: 原来的field名称。
|
||||
:param str new_name: 修改为new_name。
|
||||
:param str field_name: 原来的field名称。
|
||||
:param str new_field_name: 修改为new_name。
|
||||
"""
|
||||
if old_name in self.field_arrays:
|
||||
self.field_arrays[new_name] = self.field_arrays.pop(old_name)
|
||||
self.field_arrays[new_name].name = new_name
|
||||
if field_name in self.field_arrays:
|
||||
self.field_arrays[new_field_name] = self.field_arrays.pop(field_name)
|
||||
self.field_arrays[new_field_name].name = new_field_name
|
||||
else:
|
||||
raise KeyError("DataSet has no field named {}.".format(old_name))
|
||||
raise KeyError("DataSet has no field named {}.".format(field_name))
|
||||
return self
|
||||
|
||||
def set_target(self, *field_names, flag=True, use_1st_ins_infer_dim_type=True):
|
||||
|
@ -139,9 +139,44 @@ class DataBundle:
|
||||
dataset.set_target(field_name, flag=flag, use_1st_ins_infer_dim_type=use_1st_ins_infer_dim_type)
|
||||
return self
|
||||
|
||||
def set_pad_val(self, field_name, pad_val, ignore_miss_dataset=True):
|
||||
"""
|
||||
将DataBundle中所有的DataSet中名为field_name的Field的padding值设置为pad_val.
|
||||
|
||||
:param str field_name:
|
||||
:param int pad_val:
|
||||
:param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet;
|
||||
如果为False,则报错
|
||||
:return: self
|
||||
"""
|
||||
for name, dataset in self.datasets.items():
|
||||
if dataset.has_field(field_name=field_name):
|
||||
dataset.set_pad_val(field_name=field_name, pad_val=pad_val)
|
||||
elif not ignore_miss_dataset:
|
||||
raise KeyError(f"{field_name} not found DataSet:{name}.")
|
||||
return self
|
||||
|
||||
def set_ignore_type(self, *field_names, flag=True, ignore_miss_dataset=True):
|
||||
"""
|
||||
将DataBundle中所有的DataSet中名为*field_names的Field的ignore_type设置为flag状态
|
||||
|
||||
:param str field_names:
|
||||
:param bool flag:
|
||||
:param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet;
|
||||
如果为False,则报错
|
||||
:return: self
|
||||
"""
|
||||
for name, dataset in self.datasets.items():
|
||||
for field_name in field_names:
|
||||
if dataset.has_field(field_name=field_name):
|
||||
dataset.set_ignore_type(field_name, flag=flag)
|
||||
elif not ignore_miss_dataset:
|
||||
raise KeyError(f"{field_name} not found DataSet:{name}.")
|
||||
return self
|
||||
|
||||
def copy_field(self, field_name, new_field_name, ignore_miss_dataset=True):
|
||||
"""
|
||||
将DataBundle中所有的field_name复制一份叫new_field_name.
|
||||
将DataBundle中所有的DataSet中名为field_name的Field复制一份并命名为叫new_field_name.
|
||||
|
||||
:param str field_name:
|
||||
:param str new_field_name:
|
||||
@ -156,9 +191,42 @@ class DataBundle:
|
||||
raise KeyError(f"{field_name} not found DataSet:{name}.")
|
||||
return self
|
||||
|
||||
def rename_field(self, field_name, new_field_name, ignore_miss_dataset=True):
|
||||
"""
|
||||
将DataBundle中所有DataSet中名为field_name的field重命名为new_field_name.
|
||||
|
||||
:param str field_name:
|
||||
:param str new_field_name:
|
||||
:param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet;
|
||||
如果为False,则报错
|
||||
:return: self
|
||||
"""
|
||||
for name, dataset in self.datasets.items():
|
||||
if dataset.has_field(field_name=field_name):
|
||||
dataset.rename_field(field_name=field_name, new_field_name=new_field_name)
|
||||
elif not ignore_miss_dataset:
|
||||
raise KeyError(f"{field_name} not found DataSet:{name}.")
|
||||
return self
|
||||
|
||||
def delete_field(self, field_name, ignore_miss_dataset=True):
|
||||
"""
|
||||
将DataBundle中所有DataSet中名为field_name的field删除掉.
|
||||
|
||||
:param str field_name:
|
||||
:param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet;
|
||||
如果为False,则报错
|
||||
:return: self
|
||||
"""
|
||||
for name, dataset in self.datasets.items():
|
||||
if dataset.has_field(field_name=field_name):
|
||||
dataset.delete_field(field_name=field_name)
|
||||
elif not ignore_miss_dataset:
|
||||
raise KeyError(f"{field_name} not found DataSet:{name}.")
|
||||
return self
|
||||
|
||||
def apply_field(self, func, field_name:str, new_field_name:str, ignore_miss_dataset=True, **kwargs):
|
||||
"""
|
||||
对DataBundle中所有的dataset使用apply方法
|
||||
对DataBundle中所有的dataset使用apply_field方法
|
||||
|
||||
:param callable func: input是instance中名为 `field_name` 的field的内容。
|
||||
:param str field_name: 传入func的是哪个field。
|
||||
|
@ -4,7 +4,7 @@
|
||||
__all__ = [
|
||||
"SeqLabeling",
|
||||
"AdvSeqLabel",
|
||||
# "BiLSTMCRF"
|
||||
"BiLSTMCRF"
|
||||
]
|
||||
|
||||
import torch
|
||||
@ -14,7 +14,6 @@ import torch.nn.functional as F
|
||||
from .base_model import BaseModel
|
||||
from ..core.const import Const as C
|
||||
from ..core.utils import seq_len_to_mask
|
||||
from ..embeddings import embedding
|
||||
from ..embeddings import get_embeddings
|
||||
from ..modules import ConditionalRandomField
|
||||
from ..modules import LSTM
|
||||
@ -24,18 +23,15 @@ from ..modules.decoder.crf import allowed_transitions
|
||||
|
||||
class BiLSTMCRF(BaseModel):
|
||||
"""
|
||||
结构为BiLSTM + FC + Dropout + CRF.
|
||||
结构为embedding + BiLSTM + FC + Dropout + CRF.
|
||||
|
||||
.. todo::
|
||||
继续补充文档
|
||||
|
||||
:param embed: tuple:
|
||||
:param num_classes:
|
||||
:param num_layers:
|
||||
:param hidden_size:
|
||||
:param dropout:
|
||||
:param target_vocab:
|
||||
:param encoding_type:
|
||||
:param embed: 支持(1)fastNLP的各种Embedding, (2) tuple, 指明num_embedding, dimension, 如(1000, 100)
|
||||
:param num_classes: 一共多少个类
|
||||
:param num_layers: BiLSTM的层数
|
||||
:param hidden_size: BiLSTM的hidden_size,实际hidden size为该值的两倍(前向、后向)
|
||||
:param dropout: dropout的概率,0为不dropout
|
||||
:param target_vocab: Vocabulary对象,target与index的对应关系
|
||||
:param encoding_type: encoding的类型,支持'bioes', 'bmes', 'bio', 'bmeso'等
|
||||
"""
|
||||
def __init__(self, embed, num_classes, num_layers=1, hidden_size=100, dropout=0.5,
|
||||
target_vocab=None, encoding_type=None):
|
||||
@ -86,21 +82,20 @@ class SeqLabeling(BaseModel):
|
||||
一个基础的Sequence labeling的模型。
|
||||
用于做sequence labeling的基础类。结构包含一层Embedding,一层LSTM(单向,一层),一层FC,以及一层CRF。
|
||||
|
||||
:param tuple(int,int),torch.FloatTensor,nn.Embedding,numpy.ndarray init_embed: Embedding的大小(传入tuple(int, int),
|
||||
第一个int为vocab_zie, 第二个int为embed_dim); 如果为Tensor, Embedding, ndarray等则直接使用该值初始化Embedding
|
||||
:param tuple(int,int),torch.FloatTensor,nn.Embedding,numpy.ndarray embed: Embedding的大小(传入tuple(int, int),
|
||||
第一个int为vocab_zie, 第二个int为embed_dim); 如果为Tensor, embedding, ndarray等则直接使用该值初始化Embedding
|
||||
:param int hidden_size: LSTM隐藏层的大小
|
||||
:param int num_classes: 一共有多少类
|
||||
"""
|
||||
|
||||
def __init__(self, init_embed, hidden_size, num_classes):
|
||||
def __init__(self, embed, hidden_size, num_classes):
|
||||
super(SeqLabeling, self).__init__()
|
||||
|
||||
self.Embedding = embedding.Embedding(init_embed)
|
||||
self.Rnn = encoder.LSTM(self.Embedding.embedding_dim, hidden_size)
|
||||
self.Linear = nn.Linear(hidden_size, num_classes)
|
||||
self.Crf = decoder.ConditionalRandomField(num_classes)
|
||||
self.mask = None
|
||||
|
||||
self.embedding = get_embeddings(embed)
|
||||
self.rnn = encoder.LSTM(self.embedding.embedding_dim, hidden_size)
|
||||
self.fc = nn.Linear(hidden_size, num_classes)
|
||||
self.crf = decoder.ConditionalRandomField(num_classes)
|
||||
|
||||
def forward(self, words, seq_len, target):
|
||||
"""
|
||||
:param torch.LongTensor words: [batch_size, max_len],序列的index
|
||||
@ -109,17 +104,14 @@ class SeqLabeling(BaseModel):
|
||||
:return y: If truth is None, return list of [decode path(list)]. Used in testing and predicting.
|
||||
If truth is not None, return loss, a scalar. Used in training.
|
||||
"""
|
||||
assert words.shape[0] == seq_len.shape[0]
|
||||
assert target.shape == words.shape
|
||||
self.mask = self._make_mask(words, seq_len)
|
||||
|
||||
x = self.Embedding(words)
|
||||
mask = seq_len_to_mask(seq_len, max_len=words.size(1))
|
||||
x = self.embedding(words)
|
||||
# [batch_size, max_len, word_emb_dim]
|
||||
x, _ = self.Rnn(x, seq_len)
|
||||
x, _ = self.rnn(x, seq_len)
|
||||
# [batch_size, max_len, hidden_size * direction]
|
||||
x = self.Linear(x)
|
||||
x = self.fc(x)
|
||||
# [batch_size, max_len, num_classes]
|
||||
return {C.LOSS: self._internal_loss(x, target)}
|
||||
return {C.LOSS: self._internal_loss(x, target, mask)}
|
||||
|
||||
def predict(self, words, seq_len):
|
||||
"""
|
||||
@ -129,18 +121,18 @@ class SeqLabeling(BaseModel):
|
||||
:param torch.LongTensor seq_len: [batch_size,]
|
||||
:return: {'pred': xx}, [batch_size, max_len]
|
||||
"""
|
||||
self.mask = self._make_mask(words, seq_len)
|
||||
mask = seq_len_to_mask(seq_len, max_len=words.size(1))
|
||||
|
||||
x = self.Embedding(words)
|
||||
x = self.embedding(words)
|
||||
# [batch_size, max_len, word_emb_dim]
|
||||
x, _ = self.Rnn(x, seq_len)
|
||||
x, _ = self.rnn(x, seq_len)
|
||||
# [batch_size, max_len, hidden_size * direction]
|
||||
x = self.Linear(x)
|
||||
x = self.fc(x)
|
||||
# [batch_size, max_len, num_classes]
|
||||
pred = self._decode(x)
|
||||
pred = self._decode(x, mask)
|
||||
return {C.OUTPUT: pred}
|
||||
|
||||
def _internal_loss(self, x, y):
|
||||
def _internal_loss(self, x, y, mask):
|
||||
"""
|
||||
Negative log likelihood loss.
|
||||
:param x: Tensor, [batch_size, max_len, tag_size]
|
||||
@ -152,22 +144,15 @@ class SeqLabeling(BaseModel):
|
||||
y = y.long()
|
||||
assert x.shape[:2] == y.shape
|
||||
assert y.shape == self.mask.shape
|
||||
total_loss = self.Crf(x, y, self.mask)
|
||||
total_loss = self.crf(x, y, mask)
|
||||
return torch.mean(total_loss)
|
||||
|
||||
def _make_mask(self, x, seq_len):
|
||||
batch_size, max_len = x.size(0), x.size(1)
|
||||
mask = seq_len_to_mask(seq_len)
|
||||
mask = mask.view(batch_size, max_len)
|
||||
mask = mask.to(x).float()
|
||||
return mask
|
||||
|
||||
def _decode(self, x):
|
||||
def _decode(self, x, mask):
|
||||
"""
|
||||
:param torch.FloatTensor x: [batch_size, max_len, tag_size]
|
||||
:return prediction: [batch_size, max_len]
|
||||
"""
|
||||
tag_seq, _ = self.Crf.viterbi_decode(x, self.mask)
|
||||
tag_seq, _ = self.crf.viterbi_decode(x, mask)
|
||||
return tag_seq
|
||||
|
||||
|
||||
@ -177,7 +162,7 @@ class AdvSeqLabel(nn.Module):
|
||||
|
||||
更复杂的Sequence Labelling模型。结构为Embedding, LayerNorm, 双向LSTM(两层),FC,LayerNorm,DropOut,FC,CRF。
|
||||
|
||||
:param tuple(int,int),torch.FloatTensor,nn.Embedding,numpy.ndarray init_embed: Embedding的大小(传入tuple(int, int),
|
||||
:param tuple(int,int),torch.FloatTensor,nn.Embedding,numpy.ndarray embed: Embedding的大小(传入tuple(int, int),
|
||||
第一个int为vocab_zie, 第二个int为embed_dim); 如果为Tensor, Embedding, ndarray等则直接使用该值初始化Embedding
|
||||
:param int hidden_size: LSTM的隐层大小
|
||||
:param int num_classes: 有多少个类
|
||||
@ -188,11 +173,11 @@ class AdvSeqLabel(nn.Module):
|
||||
:param str encoding_type: 支持"BIO", "BMES", "BEMSO", 只有在id2words不为None的情况有用。
|
||||
"""
|
||||
|
||||
def __init__(self, init_embed, hidden_size, num_classes, dropout=0.3, id2words=None, encoding_type='bmes'):
|
||||
def __init__(self, embed, hidden_size, num_classes, dropout=0.3, id2words=None, encoding_type='bmes'):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.Embedding = embedding.Embedding(init_embed)
|
||||
self.Embedding = get_embeddings(embed)
|
||||
self.norm1 = torch.nn.LayerNorm(self.Embedding.embedding_dim)
|
||||
self.Rnn = encoder.LSTM(input_size=self.Embedding.embedding_dim, hidden_size=hidden_size, num_layers=2,
|
||||
dropout=dropout,
|
||||
|
@ -18,11 +18,9 @@ from fastNLP.io.pipe.conll import OntoNotesNERPipe
|
||||
|
||||
#######hyper
|
||||
normalize = False
|
||||
lower = False
|
||||
lr = 0.01
|
||||
dropout = 0.5
|
||||
batch_size = 32
|
||||
job_embed = False
|
||||
data_name = 'ontonote'
|
||||
#######hyper
|
||||
|
||||
@ -41,7 +39,7 @@ def cache():
|
||||
word_dropout=0.01,
|
||||
dropout=dropout,
|
||||
lower=True,
|
||||
min_freq=2)
|
||||
min_freq=1)
|
||||
return data, char_embed, word_embed
|
||||
data, char_embed, word_embed = cache()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user