mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-03 20:57:37 +08:00
commit
f2850766b8
@ -1 +0,0 @@
|
||||
|
@ -4,88 +4,6 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
class Action(object):
|
||||
"""Operations shared by Trainer, Tester, or Inference.
|
||||
|
||||
This is designed for reducing replicate codes.
|
||||
- make_batch: produce a min-batch of data. @staticmethod
|
||||
- pad: padding method used in sequence modeling. @staticmethod
|
||||
- mode: change network mode for either train or test. (for PyTorch) @staticmethod
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(Action, self).__init__()
|
||||
|
||||
@staticmethod
|
||||
def make_batch(iterator, use_cuda, output_length=True, max_len=None):
|
||||
"""Batch and Pad data.
|
||||
|
||||
:param iterator: an iterator, (object that implements __next__ method) which returns the next sample.
|
||||
:param use_cuda: bool, whether to use GPU
|
||||
:param output_length: bool, whether to output the original length of the sequence before padding. (default: True)
|
||||
:param max_len: int, maximum sequence length. Longer sequences will be clipped. (default: None)
|
||||
:return :
|
||||
|
||||
if output_length is True,
|
||||
(batch_x, seq_len): tuple of two elements
|
||||
batch_x: list. Each entry is a list of features of a sample. [batch_size, max_len]
|
||||
seq_len: list. The length of the pre-padded sequence, if output_length is True.
|
||||
batch_y: list. Each entry is a list of labels of a sample. [batch_size, num_labels]
|
||||
|
||||
if output_length is False,
|
||||
batch_x: list. Each entry is a list of features of a sample. [batch_size, max_len]
|
||||
batch_y: list. Each entry is a list of labels of a sample. [batch_size, num_labels]
|
||||
"""
|
||||
for batch in iterator:
|
||||
batch_x = [sample[0] for sample in batch]
|
||||
batch_y = [sample[1] for sample in batch]
|
||||
|
||||
batch_x = Action.pad(batch_x)
|
||||
# pad batch_y only if it is a 2-level list
|
||||
if len(batch_y) > 0 and isinstance(batch_y[0], list):
|
||||
batch_y = Action.pad(batch_y)
|
||||
|
||||
# convert list to tensor
|
||||
batch_x = convert_to_torch_tensor(batch_x, use_cuda)
|
||||
batch_y = convert_to_torch_tensor(batch_y, use_cuda)
|
||||
|
||||
# trim data to max_len
|
||||
if max_len is not None and batch_x.size(1) > max_len:
|
||||
batch_x = batch_x[:, :max_len]
|
||||
|
||||
if output_length:
|
||||
seq_len = [len(x) for x in batch_x]
|
||||
yield (batch_x, seq_len), batch_y
|
||||
else:
|
||||
yield batch_x, batch_y
|
||||
|
||||
@staticmethod
|
||||
def pad(batch, fill=0):
|
||||
""" Pad a mini-batch of sequence samples to maximum length of this batch.
|
||||
|
||||
:param batch: list of list
|
||||
:param fill: word index to pad, default 0.
|
||||
:return batch: a padded mini-batch
|
||||
"""
|
||||
max_length = max([len(x) for x in batch])
|
||||
for idx, sample in enumerate(batch):
|
||||
if len(sample) < max_length:
|
||||
batch[idx] = sample + ([fill] * (max_length - len(sample)))
|
||||
return batch
|
||||
|
||||
@staticmethod
|
||||
def mode(model, is_test=False):
|
||||
"""Train mode or Test mode. This is for PyTorch currently.
|
||||
|
||||
:param model: a PyTorch model
|
||||
:param is_test: bool, whether in test mode or not.
|
||||
"""
|
||||
if is_test:
|
||||
model.eval()
|
||||
else:
|
||||
model.train()
|
||||
|
||||
|
||||
def convert_to_torch_tensor(data_list, use_cuda):
|
||||
"""Convert lists into (cuda) Tensors.
|
||||
|
||||
@ -168,19 +86,7 @@ class BaseSampler(object):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, data_set):
|
||||
"""
|
||||
|
||||
:param data_set: multi-level list, of shape [num_example, *]
|
||||
|
||||
"""
|
||||
self.data_set_length = len(data_set)
|
||||
self.data = data_set
|
||||
|
||||
def __len__(self):
|
||||
return self.data_set_length
|
||||
|
||||
def __iter__(self):
|
||||
def __call__(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@ -189,16 +95,8 @@ class SequentialSampler(BaseSampler):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, data_set):
|
||||
"""
|
||||
|
||||
:param data_set: multi-level list
|
||||
|
||||
"""
|
||||
super(SequentialSampler, self).__init__(data_set)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.data)
|
||||
def __call__(self, data_set):
|
||||
return list(range(len(data_set)))
|
||||
|
||||
|
||||
class RandomSampler(BaseSampler):
|
||||
@ -206,17 +104,9 @@ class RandomSampler(BaseSampler):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, data_set):
|
||||
"""
|
||||
def __call__(self, data_set):
|
||||
return list(np.random.permutation(len(data_set)))
|
||||
|
||||
:param data_set: multi-level list
|
||||
|
||||
"""
|
||||
super(RandomSampler, self).__init__(data_set)
|
||||
self.order = np.random.permutation(self.data_set_length)
|
||||
|
||||
def __iter__(self):
|
||||
return iter((self.data[idx] for idx in self.order))
|
||||
|
||||
|
||||
class Batchifier(object):
|
||||
@ -252,6 +142,7 @@ class BucketBatchifier(Batchifier):
|
||||
"""Partition all samples into multiple buckets, each of which contains sentences of approximately the same length.
|
||||
In sampling, first random choose a bucket. Then sample data from it.
|
||||
The number of buckets is decided dynamically by the variance of sentence lengths.
|
||||
TODO: merge it into Batch
|
||||
"""
|
||||
|
||||
def __init__(self, data_set, batch_size, num_buckets, drop_last=True, sampler=None):
|
||||
|
126
fastNLP/core/batch.py
Normal file
126
fastNLP/core/batch.py
Normal file
@ -0,0 +1,126 @@
|
||||
from collections import defaultdict
|
||||
|
||||
import torch
|
||||
|
||||
from fastNLP.core.dataset import DataSet
|
||||
from fastNLP.core.field import TextField, LabelField
|
||||
from fastNLP.core.instance import Instance
|
||||
|
||||
|
||||
class Batch(object):
|
||||
"""Batch is an iterable object which iterates over mini-batches.
|
||||
|
||||
::
|
||||
for batch_x, batch_y in Batch(data_set):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, batch_size, sampler, use_cuda):
|
||||
self.dataset = dataset
|
||||
self.batch_size = batch_size
|
||||
self.sampler = sampler
|
||||
self.use_cuda = use_cuda
|
||||
self.idx_list = None
|
||||
self.curidx = 0
|
||||
|
||||
def __iter__(self):
|
||||
self.idx_list = self.sampler(self.dataset)
|
||||
self.curidx = 0
|
||||
self.lengths = self.dataset.get_length()
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
"""
|
||||
|
||||
:return batch_x: dict of (str: torch.LongTensor), which means (field name: tensor of shape [batch_size, padding_length])
|
||||
batch_x also contains an item (str: list of int) about origin lengths,
|
||||
which means ("field_name_origin_len": origin lengths).
|
||||
E.g.
|
||||
::
|
||||
{'text': tensor([[ 0, 1, 2, 3, 0, 0, 0], 4, 5, 2, 6, 7, 8, 9]]), 'text_origin_len': [4, 7]})
|
||||
|
||||
batch_y: dict of (str: torch.LongTensor), which means (field name: tensor of shape [batch_size, padding_length])
|
||||
All tensors in both batch_x and batch_y will be cuda tensors if use_cuda is True.
|
||||
The names of fields are defined in preprocessor's convert_to_dataset method.
|
||||
|
||||
"""
|
||||
if self.curidx >= len(self.idx_list):
|
||||
raise StopIteration
|
||||
else:
|
||||
endidx = min(self.curidx + self.batch_size, len(self.idx_list))
|
||||
padding_length = {field_name: max(field_length[self.curidx: endidx])
|
||||
for field_name, field_length in self.lengths.items()}
|
||||
origin_lengths = {field_name: field_length[self.curidx: endidx]
|
||||
for field_name, field_length in self.lengths.items()}
|
||||
|
||||
batch_x, batch_y = defaultdict(list), defaultdict(list)
|
||||
for idx in range(self.curidx, endidx):
|
||||
x, y = self.dataset.to_tensor(idx, padding_length)
|
||||
for name, tensor in x.items():
|
||||
batch_x[name].append(tensor)
|
||||
for name, tensor in y.items():
|
||||
batch_y[name].append(tensor)
|
||||
|
||||
batch_origin_length = {}
|
||||
# combine instances into a batch
|
||||
for batch in (batch_x, batch_y):
|
||||
for name, tensor_list in batch.items():
|
||||
if self.use_cuda:
|
||||
batch[name] = torch.stack(tensor_list, dim=0).cuda()
|
||||
else:
|
||||
batch[name] = torch.stack(tensor_list, dim=0)
|
||||
|
||||
# add origin lengths in batch_x
|
||||
for name, tensor in batch_x.items():
|
||||
if self.use_cuda:
|
||||
batch_origin_length[name + "_origin_len"] = torch.LongTensor(origin_lengths[name]).cuda()
|
||||
else:
|
||||
batch_origin_length[name + "_origin_len"] = torch.LongTensor(origin_lengths[name])
|
||||
batch_x.update(batch_origin_length)
|
||||
|
||||
self.curidx += endidx
|
||||
return batch_x, batch_y
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""simple running example
|
||||
"""
|
||||
texts = ["i am a cat",
|
||||
"this is a test of new batch",
|
||||
"haha"
|
||||
]
|
||||
labels = [0, 1, 0]
|
||||
|
||||
# prepare vocabulary
|
||||
vocab = {}
|
||||
for text in texts:
|
||||
for tokens in text.split():
|
||||
if tokens not in vocab:
|
||||
vocab[tokens] = len(vocab)
|
||||
print("vocabulary: ", vocab)
|
||||
|
||||
# prepare input dataset
|
||||
data = DataSet()
|
||||
for text, label in zip(texts, labels):
|
||||
x = TextField(text.split(), False)
|
||||
y = LabelField(label, is_target=True)
|
||||
ins = Instance(text=x, label=y)
|
||||
data.append(ins)
|
||||
|
||||
# use vocabulary to index data
|
||||
data.index_field("text", vocab)
|
||||
|
||||
|
||||
# define naive sampler for batch class
|
||||
class SeqSampler:
|
||||
def __call__(self, dataset):
|
||||
return list(range(len(dataset)))
|
||||
|
||||
|
||||
# use batch to iterate dataset
|
||||
data_iterator = Batch(data, 2, SeqSampler(), False)
|
||||
for epoch in range(1):
|
||||
for batch_x, batch_y in data_iterator:
|
||||
print(batch_x)
|
||||
print(batch_y)
|
||||
# do stuff
|
111
fastNLP/core/dataset.py
Normal file
111
fastNLP/core/dataset.py
Normal file
@ -0,0 +1,111 @@
|
||||
from collections import defaultdict
|
||||
|
||||
from fastNLP.core.field import TextField
|
||||
from fastNLP.core.instance import Instance
|
||||
|
||||
|
||||
def create_dataset_from_lists(str_lists: list, word_vocab: dict, has_target: bool = False, label_vocab: dict = None):
|
||||
if has_target is True:
|
||||
if label_vocab is None:
|
||||
raise RuntimeError("Must provide label vocabulary to transform labels.")
|
||||
return create_labeled_dataset_from_lists(str_lists, word_vocab, label_vocab)
|
||||
else:
|
||||
return create_unlabeled_dataset_from_lists(str_lists, word_vocab)
|
||||
|
||||
|
||||
def create_labeled_dataset_from_lists(str_lists, word_vocab, label_vocab):
|
||||
"""Create an DataSet instance that contains labels.
|
||||
|
||||
:param str_lists: list of list of strings, [num_examples, 2, *].
|
||||
::
|
||||
[
|
||||
[[word_11, word_12, ...], [label_11, label_12, ...]],
|
||||
...
|
||||
]
|
||||
|
||||
:param word_vocab: dict of (str: int), which means (word: index).
|
||||
:param label_vocab: dict of (str: int), which means (word: index).
|
||||
:return data_set: a DataSet instance.
|
||||
|
||||
"""
|
||||
data_set = DataSet()
|
||||
for example in str_lists:
|
||||
word_seq, label_seq = example[0], example[1]
|
||||
x = TextField(word_seq, is_target=False)
|
||||
y = TextField(label_seq, is_target=True)
|
||||
data_set.append(Instance(word_seq=x, label_seq=y))
|
||||
data_set.index_field("word_seq", word_vocab)
|
||||
data_set.index_field("label_seq", label_vocab)
|
||||
return data_set
|
||||
|
||||
|
||||
def create_unlabeled_dataset_from_lists(str_lists, word_vocab):
|
||||
"""Create an DataSet instance that contains no labels.
|
||||
|
||||
:param str_lists: list of list of strings, [num_examples, *].
|
||||
::
|
||||
[
|
||||
[word_11, word_12, ...],
|
||||
...
|
||||
]
|
||||
|
||||
:param word_vocab: dict of (str: int), which means (word: index).
|
||||
:return data_set: a DataSet instance.
|
||||
|
||||
"""
|
||||
data_set = DataSet()
|
||||
for word_seq in str_lists:
|
||||
x = TextField(word_seq, is_target=False)
|
||||
data_set.append(Instance(word_seq=x))
|
||||
data_set.index_field("word_seq", word_vocab)
|
||||
return data_set
|
||||
|
||||
|
||||
class DataSet(list):
|
||||
"""A DataSet object is a list of Instance objects.
|
||||
|
||||
"""
|
||||
def __init__(self, name="", instances=None):
|
||||
"""
|
||||
|
||||
:param name: str, the name of the dataset. (default: "")
|
||||
:param instances: list of Instance objects. (default: None)
|
||||
|
||||
"""
|
||||
list.__init__([])
|
||||
self.name = name
|
||||
if instances is not None:
|
||||
self.extend(instances)
|
||||
|
||||
def index_all(self, vocab):
|
||||
for ins in self:
|
||||
ins.index_all(vocab)
|
||||
|
||||
def index_field(self, field_name, vocab):
|
||||
for ins in self:
|
||||
ins.index_field(field_name, vocab)
|
||||
|
||||
def to_tensor(self, idx: int, padding_length: dict):
|
||||
"""Convert an instance in a dataset to tensor.
|
||||
|
||||
:param idx: int, the index of the instance in the dataset.
|
||||
:param padding_length: int
|
||||
:return tensor_x: dict of (str: torch.LongTensor), which means (field name: tensor of shape [padding_length, ])
|
||||
tensor_y: dict of (str: torch.LongTensor), which means (field name: tensor of shape [padding_length, ])
|
||||
|
||||
"""
|
||||
ins = self[idx]
|
||||
return ins.to_tensor(padding_length)
|
||||
|
||||
def get_length(self):
|
||||
"""Fetch lengths of all fields in all instances in a dataset.
|
||||
|
||||
:return lengths: dict of (str: list). The str is the field name.
|
||||
The list contains lengths of this field in all instances.
|
||||
|
||||
"""
|
||||
lengths = defaultdict(list)
|
||||
for ins in self:
|
||||
for field_name, field_length in ins.get_length().items():
|
||||
lengths[field_name].append(field_length)
|
||||
return lengths
|
93
fastNLP/core/field.py
Normal file
93
fastNLP/core/field.py
Normal file
@ -0,0 +1,93 @@
|
||||
import torch
|
||||
|
||||
|
||||
class Field(object):
|
||||
"""A field defines a data type.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, is_target: bool):
|
||||
self.is_target = is_target
|
||||
|
||||
def index(self, vocab):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_length(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def to_tensor(self, padding_length):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TextField(Field):
|
||||
def __init__(self, text, is_target):
|
||||
"""
|
||||
:param text: list of strings
|
||||
:param is_target: bool
|
||||
"""
|
||||
super(TextField, self).__init__(is_target)
|
||||
self.text = text
|
||||
self._index = None
|
||||
|
||||
def index(self, vocab):
|
||||
if self._index is None:
|
||||
self._index = [vocab[c] for c in self.text]
|
||||
else:
|
||||
raise RuntimeError("Replicate indexing of this field.")
|
||||
return self._index
|
||||
|
||||
def get_length(self):
|
||||
"""Fetch the length of the text field.
|
||||
|
||||
:return length: int, the length of the text.
|
||||
|
||||
"""
|
||||
return len(self.text)
|
||||
|
||||
def to_tensor(self, padding_length: int):
|
||||
"""Convert text field to tensor.
|
||||
|
||||
:param padding_length: int
|
||||
:return tensor: torch.LongTensor, of shape [padding_length, ]
|
||||
"""
|
||||
pads = []
|
||||
if self._index is None:
|
||||
raise RuntimeError("Indexing not done before to_tensor in TextField.")
|
||||
if padding_length > self.get_length():
|
||||
pads = [0] * (padding_length - self.get_length())
|
||||
return torch.LongTensor(self._index + pads)
|
||||
|
||||
|
||||
class LabelField(Field):
|
||||
def __init__(self, label, is_target=True):
|
||||
super(LabelField, self).__init__(is_target)
|
||||
self.label = label
|
||||
self._index = None
|
||||
|
||||
def get_length(self):
|
||||
"""Fetch the length of the label field.
|
||||
|
||||
:return length: int, the length of the label, always 1.
|
||||
"""
|
||||
return 1
|
||||
|
||||
def index(self, vocab):
|
||||
if self._index is None:
|
||||
self._index = vocab[self.label]
|
||||
return self._index
|
||||
|
||||
def to_tensor(self, padding_length):
|
||||
if self._index is None:
|
||||
if isinstance(self.label, int):
|
||||
return torch.LongTensor([self.label])
|
||||
elif isinstance(self.label, str):
|
||||
raise RuntimeError("Field {} not indexed. Call index method.".format(self.label))
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Not support type for LabelField. Expect str or int, got {}.".format(type(self.label)))
|
||||
else:
|
||||
return torch.LongTensor([self._index])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf = TextField("test the code".split(), is_target=False)
|
53
fastNLP/core/instance.py
Normal file
53
fastNLP/core/instance.py
Normal file
@ -0,0 +1,53 @@
|
||||
class Instance(object):
|
||||
"""An instance which consists of Fields is an example in the DataSet.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, **fields):
|
||||
self.fields = fields
|
||||
self.has_index = False
|
||||
self.indexes = {}
|
||||
|
||||
def add_field(self, field_name, field):
|
||||
self.fields[field_name] = field
|
||||
|
||||
def get_length(self):
|
||||
"""Fetch the length of all fields in the instance.
|
||||
|
||||
:return length: dict of (str: int), which means (field name: field length).
|
||||
|
||||
"""
|
||||
length = {name: field.get_length() for name, field in self.fields.items()}
|
||||
return length
|
||||
|
||||
def index_field(self, field_name, vocab):
|
||||
"""use `vocab` to index certain field
|
||||
"""
|
||||
self.indexes[field_name] = self.fields[field_name].index(vocab)
|
||||
|
||||
def index_all(self, vocab):
|
||||
"""use `vocab` to index all fields
|
||||
"""
|
||||
if self.has_index:
|
||||
print("error")
|
||||
return self.indexes
|
||||
indexes = {name: field.index(vocab) for name, field in self.fields.items()}
|
||||
self.indexes = indexes
|
||||
return indexes
|
||||
|
||||
def to_tensor(self, padding_length: dict):
|
||||
"""Convert instance to tensor.
|
||||
|
||||
:param padding_length: dict of (str: int), which means (field name: padding_length of this field)
|
||||
:return tensor_x: dict of (str: torch.LongTensor), which means (field name: tensor of shape [padding_length, ])
|
||||
tensor_y: dict of (str: torch.LongTensor), which means (field name: tensor of shape [padding_length, ])
|
||||
If is_target is False for all fields, tensor_y would be an empty dict.
|
||||
"""
|
||||
tensor_x = {}
|
||||
tensor_y = {}
|
||||
for name, field in self.fields.items():
|
||||
if field.is_target:
|
||||
tensor_y[name] = field.to_tensor(padding_length[name])
|
||||
else:
|
||||
tensor_x[name] = field.to_tensor(padding_length[name])
|
||||
return tensor_x, tensor_y
|
@ -37,5 +37,7 @@ class Loss(object):
|
||||
"""
|
||||
if loss_name == "cross_entropy":
|
||||
return torch.nn.CrossEntropyLoss()
|
||||
elif loss_name == 'nll':
|
||||
return torch.nn.NLLLoss()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
@ -1,53 +1,10 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from fastNLP.core.action import Batchifier, SequentialSampler
|
||||
from fastNLP.core.action import convert_to_torch_tensor
|
||||
from fastNLP.core.preprocess import load_pickle, DEFAULT_UNKNOWN_LABEL
|
||||
from fastNLP.modules import utils
|
||||
|
||||
|
||||
def make_batch(iterator, use_cuda, output_length=False, max_len=None, min_len=None):
|
||||
"""Batch and Pad data, only for Inference.
|
||||
|
||||
:param iterator: An iterable object that returns a list of indices representing a mini-batch of samples.
|
||||
:param use_cuda: bool, whether to use GPU
|
||||
:param output_length: bool, whether to output the original length of the sequence before padding. (default: False)
|
||||
:param max_len: int, maximum sequence length. Longer sequences will be clipped. (default: None)
|
||||
:param min_len: int, minimum sequence length. Shorter sequences will be padded. (default: None)
|
||||
:return:
|
||||
"""
|
||||
for batch_x in iterator:
|
||||
batch_x = pad(batch_x)
|
||||
# convert list to tensor
|
||||
batch_x = convert_to_torch_tensor(batch_x, use_cuda)
|
||||
|
||||
# trim data to max_len
|
||||
if max_len is not None and batch_x.size(1) > max_len:
|
||||
batch_x = batch_x[:, :max_len]
|
||||
if min_len is not None and batch_x.size(1) < min_len:
|
||||
pad_tensor = torch.zeros(batch_x.size(0), min_len - batch_x.size(1)).to(batch_x)
|
||||
batch_x = torch.cat((batch_x, pad_tensor), 1)
|
||||
|
||||
if output_length:
|
||||
seq_len = [len(x) for x in batch_x]
|
||||
yield tuple([batch_x, seq_len])
|
||||
else:
|
||||
yield batch_x
|
||||
|
||||
|
||||
def pad(batch, fill=0):
|
||||
""" Pad a mini-batch of sequence samples to maximum length of this batch.
|
||||
|
||||
:param batch: list of list
|
||||
:param fill: word index to pad, default 0.
|
||||
:return batch: a padded mini-batch
|
||||
"""
|
||||
max_length = max([len(x) for x in batch])
|
||||
for idx, sample in enumerate(batch):
|
||||
if len(sample) < max_length:
|
||||
batch[idx] = sample + ([fill] * (max_length - len(sample)))
|
||||
return batch
|
||||
from fastNLP.core.action import SequentialSampler
|
||||
from fastNLP.core.batch import Batch
|
||||
from fastNLP.core.dataset import create_dataset_from_lists
|
||||
from fastNLP.core.preprocess import load_pickle
|
||||
|
||||
|
||||
class Predictor(object):
|
||||
@ -59,11 +16,17 @@ class Predictor(object):
|
||||
Currently, Predictor does not support GPU.
|
||||
"""
|
||||
|
||||
def __init__(self, pickle_path):
|
||||
def __init__(self, pickle_path, task):
|
||||
"""
|
||||
|
||||
:param pickle_path: str, the path to the pickle files.
|
||||
:param task: str, specify which task the predictor will perform. One of ("seq_label", "text_classify").
|
||||
|
||||
"""
|
||||
self.batch_size = 1
|
||||
self.batch_output = []
|
||||
self.iterator = None
|
||||
self.pickle_path = pickle_path
|
||||
self._task = task # one of ("seq_label", "text_classify")
|
||||
self.index2label = load_pickle(self.pickle_path, "id2class.pkl")
|
||||
self.word2index = load_pickle(self.pickle_path, "word2id.pkl")
|
||||
|
||||
@ -71,19 +34,19 @@ class Predictor(object):
|
||||
"""Perform inference using the trained model.
|
||||
|
||||
:param network: a PyTorch model (cpu)
|
||||
:param data: list of list of strings
|
||||
:param data: list of list of strings, [num_examples, seq_len]
|
||||
:return: list of list of strings, [num_examples, tag_seq_length]
|
||||
"""
|
||||
# transform strings into indices
|
||||
# transform strings into DataSet object
|
||||
data = self.prepare_input(data)
|
||||
|
||||
# turn on the testing mode; clean up the history
|
||||
self.mode(network, test=True)
|
||||
self.batch_output.clear()
|
||||
|
||||
data_iterator = iter(Batchifier(SequentialSampler(data), self.batch_size, drop_last=False))
|
||||
data_iterator = Batch(data, batch_size=self.batch_size, sampler=SequentialSampler(), use_cuda=False)
|
||||
|
||||
for batch_x in self.make_batch(data_iterator, use_cuda=False):
|
||||
for batch_x, _ in data_iterator:
|
||||
with torch.no_grad():
|
||||
prediction = self.data_forward(network, batch_x)
|
||||
|
||||
@ -99,103 +62,61 @@ class Predictor(object):
|
||||
|
||||
def data_forward(self, network, x):
|
||||
"""Forward through network."""
|
||||
raise NotImplementedError
|
||||
|
||||
def make_batch(self, iterator, use_cuda):
|
||||
raise NotImplementedError
|
||||
y = network(**x)
|
||||
if self._task == "seq_label":
|
||||
y = network.prediction(y)
|
||||
return y
|
||||
|
||||
def prepare_input(self, data):
|
||||
"""Transform two-level list of strings into that of index.
|
||||
"""Transform two-level list of strings into an DataSet object.
|
||||
In the training pipeline, this is done by Preprocessor. But in inference time, we do not call Preprocessor.
|
||||
|
||||
:param data:
|
||||
:param data: list of list of strings.
|
||||
::
|
||||
[
|
||||
[word_11, word_12, ...],
|
||||
[word_21, word_22, ...],
|
||||
...
|
||||
]
|
||||
:return data_index: list of list of int.
|
||||
|
||||
:return data_set: a DataSet instance.
|
||||
"""
|
||||
assert isinstance(data, list)
|
||||
data_index = []
|
||||
default_unknown_index = self.word2index[DEFAULT_UNKNOWN_LABEL]
|
||||
for example in data:
|
||||
data_index.append([self.word2index.get(w, default_unknown_index) for w in example])
|
||||
return data_index
|
||||
return create_dataset_from_lists(data, self.word2index, has_target=False)
|
||||
|
||||
def prepare_output(self, data):
|
||||
"""Transform list of batch outputs into strings."""
|
||||
raise NotImplementedError
|
||||
if self._task == "seq_label":
|
||||
return self._seq_label_prepare_output(data)
|
||||
elif self._task == "text_classify":
|
||||
return self._text_classify_prepare_output(data)
|
||||
else:
|
||||
raise NotImplementedError("Unknown task type {}".format(self._task))
|
||||
|
||||
|
||||
class SeqLabelInfer(Predictor):
|
||||
"""
|
||||
Inference on sequence labeling models.
|
||||
"""
|
||||
|
||||
def __init__(self, pickle_path):
|
||||
super(SeqLabelInfer, self).__init__(pickle_path)
|
||||
|
||||
def data_forward(self, network, inputs):
|
||||
"""
|
||||
This is only for sequence labeling with CRF decoder.
|
||||
:param network: a PyTorch model
|
||||
:param inputs: tuple of (x, seq_len)
|
||||
x: Tensor of shape [batch_size, max_len], where max_len is the maximum length of the mini-batch
|
||||
after padding.
|
||||
seq_len: list of int, the lengths of sequences before padding.
|
||||
:return prediction: Tensor of shape [batch_size, max_len]
|
||||
"""
|
||||
if not isinstance(inputs[1], list) and isinstance(inputs[0], list):
|
||||
raise RuntimeError("output_length must be true for sequence modeling.")
|
||||
# unpack the returned value from make_batch
|
||||
x, seq_len = inputs[0], inputs[1]
|
||||
batch_size, max_len = x.size(0), x.size(1)
|
||||
mask = utils.seq_mask(seq_len, max_len)
|
||||
mask = mask.byte().view(batch_size, max_len)
|
||||
y = network(x)
|
||||
prediction = network.prediction(y, mask)
|
||||
return torch.Tensor(prediction)
|
||||
|
||||
def make_batch(self, iterator, use_cuda):
|
||||
return make_batch(iterator, use_cuda, output_length=True)
|
||||
|
||||
def prepare_output(self, batch_outputs):
|
||||
"""Transform list of batch outputs into strings.
|
||||
|
||||
:param batch_outputs: list of 2-D Tensor, shape [num_batch, batch-size, tag_seq_length].
|
||||
:return results: 2-D list of strings, shape [num_examples, tag_seq_length]
|
||||
"""
|
||||
def _seq_label_prepare_output(self, batch_outputs):
|
||||
results = []
|
||||
for batch in batch_outputs:
|
||||
for example in np.array(batch):
|
||||
results.append([self.index2label[int(x)] for x in example])
|
||||
return results
|
||||
|
||||
|
||||
class ClassificationInfer(Predictor):
|
||||
"""
|
||||
Inference on Classification models.
|
||||
"""
|
||||
|
||||
def __init__(self, pickle_path):
|
||||
super(ClassificationInfer, self).__init__(pickle_path)
|
||||
|
||||
def data_forward(self, network, x):
|
||||
"""Forward through network."""
|
||||
logits = network(x)
|
||||
return logits
|
||||
|
||||
def make_batch(self, iterator, use_cuda):
|
||||
return make_batch(iterator, use_cuda, output_length=False, min_len=5)
|
||||
|
||||
def prepare_output(self, batch_outputs):
|
||||
"""
|
||||
Transform list of batch outputs into strings.
|
||||
:param batch_outputs: list of 2-D Tensor, of shape [num_batch, batch-size, num_classes].
|
||||
:return results: list of strings
|
||||
"""
|
||||
def _text_classify_prepare_output(self, batch_outputs):
|
||||
results = []
|
||||
for batch_out in batch_outputs:
|
||||
idx = np.argmax(batch_out.detach().numpy(), axis=-1)
|
||||
results.extend([self.index2label[i] for i in idx])
|
||||
return results
|
||||
|
||||
|
||||
class SeqLabelInfer(Predictor):
|
||||
def __init__(self, pickle_path):
|
||||
print(
|
||||
"[FastNLP Warning] SeqLabelInfer will be deprecated. Please use Predictor with argument 'task'='seq_label'.")
|
||||
super(SeqLabelInfer, self).__init__(pickle_path, "seq_label")
|
||||
|
||||
|
||||
class ClassificationInfer(Predictor):
|
||||
def __init__(self, pickle_path):
|
||||
print(
|
||||
"[FastNLP Warning] ClassificationInfer will be deprecated. Please use Predictor with argument 'task'='text_classify'.")
|
||||
super(ClassificationInfer, self).__init__(pickle_path, "text_classify")
|
||||
|
@ -3,6 +3,10 @@ import os
|
||||
|
||||
import numpy as np
|
||||
|
||||
from fastNLP.core.dataset import DataSet
|
||||
from fastNLP.core.field import TextField, LabelField
|
||||
from fastNLP.core.instance import Instance
|
||||
|
||||
DEFAULT_PADDING_LABEL = '<pad>' # dict index = 0
|
||||
DEFAULT_UNKNOWN_LABEL = '<unk>' # dict index = 1
|
||||
DEFAULT_RESERVED_LABEL = ['<reserved-2>',
|
||||
@ -84,7 +88,7 @@ class BasePreprocess(object):
|
||||
return len(self.label2index)
|
||||
|
||||
def run(self, train_dev_data, test_data=None, pickle_path="./", train_dev_split=0, cross_val=False, n_fold=10):
|
||||
"""Main preprocessing pipeline.
|
||||
"""Main pre-processing pipeline.
|
||||
|
||||
:param train_dev_data: three-level list, with either single label or multiple labels in a sample.
|
||||
:param test_data: three-level list, with either single label or multiple labels in a sample. (optional)
|
||||
@ -92,7 +96,9 @@ class BasePreprocess(object):
|
||||
:param train_dev_split: float, between [0, 1]. The ratio of training data used as validation set.
|
||||
:param cross_val: bool, whether to do cross validation.
|
||||
:param n_fold: int, the number of folds of cross validation. Only useful when cross_val is True.
|
||||
:return results: a tuple of datasets after preprocessing.
|
||||
:return results: multiple datasets after pre-processing. If test_data is provided, return one more dataset.
|
||||
If train_dev_split > 0, return one more dataset - the dev set. If cross_val is True, each dataset
|
||||
is a list of DataSet objects; Otherwise, each dataset is a DataSet object.
|
||||
"""
|
||||
|
||||
if pickle_exist(pickle_path, "word2id.pkl") and pickle_exist(pickle_path, "class2id.pkl"):
|
||||
@ -111,68 +117,87 @@ class BasePreprocess(object):
|
||||
index2label = self.build_reverse_dict(self.label2index)
|
||||
save_pickle(index2label, pickle_path, "id2class.pkl")
|
||||
|
||||
data_train = []
|
||||
data_dev = []
|
||||
train_set = []
|
||||
dev_set = []
|
||||
if not cross_val:
|
||||
if not pickle_exist(pickle_path, "data_train.pkl"):
|
||||
data_train.extend(self.to_index(train_dev_data))
|
||||
if train_dev_split > 0 and not pickle_exist(pickle_path, "data_dev.pkl"):
|
||||
split = int(len(data_train) * train_dev_split)
|
||||
data_dev = data_train[: split]
|
||||
data_train = data_train[split:]
|
||||
save_pickle(data_dev, pickle_path, "data_dev.pkl")
|
||||
split = int(len(train_dev_data) * train_dev_split)
|
||||
data_dev = train_dev_data[: split]
|
||||
data_train = train_dev_data[split:]
|
||||
train_set = self.convert_to_dataset(data_train, self.word2index, self.label2index)
|
||||
dev_set = self.convert_to_dataset(data_dev, self.word2index, self.label2index)
|
||||
|
||||
save_pickle(dev_set, pickle_path, "data_dev.pkl")
|
||||
print("{} of the training data is split for validation. ".format(train_dev_split))
|
||||
save_pickle(data_train, pickle_path, "data_train.pkl")
|
||||
else:
|
||||
train_set = self.convert_to_dataset(train_dev_data, self.word2index, self.label2index)
|
||||
save_pickle(train_set, pickle_path, "data_train.pkl")
|
||||
else:
|
||||
data_train = load_pickle(pickle_path, "data_train.pkl")
|
||||
train_set = load_pickle(pickle_path, "data_train.pkl")
|
||||
if pickle_exist(pickle_path, "data_dev.pkl"):
|
||||
data_dev = load_pickle(pickle_path, "data_dev.pkl")
|
||||
dev_set = load_pickle(pickle_path, "data_dev.pkl")
|
||||
else:
|
||||
# cross_val is True
|
||||
if not pickle_exist(pickle_path, "data_train_0.pkl"):
|
||||
# cross validation
|
||||
data_idx = self.to_index(train_dev_data)
|
||||
data_cv = self.cv_split(data_idx, n_fold)
|
||||
data_cv = self.cv_split(train_dev_data, n_fold)
|
||||
for i, (data_train_cv, data_dev_cv) in enumerate(data_cv):
|
||||
data_train_cv = self.convert_to_dataset(data_train_cv, self.word2index, self.label2index)
|
||||
data_dev_cv = self.convert_to_dataset(data_dev_cv, self.word2index, self.label2index)
|
||||
save_pickle(
|
||||
data_train_cv, pickle_path,
|
||||
"data_train_{}.pkl".format(i))
|
||||
save_pickle(
|
||||
data_dev_cv, pickle_path,
|
||||
"data_dev_{}.pkl".format(i))
|
||||
data_train.append(data_train_cv)
|
||||
data_dev.append(data_dev_cv)
|
||||
train_set.append(data_train_cv)
|
||||
dev_set.append(data_dev_cv)
|
||||
print("{}-fold cross validation.".format(n_fold))
|
||||
else:
|
||||
for i in range(n_fold):
|
||||
data_train_cv = load_pickle(pickle_path, "data_train_{}.pkl".format(i))
|
||||
data_dev_cv = load_pickle(pickle_path, "data_dev_{}.pkl".format(i))
|
||||
data_train.append(data_train_cv)
|
||||
data_dev.append(data_dev_cv)
|
||||
train_set.append(data_train_cv)
|
||||
dev_set.append(data_dev_cv)
|
||||
|
||||
# prepare test data if provided
|
||||
data_test = []
|
||||
test_set = []
|
||||
if test_data is not None:
|
||||
if not pickle_exist(pickle_path, "data_test.pkl"):
|
||||
data_test = self.to_index(test_data)
|
||||
save_pickle(data_test, pickle_path, "data_test.pkl")
|
||||
test_set = self.convert_to_dataset(test_data, self.word2index, self.label2index)
|
||||
save_pickle(test_set, pickle_path, "data_test.pkl")
|
||||
|
||||
# return preprocessed results
|
||||
results = [data_train]
|
||||
results = [train_set]
|
||||
if cross_val or train_dev_split > 0:
|
||||
results.append(data_dev)
|
||||
results.append(dev_set)
|
||||
if test_data:
|
||||
results.append(data_test)
|
||||
results.append(test_set)
|
||||
if len(results) == 1:
|
||||
return results[0]
|
||||
else:
|
||||
return tuple(results)
|
||||
|
||||
def build_dict(self, data):
|
||||
raise NotImplementedError
|
||||
label2index = DEFAULT_WORD_TO_INDEX.copy()
|
||||
word2index = DEFAULT_WORD_TO_INDEX.copy()
|
||||
for example in data:
|
||||
for word in example[0]:
|
||||
if word not in word2index:
|
||||
word2index[word] = len(word2index)
|
||||
label = example[1]
|
||||
if isinstance(label, str):
|
||||
# label is a string
|
||||
if label not in label2index:
|
||||
label2index[label] = len(label2index)
|
||||
elif isinstance(label, list):
|
||||
# label is a list of strings
|
||||
for single_label in label:
|
||||
if single_label not in label2index:
|
||||
label2index[single_label] = len(label2index)
|
||||
return word2index, label2index
|
||||
|
||||
def to_index(self, data):
|
||||
raise NotImplementedError
|
||||
|
||||
def build_reverse_dict(self, word_dict):
|
||||
id2word = {word_dict[w]: w for w in word_dict}
|
||||
@ -186,11 +211,23 @@ class BasePreprocess(object):
|
||||
return data_train, data_dev
|
||||
|
||||
def cv_split(self, data, n_fold):
|
||||
"""Split data for cross validation."""
|
||||
"""Split data for cross validation.
|
||||
|
||||
:param data: list of string
|
||||
:param n_fold: int
|
||||
:return data_cv:
|
||||
|
||||
::
|
||||
[
|
||||
(data_train, data_dev), # 1st fold
|
||||
(data_train, data_dev), # 2nd fold
|
||||
...
|
||||
]
|
||||
|
||||
"""
|
||||
data_copy = data.copy()
|
||||
np.random.shuffle(data_copy)
|
||||
fold_size = round(len(data_copy) / n_fold)
|
||||
|
||||
data_cv = []
|
||||
for i in range(n_fold - 1):
|
||||
start = i * fold_size
|
||||
@ -202,154 +239,72 @@ class BasePreprocess(object):
|
||||
data_dev = data_copy[start:]
|
||||
data_train = data_copy[:start]
|
||||
data_cv.append((data_train, data_dev))
|
||||
|
||||
return data_cv
|
||||
|
||||
def convert_to_dataset(self, data, vocab, label_vocab):
|
||||
"""Convert list of indices into a DataSet object.
|
||||
|
||||
:param data: list. Entries are strings.
|
||||
:param vocab: a dict, mapping string (token) to index (int).
|
||||
:param label_vocab: a dict, mapping string (label) to index (int).
|
||||
:return data_set: a DataSet object
|
||||
"""
|
||||
use_word_seq = False
|
||||
use_label_seq = False
|
||||
use_label_str = False
|
||||
|
||||
# construct a DataSet object and fill it with Instances
|
||||
data_set = DataSet()
|
||||
for example in data:
|
||||
words, label = example[0], example[1]
|
||||
instance = Instance()
|
||||
|
||||
if isinstance(words, list):
|
||||
x = TextField(words, is_target=False)
|
||||
instance.add_field("word_seq", x)
|
||||
use_word_seq = True
|
||||
else:
|
||||
raise NotImplementedError("words is a {}".format(type(words)))
|
||||
|
||||
if isinstance(label, list):
|
||||
y = TextField(label, is_target=True)
|
||||
instance.add_field("label_seq", y)
|
||||
use_label_seq = True
|
||||
elif isinstance(label, str):
|
||||
y = LabelField(label, is_target=True)
|
||||
instance.add_field("label", y)
|
||||
use_label_str = True
|
||||
else:
|
||||
raise NotImplementedError("label is a {}".format(type(label)))
|
||||
data_set.append(instance)
|
||||
|
||||
# convert strings to indices
|
||||
if use_word_seq:
|
||||
data_set.index_field("word_seq", vocab)
|
||||
if use_label_seq:
|
||||
data_set.index_field("label_seq", label_vocab)
|
||||
if use_label_str:
|
||||
data_set.index_field("label", label_vocab)
|
||||
|
||||
return data_set
|
||||
|
||||
|
||||
class SeqLabelPreprocess(BasePreprocess):
|
||||
"""Preprocess pipeline, including building mapping from words to index, from index to words,
|
||||
from labels/classes to index, from index to labels/classes.
|
||||
data of three-level list which have multiple labels in each sample.
|
||||
::
|
||||
|
||||
[
|
||||
[ [word_11, word_12, ...], [label_1, label_1, ...] ],
|
||||
[ [word_21, word_22, ...], [label_2, label_1, ...] ],
|
||||
...
|
||||
]
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
|
||||
super(SeqLabelPreprocess, self).__init__()
|
||||
|
||||
def build_dict(self, data):
|
||||
"""Add new words with indices into self.word_dict, new labels with indices into self.label_dict.
|
||||
|
||||
:param data: three-level list
|
||||
::
|
||||
|
||||
[
|
||||
[ [word_11, word_12, ...], [label_1, label_1, ...] ],
|
||||
[ [word_21, word_22, ...], [label_2, label_1, ...] ],
|
||||
...
|
||||
]
|
||||
|
||||
:return word2index: dict of {str, int}
|
||||
label2index: dict of {str, int}
|
||||
"""
|
||||
# In seq labeling, both word seq and label seq need to be padded to the same length in a mini-batch.
|
||||
label2index = DEFAULT_WORD_TO_INDEX.copy()
|
||||
word2index = DEFAULT_WORD_TO_INDEX.copy()
|
||||
for example in data:
|
||||
for word, label in zip(example[0], example[1]):
|
||||
if word not in word2index:
|
||||
word2index[word] = len(word2index)
|
||||
if label not in label2index:
|
||||
label2index[label] = len(label2index)
|
||||
return word2index, label2index
|
||||
|
||||
def to_index(self, data):
|
||||
"""Convert word strings and label strings into indices.
|
||||
|
||||
:param data: three-level list
|
||||
::
|
||||
|
||||
[
|
||||
[ [word_11, word_12, ...], [label_1, label_1, ...] ],
|
||||
[ [word_21, word_22, ...], [label_2, label_1, ...] ],
|
||||
...
|
||||
]
|
||||
|
||||
:return data_index: the same shape as data, but each string is replaced by its corresponding index
|
||||
"""
|
||||
data_index = []
|
||||
for example in data:
|
||||
word_list = []
|
||||
label_list = []
|
||||
for word, label in zip(example[0], example[1]):
|
||||
word_list.append(self.word2index.get(word, DEFAULT_WORD_TO_INDEX[DEFAULT_UNKNOWN_LABEL]))
|
||||
label_list.append(self.label2index.get(label, DEFAULT_WORD_TO_INDEX[DEFAULT_UNKNOWN_LABEL]))
|
||||
data_index.append([word_list, label_list])
|
||||
return data_index
|
||||
|
||||
|
||||
class ClassPreprocess(BasePreprocess):
|
||||
""" Preprocess pipeline for classification datasets.
|
||||
Preprocess pipeline, including building mapping from words to index, from index to words,
|
||||
from labels/classes to index, from index to labels/classes.
|
||||
design for data of three-level list which has a single label in each sample.
|
||||
::
|
||||
|
||||
[
|
||||
[ [word_11, word_12, ...], label_1 ],
|
||||
[ [word_21, word_22, ...], label_2 ],
|
||||
...
|
||||
]
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(ClassPreprocess, self).__init__()
|
||||
|
||||
def build_dict(self, data):
|
||||
"""Build vocabulary."""
|
||||
|
||||
# build vocabulary from scratch if nothing exists
|
||||
word2index = DEFAULT_WORD_TO_INDEX.copy()
|
||||
label2index = DEFAULT_WORD_TO_INDEX.copy()
|
||||
|
||||
# collect every word and label
|
||||
for sent, label in data:
|
||||
if len(sent) <= 1:
|
||||
continue
|
||||
|
||||
if label not in label2index:
|
||||
label2index[label] = len(label2index)
|
||||
|
||||
for word in sent:
|
||||
if word not in word2index:
|
||||
word2index[word] = len(word2index)
|
||||
return word2index, label2index
|
||||
|
||||
def to_index(self, data):
|
||||
"""Convert word strings and label strings into indices.
|
||||
|
||||
:param data: three-level list
|
||||
::
|
||||
|
||||
[
|
||||
[ [word_11, word_12, ...], label_1 ],
|
||||
[ [word_21, word_22, ...], label_2 ],
|
||||
...
|
||||
]
|
||||
|
||||
:return data_index: the same shape as data, but each string is replaced by its corresponding index
|
||||
"""
|
||||
data_index = []
|
||||
for example in data:
|
||||
word_list = []
|
||||
# example[0] is the word list, example[1] is the single label
|
||||
for word in example[0]:
|
||||
word_list.append(self.word2index.get(word, DEFAULT_WORD_TO_INDEX[DEFAULT_UNKNOWN_LABEL]))
|
||||
label_index = self.label2index.get(example[1], DEFAULT_WORD_TO_INDEX[DEFAULT_UNKNOWN_LABEL])
|
||||
data_index.append([word_list, label_index])
|
||||
return data_index
|
||||
|
||||
|
||||
def infer_preprocess(pickle_path, data):
|
||||
"""Preprocess over inference data. Transform three-level list of strings into that of index.
|
||||
::
|
||||
|
||||
[
|
||||
[word_11, word_12, ...],
|
||||
[word_21, word_22, ...],
|
||||
...
|
||||
]
|
||||
|
||||
"""
|
||||
word2index = load_pickle(pickle_path, "word2id.pkl")
|
||||
data_index = []
|
||||
for example in data:
|
||||
data_index.append([word2index.get(w, DEFAULT_UNKNOWN_LABEL) for w in example])
|
||||
return data_index
|
||||
if __name__ == "__main__":
|
||||
p = BasePreprocess()
|
||||
train_dev_data = [[["I", "am", "a", "good", "student", "."], "0"],
|
||||
[["You", "are", "pretty", "."], "1"]
|
||||
]
|
||||
training_set = p.run(train_dev_data)
|
||||
print(training_set)
|
||||
|
@ -1,9 +1,8 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from fastNLP.core.action import Action
|
||||
from fastNLP.core.action import RandomSampler, Batchifier
|
||||
from fastNLP.modules import utils
|
||||
from fastNLP.core.action import RandomSampler
|
||||
from fastNLP.core.batch import Batch
|
||||
from fastNLP.saver.logger import create_logger
|
||||
|
||||
logger = create_logger(__name__, "./train_test.log")
|
||||
@ -35,16 +34,16 @@ class BaseTester(object):
|
||||
"""
|
||||
"required_args" is the collection of arguments that users must pass to Trainer explicitly.
|
||||
This is used to warn users of essential settings in the training.
|
||||
Obviously, "required_args" is the subset of "default_args".
|
||||
The value in "default_args" to the keys in "required_args" is simply for type check.
|
||||
Specially, "required_args" does not have default value, so they have nothing to do with "default_args".
|
||||
"""
|
||||
# add required arguments here
|
||||
required_args = {}
|
||||
required_args = {"task" # one of ("seq_label", "text_classify")
|
||||
}
|
||||
|
||||
for req_key in required_args:
|
||||
if req_key not in kwargs:
|
||||
logger.error("Tester lacks argument {}".format(req_key))
|
||||
raise ValueError("Tester lacks argument {}".format(req_key))
|
||||
self._task = kwargs["task"]
|
||||
|
||||
for key in default_args:
|
||||
if key in kwargs:
|
||||
@ -79,14 +78,14 @@ class BaseTester(object):
|
||||
self._model = network
|
||||
|
||||
# turn on the testing mode; clean up the history
|
||||
self.mode(network, test=True)
|
||||
self.mode(network, is_test=True)
|
||||
self.eval_history.clear()
|
||||
self.batch_output.clear()
|
||||
|
||||
iterator = iter(Batchifier(RandomSampler(dev_data), self.batch_size, drop_last=False))
|
||||
data_iterator = Batch(dev_data, self.batch_size, sampler=RandomSampler(), use_cuda=self.use_cuda)
|
||||
step = 0
|
||||
|
||||
for batch_x, batch_y in self.make_batch(iterator):
|
||||
for batch_x, batch_y in data_iterator:
|
||||
with torch.no_grad():
|
||||
prediction = self.data_forward(network, batch_x)
|
||||
eval_results = self.evaluate(prediction, batch_y)
|
||||
@ -102,17 +101,22 @@ class BaseTester(object):
|
||||
print(self.make_eval_output(prediction, eval_results))
|
||||
step += 1
|
||||
|
||||
def mode(self, model, test):
|
||||
def mode(self, model, is_test=False):
|
||||
"""Train mode or Test mode. This is for PyTorch currently.
|
||||
|
||||
:param model: a PyTorch model
|
||||
:param test: bool, whether in test mode.
|
||||
:param is_test: bool, whether in test mode or not.
|
||||
|
||||
"""
|
||||
Action.mode(model, test)
|
||||
if is_test:
|
||||
model.eval()
|
||||
else:
|
||||
model.train()
|
||||
|
||||
def data_forward(self, network, x):
|
||||
"""A forward pass of the model. """
|
||||
raise NotImplementedError
|
||||
y = network(**x)
|
||||
return y
|
||||
|
||||
def evaluate(self, predict, truth):
|
||||
"""Compute evaluation metrics.
|
||||
@ -121,7 +125,38 @@ class BaseTester(object):
|
||||
:param truth: Tensor
|
||||
:return eval_results: can be anything. It will be stored in self.eval_history
|
||||
"""
|
||||
raise NotImplementedError
|
||||
if "label_seq" in truth:
|
||||
truth = truth["label_seq"]
|
||||
elif "label" in truth:
|
||||
truth = truth["label"]
|
||||
else:
|
||||
raise NotImplementedError("Unknown key {} in batch_y.".format(truth.keys()))
|
||||
|
||||
if self._task == "seq_label":
|
||||
return self._seq_label_evaluate(predict, truth)
|
||||
elif self._task == "text_classify":
|
||||
return self._text_classify_evaluate(predict, truth)
|
||||
else:
|
||||
raise NotImplementedError("Unknown task type {}.".format(self._task))
|
||||
|
||||
def _seq_label_evaluate(self, predict, truth):
|
||||
batch_size, max_len = predict.size(0), predict.size(1)
|
||||
loss = self._model.loss(predict, truth) / batch_size
|
||||
prediction = self._model.prediction(predict)
|
||||
# pad prediction to equal length
|
||||
for pred in prediction:
|
||||
if len(pred) < max_len:
|
||||
pred += [0] * (max_len - len(pred))
|
||||
results = torch.Tensor(prediction).view(-1, )
|
||||
|
||||
# make sure "results" is in the same device as "truth"
|
||||
results = results.to(truth)
|
||||
accuracy = torch.sum(results == truth.view((-1,))).to(torch.float) / results.shape[0]
|
||||
return [float(loss), float(accuracy)]
|
||||
|
||||
def _text_classify_evaluate(self, y_logit, y_true):
|
||||
y_prob = torch.nn.functional.softmax(y_logit, dim=-1)
|
||||
return [y_prob, y_true]
|
||||
|
||||
@property
|
||||
def metrics(self):
|
||||
@ -131,7 +166,27 @@ class BaseTester(object):
|
||||
|
||||
:return : variable number of outputs
|
||||
"""
|
||||
raise NotImplementedError
|
||||
if self._task == "seq_label":
|
||||
return self._seq_label_metrics
|
||||
elif self._task == "text_classify":
|
||||
return self._text_classify_metrics
|
||||
else:
|
||||
raise NotImplementedError("Unknown task type {}.".format(self._task))
|
||||
|
||||
@property
|
||||
def _seq_label_metrics(self):
|
||||
batch_loss = np.mean([x[0] for x in self.eval_history])
|
||||
batch_accuracy = np.mean([x[1] for x in self.eval_history])
|
||||
return batch_loss, batch_accuracy
|
||||
|
||||
@property
|
||||
def _text_classify_metrics(self):
|
||||
y_prob, y_true = zip(*self.eval_history)
|
||||
y_prob = torch.cat(y_prob, dim=0)
|
||||
y_pred = torch.argmax(y_prob, dim=-1)
|
||||
y_true = torch.cat(y_true, dim=0)
|
||||
acc = float(torch.sum(y_pred == y_true)) / len(y_true)
|
||||
return y_true.cpu().numpy(), y_prob.cpu().numpy(), acc
|
||||
|
||||
def show_metrics(self):
|
||||
"""Customize evaluation outputs in Trainer.
|
||||
@ -140,10 +195,8 @@ class BaseTester(object):
|
||||
|
||||
:return print_str: str
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def make_batch(self, iterator):
|
||||
raise NotImplementedError
|
||||
loss, accuracy = self.metrics
|
||||
return "dev loss={:.2f}, accuracy={:.2f}".format(loss, accuracy)
|
||||
|
||||
def make_eval_output(self, predictions, eval_results):
|
||||
"""Customize Tester outputs.
|
||||
@ -152,108 +205,20 @@ class BaseTester(object):
|
||||
:param eval_results: Tensor
|
||||
:return: str, to be printed.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
return self.show_metrics()
|
||||
|
||||
|
||||
class SeqLabelTester(BaseTester):
|
||||
"""Tester for sequence labeling.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, **test_args):
|
||||
"""
|
||||
:param test_args: a dict-like object that has __getitem__ method, can be accessed by "test_args["key_str"]"
|
||||
"""
|
||||
test_args.update({"task": "seq_label"})
|
||||
print(
|
||||
"[FastNLP Warning] SeqLabelTester will be deprecated. Please use Tester with argument 'task'='seq_label'.")
|
||||
super(SeqLabelTester, self).__init__(**test_args)
|
||||
self.max_len = None
|
||||
self.mask = None
|
||||
self.seq_len = None
|
||||
|
||||
def data_forward(self, network, inputs):
|
||||
"""This is only for sequence labeling with CRF decoder.
|
||||
|
||||
:param network: a PyTorch model
|
||||
:param inputs: tuple of (x, seq_len)
|
||||
x: Tensor of shape [batch_size, max_len], where max_len is the maximum length of the mini-batch
|
||||
after padding.
|
||||
seq_len: list of int, the lengths of sequences before padding.
|
||||
:return y: Tensor of shape [batch_size, max_len]
|
||||
"""
|
||||
if not isinstance(inputs, tuple):
|
||||
raise RuntimeError("output_length must be true for sequence modeling.")
|
||||
# unpack the returned value from make_batch
|
||||
x, seq_len = inputs[0], inputs[1]
|
||||
batch_size, max_len = x.size(0), x.size(1)
|
||||
mask = utils.seq_mask(seq_len, max_len)
|
||||
mask = mask.byte().view(batch_size, max_len)
|
||||
if torch.cuda.is_available() and self.use_cuda:
|
||||
mask = mask.cuda()
|
||||
self.mask = mask
|
||||
self.seq_len = seq_len
|
||||
y = network(x)
|
||||
return y
|
||||
|
||||
def evaluate(self, predict, truth):
|
||||
"""Compute metrics (or loss).
|
||||
|
||||
:param predict: Tensor, [batch_size, max_len, tag_size]
|
||||
:param truth: Tensor, [batch_size, max_len]
|
||||
:return:
|
||||
"""
|
||||
batch_size, max_len = predict.size(0), predict.size(1)
|
||||
loss = self._model.loss(predict, truth, self.mask) / batch_size
|
||||
|
||||
prediction = self._model.prediction(predict, self.mask)
|
||||
results = torch.Tensor(prediction).view(-1, )
|
||||
# make sure "results" is in the same device as "truth"
|
||||
results = results.to(truth)
|
||||
accuracy = torch.sum(results == truth.view((-1,))).to(torch.float) / results.shape[0]
|
||||
return [float(loss), float(accuracy)]
|
||||
|
||||
def metrics(self):
|
||||
batch_loss = np.mean([x[0] for x in self.eval_history])
|
||||
batch_accuracy = np.mean([x[1] for x in self.eval_history])
|
||||
return batch_loss, batch_accuracy
|
||||
|
||||
def show_metrics(self):
|
||||
"""This is called by Trainer to print evaluation on dev set.
|
||||
|
||||
:return print_str: str
|
||||
"""
|
||||
loss, accuracy = self.metrics()
|
||||
return "dev loss={:.2f}, accuracy={:.2f}".format(loss, accuracy)
|
||||
|
||||
def make_batch(self, iterator):
|
||||
return Action.make_batch(iterator, use_cuda=self.use_cuda, output_length=True)
|
||||
|
||||
|
||||
class ClassificationTester(BaseTester):
|
||||
"""Tester for classification."""
|
||||
|
||||
def __init__(self, **test_args):
|
||||
"""
|
||||
:param test_args: a dict-like object that has __getitem__ method.
|
||||
can be accessed by "test_args["key_str"]"
|
||||
"""
|
||||
test_args.update({"task": "seq_label"})
|
||||
print(
|
||||
"[FastNLP Warning] ClassificationTester will be deprecated. Please use Tester with argument 'task'='text_classify'.")
|
||||
super(ClassificationTester, self).__init__(**test_args)
|
||||
|
||||
def make_batch(self, iterator, max_len=None):
|
||||
return Action.make_batch(iterator, use_cuda=self.use_cuda, max_len=max_len)
|
||||
|
||||
def data_forward(self, network, x):
|
||||
"""Forward through network."""
|
||||
logits = network(x)
|
||||
return logits
|
||||
|
||||
def evaluate(self, y_logit, y_true):
|
||||
"""Return y_pred and y_true."""
|
||||
y_prob = torch.nn.functional.softmax(y_logit, dim=-1)
|
||||
return [y_prob, y_true]
|
||||
|
||||
def metrics(self):
|
||||
"""Compute accuracy."""
|
||||
y_prob, y_true = zip(*self.eval_history)
|
||||
y_prob = torch.cat(y_prob, dim=0)
|
||||
y_pred = torch.argmax(y_prob, dim=-1)
|
||||
y_true = torch.cat(y_true, dim=0)
|
||||
acc = float(torch.sum(y_pred == y_true)) / len(y_true)
|
||||
return y_true.cpu().numpy(), y_prob.cpu().numpy(), acc
|
||||
|
@ -4,15 +4,13 @@ import time
|
||||
from datetime import timedelta
|
||||
|
||||
import torch
|
||||
import tensorboardX
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
from fastNLP.core.action import Action
|
||||
from fastNLP.core.action import RandomSampler, Batchifier
|
||||
from fastNLP.core.action import RandomSampler
|
||||
from fastNLP.core.batch import Batch
|
||||
from fastNLP.core.loss import Loss
|
||||
from fastNLP.core.optimizer import Optimizer
|
||||
from fastNLP.core.tester import SeqLabelTester, ClassificationTester
|
||||
from fastNLP.modules import utils
|
||||
from fastNLP.saver.logger import create_logger
|
||||
from fastNLP.saver.model_saver import ModelSaver
|
||||
|
||||
@ -50,16 +48,16 @@ class BaseTrainer(object):
|
||||
"""
|
||||
"required_args" is the collection of arguments that users must pass to Trainer explicitly.
|
||||
This is used to warn users of essential settings in the training.
|
||||
Obviously, "required_args" is the subset of "default_args".
|
||||
The value in "default_args" to the keys in "required_args" is simply for type check.
|
||||
Specially, "required_args" does not have default value, so they have nothing to do with "default_args".
|
||||
"""
|
||||
# add required arguments here
|
||||
required_args = {}
|
||||
required_args = {"task" # one of ("seq_label", "text_classify")
|
||||
}
|
||||
|
||||
for req_key in required_args:
|
||||
if req_key not in kwargs:
|
||||
logger.error("Trainer lacks argument {}".format(req_key))
|
||||
raise ValueError("Trainer lacks argument {}".format(req_key))
|
||||
self._task = kwargs["task"]
|
||||
|
||||
for key in default_args:
|
||||
if key in kwargs:
|
||||
@ -90,13 +88,14 @@ class BaseTrainer(object):
|
||||
self._optimizer_proto = default_args["optimizer"]
|
||||
self._summary_writer = SummaryWriter(self.pickle_path + 'tensorboard_logs')
|
||||
self._graph_summaried = False
|
||||
self._best_accuracy = 0.0
|
||||
|
||||
def train(self, network, train_data, dev_data=None):
|
||||
"""General Training Procedure
|
||||
|
||||
:param network: a model
|
||||
:param train_data: three-level list, the training set.
|
||||
:param dev_data: three-level list, the validation data (optional)
|
||||
:param train_data: a DataSet instance, the training data
|
||||
:param dev_data: a DataSet instance, the validation data (optional)
|
||||
"""
|
||||
# transfer model to gpu if available
|
||||
if torch.cuda.is_available() and self.use_cuda:
|
||||
@ -126,9 +125,10 @@ class BaseTrainer(object):
|
||||
logger.info("training epoch {}".format(epoch))
|
||||
|
||||
# turn on network training mode
|
||||
self.mode(network, test=False)
|
||||
self.mode(network, is_test=False)
|
||||
# prepare mini-batch iterator
|
||||
data_iterator = iter(Batchifier(RandomSampler(train_data), self.batch_size, drop_last=False))
|
||||
data_iterator = Batch(train_data, batch_size=self.batch_size, sampler=RandomSampler(),
|
||||
use_cuda=self.use_cuda)
|
||||
logger.info("prepared data iterator")
|
||||
|
||||
# one forward and backward pass
|
||||
@ -157,7 +157,7 @@ class BaseTrainer(object):
|
||||
- epoch: int,
|
||||
"""
|
||||
step = 0
|
||||
for batch_x, batch_y in self.make_batch(data_iterator):
|
||||
for batch_x, batch_y in data_iterator:
|
||||
|
||||
prediction = self.data_forward(network, batch_x)
|
||||
|
||||
@ -166,10 +166,6 @@ class BaseTrainer(object):
|
||||
self.update()
|
||||
self._summary_writer.add_scalar("loss", loss.item(), global_step=step)
|
||||
|
||||
if not self._graph_summaried:
|
||||
self._summary_writer.add_graph(network, batch_x)
|
||||
self._graph_summaried = True
|
||||
|
||||
if kwargs["n_print"] > 0 and step % kwargs["n_print"] == 0:
|
||||
end = time.time()
|
||||
diff = timedelta(seconds=round(end - kwargs["start"]))
|
||||
@ -204,11 +200,17 @@ class BaseTrainer(object):
|
||||
network_copy = copy.deepcopy(network)
|
||||
self.train(network_copy, train_data_cv[i], dev_data_cv[i])
|
||||
|
||||
def make_batch(self, iterator):
|
||||
raise NotImplementedError
|
||||
def mode(self, model, is_test=False):
|
||||
"""Train mode or Test mode. This is for PyTorch currently.
|
||||
|
||||
def mode(self, network, test):
|
||||
Action.mode(network, test)
|
||||
:param model: a PyTorch model
|
||||
:param is_test: bool, whether in test mode or not.
|
||||
|
||||
"""
|
||||
if is_test:
|
||||
model.eval()
|
||||
else:
|
||||
model.train()
|
||||
|
||||
def define_optimizer(self):
|
||||
"""Define framework-specific optimizer specified by the models.
|
||||
@ -224,7 +226,20 @@ class BaseTrainer(object):
|
||||
self._optimizer.step()
|
||||
|
||||
def data_forward(self, network, x):
|
||||
raise NotImplementedError
|
||||
if self._task == "seq_label":
|
||||
y = network(x["word_seq"], x["word_seq_origin_len"])
|
||||
elif self._task == "text_classify":
|
||||
y = network(x["word_seq"])
|
||||
else:
|
||||
raise NotImplementedError("Unknown task type {}.".format(self._task))
|
||||
|
||||
if not self._graph_summaried:
|
||||
if self._task == "seq_label":
|
||||
self._summary_writer.add_graph(network, (x["word_seq"], x["word_seq_origin_len"]), verbose=False)
|
||||
elif self._task == "text_classify":
|
||||
self._summary_writer.add_graph(network, x["word_seq"], verbose=False)
|
||||
self._graph_summaried = True
|
||||
return y
|
||||
|
||||
def grad_backward(self, loss):
|
||||
"""Compute gradient with link rules.
|
||||
@ -243,6 +258,13 @@ class BaseTrainer(object):
|
||||
:param truth: ground truth label vector
|
||||
:return: a scalar
|
||||
"""
|
||||
if "label_seq" in truth:
|
||||
truth = truth["label_seq"]
|
||||
elif "label" in truth:
|
||||
truth = truth["label"]
|
||||
truth = truth.view((-1,))
|
||||
else:
|
||||
raise NotImplementedError("Unknown key {} in batch_y.".format(truth.keys()))
|
||||
return self._loss_func(predict, truth)
|
||||
|
||||
def define_loss(self):
|
||||
@ -270,7 +292,12 @@ class BaseTrainer(object):
|
||||
:param validator: a Tester instance
|
||||
:return: bool, True means current results on dev set is the best.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
loss, accuracy = validator.metrics
|
||||
if accuracy > self._best_accuracy:
|
||||
self._best_accuracy = accuracy
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def save_model(self, network, model_name):
|
||||
"""Save this model with such a name.
|
||||
@ -291,55 +318,11 @@ class SeqLabelTrainer(BaseTrainer):
|
||||
"""Trainer for Sequence Labeling
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs.update({"task": "seq_label"})
|
||||
print(
|
||||
"[FastNLP Warning] SeqLabelTrainer will be deprecated. Please use Trainer with argument 'task'='seq_label'.")
|
||||
super(SeqLabelTrainer, self).__init__(**kwargs)
|
||||
# self.vocab_size = kwargs["vocab_size"]
|
||||
# self.num_classes = kwargs["num_classes"]
|
||||
self.max_len = None
|
||||
self.mask = None
|
||||
self.best_accuracy = 0.0
|
||||
|
||||
def data_forward(self, network, inputs):
|
||||
if not isinstance(inputs, tuple):
|
||||
raise RuntimeError("output_length must be true for sequence modeling. Receive {}".format(type(inputs[0])))
|
||||
# unpack the returned value from make_batch
|
||||
x, seq_len = inputs[0], inputs[1]
|
||||
|
||||
batch_size, max_len = x.size(0), x.size(1)
|
||||
mask = utils.seq_mask(seq_len, max_len)
|
||||
mask = mask.byte().view(batch_size, max_len)
|
||||
|
||||
if torch.cuda.is_available() and self.use_cuda:
|
||||
mask = mask.cuda()
|
||||
self.mask = mask
|
||||
|
||||
y = network(x)
|
||||
return y
|
||||
|
||||
def get_loss(self, predict, truth):
|
||||
"""Compute loss given prediction and ground truth.
|
||||
|
||||
:param predict: prediction label vector, [batch_size, max_len, tag_size]
|
||||
:param truth: ground truth label vector, [batch_size, max_len]
|
||||
:return loss: a scalar
|
||||
"""
|
||||
batch_size, max_len = predict.size(0), predict.size(1)
|
||||
assert truth.shape == (batch_size, max_len)
|
||||
|
||||
loss = self._model.loss(predict, truth, self.mask)
|
||||
return loss
|
||||
|
||||
def best_eval_result(self, validator):
|
||||
loss, accuracy = validator.metrics()
|
||||
if accuracy > self.best_accuracy:
|
||||
self.best_accuracy = accuracy
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def make_batch(self, iterator):
|
||||
return Action.make_batch(iterator, output_length=True, use_cuda=self.use_cuda)
|
||||
|
||||
def _create_validator(self, valid_args):
|
||||
return SeqLabelTester(**valid_args)
|
||||
@ -349,33 +332,10 @@ class ClassificationTrainer(BaseTrainer):
|
||||
"""Trainer for text classification."""
|
||||
|
||||
def __init__(self, **train_args):
|
||||
train_args.update({"task": "text_classify"})
|
||||
print(
|
||||
"[FastNLP Warning] ClassificationTrainer will be deprecated. Please use Trainer with argument 'task'='text_classify'.")
|
||||
super(ClassificationTrainer, self).__init__(**train_args)
|
||||
|
||||
self.iterator = None
|
||||
self.loss_func = None
|
||||
self.optimizer = None
|
||||
self.best_accuracy = 0
|
||||
|
||||
def data_forward(self, network, x):
|
||||
"""Forward through network."""
|
||||
logits = network(x)
|
||||
return logits
|
||||
|
||||
def make_batch(self, iterator):
|
||||
return Action.make_batch(iterator, output_length=False, use_cuda=self.use_cuda)
|
||||
|
||||
def get_acc(self, y_logit, y_true):
|
||||
"""Compute accuracy."""
|
||||
y_pred = torch.argmax(y_logit, dim=-1)
|
||||
return int(torch.sum(y_true == y_pred)) / len(y_true)
|
||||
|
||||
def best_eval_result(self, validator):
|
||||
_, _, accuracy = validator.metrics()
|
||||
if accuracy > self.best_accuracy:
|
||||
self.best_accuracy = accuracy
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def _create_validator(self, valid_args):
|
||||
return ClassificationTester(**valid_args)
|
||||
|
@ -35,8 +35,12 @@ class CNNText(torch.nn.Module):
|
||||
self.dropout = nn.Dropout(drop_prob)
|
||||
self.fc = encoder.linear.Linear(sum(kernel_nums), num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.embed(x) # [N,L] -> [N,L,C]
|
||||
def forward(self, word_seq):
|
||||
"""
|
||||
:param word_seq: torch.LongTensor, [batch_size, seq_len]
|
||||
:return x: torch.LongTensor, [batch_size, num_classes]
|
||||
"""
|
||||
x = self.embed(word_seq) # [N,L] -> [N,L,C]
|
||||
x = self.conv_pool(x) # [N,L,C] -> [N,C]
|
||||
x = self.dropout(x)
|
||||
x = self.fc(x) # [N,C] -> [N, N_class]
|
||||
|
@ -4,6 +4,20 @@ from fastNLP.models.base_model import BaseModel
|
||||
from fastNLP.modules import decoder, encoder
|
||||
|
||||
|
||||
def seq_mask(seq_len, max_len):
|
||||
"""Create a mask for the sequences.
|
||||
|
||||
:param seq_len: list or torch.LongTensor
|
||||
:param max_len: int
|
||||
:return mask: torch.LongTensor
|
||||
"""
|
||||
if isinstance(seq_len, list):
|
||||
seq_len = torch.LongTensor(seq_len)
|
||||
mask = [torch.ge(seq_len, i + 1) for i in range(max_len)]
|
||||
mask = torch.stack(mask, 1)
|
||||
return mask
|
||||
|
||||
|
||||
class SeqLabeling(BaseModel):
|
||||
"""
|
||||
PyTorch Network for sequence labeling
|
||||
@ -20,13 +34,17 @@ class SeqLabeling(BaseModel):
|
||||
self.Rnn = encoder.lstm.Lstm(word_emb_dim, hidden_dim)
|
||||
self.Linear = encoder.linear.Linear(hidden_dim, num_classes)
|
||||
self.Crf = decoder.CRF.ConditionalRandomField(num_classes)
|
||||
self.mask = None
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, word_seq, word_seq_origin_len):
|
||||
"""
|
||||
:param x: LongTensor, [batch_size, mex_len]
|
||||
:param word_seq: LongTensor, [batch_size, mex_len]
|
||||
:param word_seq_origin_len: LongTensor, [batch_size,], the origin lengths of the sequences.
|
||||
:return y: [batch_size, mex_len, tag_size]
|
||||
"""
|
||||
x = self.Embedding(x)
|
||||
self.mask = self.make_mask(word_seq, word_seq_origin_len)
|
||||
|
||||
x = self.Embedding(word_seq)
|
||||
# [batch_size, max_len, word_emb_dim]
|
||||
x = self.Rnn(x)
|
||||
# [batch_size, max_len, hidden_size * direction]
|
||||
@ -34,27 +52,34 @@ class SeqLabeling(BaseModel):
|
||||
# [batch_size, max_len, num_classes]
|
||||
return x
|
||||
|
||||
def loss(self, x, y, mask):
|
||||
def loss(self, x, y):
|
||||
"""
|
||||
Negative log likelihood loss.
|
||||
:param x: Tensor, [batch_size, max_len, tag_size]
|
||||
:param y: Tensor, [batch_size, max_len]
|
||||
:param mask: ByteTensor, [batch_size, ,max_len]
|
||||
:return loss: a scalar Tensor
|
||||
|
||||
"""
|
||||
x = x.float()
|
||||
y = y.long()
|
||||
total_loss = self.Crf(x, y, mask)
|
||||
assert x.shape[:2] == y.shape
|
||||
assert y.shape == self.mask.shape
|
||||
total_loss = self.Crf(x, y, self.mask)
|
||||
return torch.mean(total_loss)
|
||||
|
||||
def prediction(self, x, mask):
|
||||
def make_mask(self, x, seq_len):
|
||||
batch_size, max_len = x.size(0), x.size(1)
|
||||
mask = seq_mask(seq_len, max_len)
|
||||
mask = mask.byte().view(batch_size, max_len)
|
||||
mask = mask.to(x)
|
||||
return mask
|
||||
|
||||
def prediction(self, x):
|
||||
"""
|
||||
:param x: FloatTensor, [batch_size, max_len, tag_size]
|
||||
:param mask: ByteTensor, [batch_size, max_len]
|
||||
:return prediction: list of [decode path(list)]
|
||||
"""
|
||||
tag_seq = self.Crf.viterbi_decode(x, mask)
|
||||
tag_seq = self.Crf.viterbi_decode(x, self.mask)
|
||||
return tag_seq
|
||||
|
||||
|
||||
@ -81,14 +106,17 @@ class AdvSeqLabel(SeqLabeling):
|
||||
|
||||
self.Crf = decoder.CRF.ConditionalRandomField(num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, word_seq, word_seq_origin_len):
|
||||
"""
|
||||
:param x: LongTensor, [batch_size, mex_len]
|
||||
:param word_seq: LongTensor, [batch_size, mex_len]
|
||||
:param word_seq_origin_len: list of int.
|
||||
:return y: [batch_size, mex_len, tag_size]
|
||||
"""
|
||||
batch_size = x.size(0)
|
||||
max_len = x.size(1)
|
||||
x = self.Embedding(x)
|
||||
self.mask = self.make_mask(word_seq, word_seq_origin_len)
|
||||
|
||||
batch_size = word_seq.size(0)
|
||||
max_len = word_seq.size(1)
|
||||
x = self.Embedding(word_seq)
|
||||
# [batch_size, max_len, word_emb_dim]
|
||||
x = self.Rnn(x)
|
||||
# [batch_size, max_len, hidden_size * direction]
|
||||
|
@ -1,8 +1,10 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.autograd import Variable
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
from fastNLP.modules.utils import initial_parameter
|
||||
class SelfAttention(nn.Module):
|
||||
"""
|
||||
Self Attention Module.
|
||||
@ -13,13 +15,18 @@ class SelfAttention(nn.Module):
|
||||
num_vec: int, the number of encoded vectors
|
||||
"""
|
||||
|
||||
def __init__(self, input_size, dim=10, num_vec=10):
|
||||
def __init__(self, input_size, dim=10, num_vec=10 ,drop = 0.5 ,initial_method =None):
|
||||
super(SelfAttention, self).__init__()
|
||||
self.W_s1 = nn.Parameter(torch.randn(dim, input_size), requires_grad=True)
|
||||
self.W_s2 = nn.Parameter(torch.randn(num_vec, dim), requires_grad=True)
|
||||
# self.W_s1 = nn.Parameter(torch.randn(dim, input_size), requires_grad=True)
|
||||
# self.W_s2 = nn.Parameter(torch.randn(num_vec, dim), requires_grad=True)
|
||||
self.attention_hops = num_vec
|
||||
|
||||
self.ws1 = nn.Linear(input_size, dim, bias=False)
|
||||
self.ws2 = nn.Linear(dim, num_vec, bias=False)
|
||||
self.drop = nn.Dropout(drop)
|
||||
self.softmax = nn.Softmax(dim=2)
|
||||
self.tanh = nn.Tanh()
|
||||
|
||||
initial_parameter(self, initial_method)
|
||||
def penalization(self, A):
|
||||
"""
|
||||
compute the penalization term for attention module
|
||||
@ -32,11 +39,33 @@ class SelfAttention(nn.Module):
|
||||
M = M.view(M.size(0), -1)
|
||||
return torch.sum(M ** 2, dim=1)
|
||||
|
||||
def forward(self, x):
|
||||
inter = self.tanh(torch.matmul(self.W_s1, torch.transpose(x, 1, 2)))
|
||||
A = self.softmax(torch.matmul(self.W_s2, inter))
|
||||
out = torch.matmul(A, x)
|
||||
out = out.view(out.size(0), -1)
|
||||
penalty = self.penalization(A)
|
||||
return out, penalty
|
||||
def forward(self, outp ,inp):
|
||||
# the following code can not be use because some word are padding ,these is not such module!
|
||||
|
||||
# inter = self.tanh(torch.matmul(self.W_s1, torch.transpose(x, 1, 2))) # []
|
||||
# A = self.softmax(torch.matmul(self.W_s2, inter))
|
||||
# out = torch.matmul(A, x)
|
||||
# out = out.view(out.size(0), -1)
|
||||
# penalty = self.penalization(A)
|
||||
# return out, penalty
|
||||
outp = outp.contiguous()
|
||||
size = outp.size() # [bsz, len, nhid]
|
||||
|
||||
compressed_embeddings = outp.view(-1, size[2]) # [bsz*len, nhid*2]
|
||||
transformed_inp = torch.transpose(inp, 0, 1).contiguous() # [bsz, len]
|
||||
transformed_inp = transformed_inp.view(size[0], 1, size[1]) # [bsz, 1, len]
|
||||
concatenated_inp = [transformed_inp for i in range(self.attention_hops)]
|
||||
concatenated_inp = torch.cat(concatenated_inp, 1) # [bsz, hop, len]
|
||||
|
||||
hbar = self.tanh(self.ws1(self.drop(compressed_embeddings))) # [bsz*len, attention-unit]
|
||||
attention = self.ws2(hbar).view(size[0], size[1], -1) # [bsz, len, hop]
|
||||
attention = torch.transpose(attention, 1, 2).contiguous() # [bsz, hop, len]
|
||||
penalized_alphas = attention + (
|
||||
-10000 * (concatenated_inp == 0).float())
|
||||
# [bsz, hop, len] + [bsz, hop, len]
|
||||
attention = self.softmax(penalized_alphas.view(-1, size[1])) # [bsz*hop, len]
|
||||
attention = attention.view(size[0], self.attention_hops, size[1]) # [bsz, hop, len]
|
||||
return torch.bmm(attention, outp), attention # output --> [baz ,hop ,nhid]
|
||||
|
||||
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from fastNLP.modules.utils import initial_parameter
|
||||
|
||||
def log_sum_exp(x, dim=-1):
|
||||
max_value, _ = x.max(dim=dim, keepdim=True)
|
||||
@ -19,7 +20,7 @@ def seq_len_to_byte_mask(seq_lens):
|
||||
|
||||
|
||||
class ConditionalRandomField(nn.Module):
|
||||
def __init__(self, tag_size, include_start_end_trans=True):
|
||||
def __init__(self, tag_size, include_start_end_trans=True ,initial_method = None):
|
||||
"""
|
||||
:param tag_size: int, num of tags
|
||||
:param include_start_end_trans: bool, whether to include start/end tag
|
||||
@ -35,8 +36,8 @@ class ConditionalRandomField(nn.Module):
|
||||
self.start_scores = nn.Parameter(torch.randn(tag_size))
|
||||
self.end_scores = nn.Parameter(torch.randn(tag_size))
|
||||
|
||||
self.reset_parameter()
|
||||
|
||||
# self.reset_parameter()
|
||||
initial_parameter(self, initial_method)
|
||||
def reset_parameter(self):
|
||||
nn.init.xavier_normal_(self.transition_m)
|
||||
if self.include_start_end_trans:
|
||||
|
@ -1,8 +1,8 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from fastNLP.modules.utils import initial_parameter
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, size_layer, num_class=2, activation='relu'):
|
||||
def __init__(self, size_layer, num_class=2, activation='relu' , initial_method = None):
|
||||
"""Multilayer Perceptrons as a decoder
|
||||
|
||||
Args:
|
||||
@ -36,7 +36,7 @@ class MLP(nn.Module):
|
||||
self.hidden_active = activation
|
||||
else:
|
||||
raise ValueError("should set activation correctly: {}".format(activation))
|
||||
|
||||
initial_parameter(self, initial_method )
|
||||
def forward(self, x):
|
||||
for layer in self.hiddens:
|
||||
x = self.hidden_active(layer(x))
|
||||
|
@ -1,11 +1,12 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
# from torch.nn.init import xavier_uniform
|
||||
|
||||
|
||||
from fastNLP.modules.utils import initial_parameter
|
||||
class ConvCharEmbedding(nn.Module):
|
||||
|
||||
def __init__(self, char_emb_size=50, feature_maps=(40, 30, 30), kernels=(3, 4, 5)):
|
||||
def __init__(self, char_emb_size=50, feature_maps=(40, 30, 30), kernels=(3, 4, 5),initial_method = None):
|
||||
"""
|
||||
Character Level Word Embedding
|
||||
:param char_emb_size: the size of character level embedding. Default: 50
|
||||
@ -20,6 +21,8 @@ class ConvCharEmbedding(nn.Module):
|
||||
nn.Conv2d(1, feature_maps[i], kernel_size=(char_emb_size, kernels[i]), bias=True, padding=(0, 4))
|
||||
for i in range(len(kernels))])
|
||||
|
||||
initial_parameter(self,initial_method)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
:param x: [batch_size * sent_length, word_length, char_emb_size]
|
||||
@ -53,7 +56,7 @@ class LSTMCharEmbedding(nn.Module):
|
||||
:param hidden_size: int, the number of hidden units. Default: equal to char_emb_size.
|
||||
"""
|
||||
|
||||
def __init__(self, char_emb_size=50, hidden_size=None):
|
||||
def __init__(self, char_emb_size=50, hidden_size=None , initial_method= None):
|
||||
super(LSTMCharEmbedding, self).__init__()
|
||||
self.hidden_size = char_emb_size if hidden_size is None else hidden_size
|
||||
|
||||
@ -62,7 +65,7 @@ class LSTMCharEmbedding(nn.Module):
|
||||
num_layers=1,
|
||||
bias=True,
|
||||
batch_first=True)
|
||||
|
||||
initial_parameter(self, initial_method)
|
||||
def forward(self, x):
|
||||
"""
|
||||
:param x:[ n_batch*n_word, word_length, char_emb_size]
|
||||
|
@ -6,6 +6,7 @@ import torch.nn as nn
|
||||
from torch.nn.init import xavier_uniform_
|
||||
# import torch.nn.functional as F
|
||||
|
||||
from fastNLP.modules.utils import initial_parameter
|
||||
|
||||
class Conv(nn.Module):
|
||||
"""
|
||||
@ -15,7 +16,7 @@ class Conv(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size,
|
||||
stride=1, padding=0, dilation=1,
|
||||
groups=1, bias=True, activation='relu'):
|
||||
groups=1, bias=True, activation='relu',initial_method = None ):
|
||||
super(Conv, self).__init__()
|
||||
self.conv = nn.Conv1d(
|
||||
in_channels=in_channels,
|
||||
@ -26,7 +27,7 @@ class Conv(nn.Module):
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias)
|
||||
xavier_uniform_(self.conv.weight)
|
||||
# xavier_uniform_(self.conv.weight)
|
||||
|
||||
activations = {
|
||||
'relu': nn.ReLU(),
|
||||
@ -37,6 +38,7 @@ class Conv(nn.Module):
|
||||
raise Exception(
|
||||
'Should choose activation function from: ' +
|
||||
', '.join([x for x in activations]))
|
||||
initial_parameter(self, initial_method)
|
||||
|
||||
def forward(self, x):
|
||||
x = torch.transpose(x, 1, 2) # [N,L,C] -> [N,C,L]
|
||||
|
@ -5,7 +5,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.init import xavier_uniform_
|
||||
|
||||
from fastNLP.modules.utils import initial_parameter
|
||||
|
||||
class ConvMaxpool(nn.Module):
|
||||
"""
|
||||
@ -14,7 +14,7 @@ class ConvMaxpool(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_sizes,
|
||||
stride=1, padding=0, dilation=1,
|
||||
groups=1, bias=True, activation='relu'):
|
||||
groups=1, bias=True, activation='relu',initial_method = None ):
|
||||
super(ConvMaxpool, self).__init__()
|
||||
|
||||
# convolution
|
||||
@ -47,6 +47,8 @@ class ConvMaxpool(nn.Module):
|
||||
raise Exception(
|
||||
"Undefined activation function: choose from: relu")
|
||||
|
||||
initial_parameter(self, initial_method)
|
||||
|
||||
def forward(self, x):
|
||||
# [N,L,C] -> [N,C,L]
|
||||
x = torch.transpose(x, 1, 2)
|
||||
|
@ -1,6 +1,6 @@
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
from fastNLP.modules.utils import initial_parameter
|
||||
class Linear(nn.Module):
|
||||
"""
|
||||
Linear module
|
||||
@ -12,10 +12,10 @@ class Linear(nn.Module):
|
||||
bidirectional : If True, becomes a bidirectional RNN
|
||||
"""
|
||||
|
||||
def __init__(self, input_size, output_size, bias=True):
|
||||
def __init__(self, input_size, output_size, bias=True,initial_method = None ):
|
||||
super(Linear, self).__init__()
|
||||
self.linear = nn.Linear(input_size, output_size, bias)
|
||||
|
||||
initial_parameter(self, initial_method)
|
||||
def forward(self, x):
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
@ -1,6 +1,6 @@
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
from fastNLP.modules.utils import initial_parameter
|
||||
class Lstm(nn.Module):
|
||||
"""
|
||||
LSTM module
|
||||
@ -13,11 +13,13 @@ class Lstm(nn.Module):
|
||||
bidirectional : If True, becomes a bidirectional RNN. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0, bidirectional=False):
|
||||
def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0, bidirectional=False , initial_method = None):
|
||||
super(Lstm, self).__init__()
|
||||
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=True, batch_first=True,
|
||||
dropout=dropout, bidirectional=bidirectional)
|
||||
|
||||
initial_parameter(self, initial_method)
|
||||
def forward(self, x):
|
||||
x, _ = self.lstm(x)
|
||||
return x
|
||||
if __name__ == "__main__":
|
||||
lstm = Lstm(10)
|
||||
|
@ -4,7 +4,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
from fastNLP.modules.utils import initial_parameter
|
||||
def MaskedRecurrent(reverse=False):
|
||||
def forward(input, hidden, cell, mask, train=True, dropout=0):
|
||||
"""
|
||||
@ -192,7 +192,7 @@ def AutogradMaskedStep(num_layers=1, dropout=0, train=True, lstm=False):
|
||||
class MaskedRNNBase(nn.Module):
|
||||
def __init__(self, Cell, input_size, hidden_size,
|
||||
num_layers=1, bias=True, batch_first=False,
|
||||
layer_dropout=0, step_dropout=0, bidirectional=False, **kwargs):
|
||||
layer_dropout=0, step_dropout=0, bidirectional=False, initial_method = None , **kwargs):
|
||||
"""
|
||||
:param Cell:
|
||||
:param input_size:
|
||||
@ -226,7 +226,7 @@ class MaskedRNNBase(nn.Module):
|
||||
cell = self.Cell(layer_input_size, hidden_size, self.bias, **kwargs)
|
||||
self.all_cells.append(cell)
|
||||
self.add_module('cell%d' % (layer * num_directions + direction), cell) # Max的代码写得真好看
|
||||
|
||||
initial_parameter(self, initial_method)
|
||||
def reset_parameters(self):
|
||||
for cell in self.all_cells:
|
||||
cell.reset_parameters()
|
||||
|
@ -6,6 +6,7 @@ import torch.nn.functional as F
|
||||
from torch.nn._functions.thnn import rnnFusedPointwise as fusedBackend
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from fastNLP.modules.utils import initial_parameter
|
||||
|
||||
def default_initializer(hidden_size):
|
||||
stdv = 1.0 / math.sqrt(hidden_size)
|
||||
@ -172,7 +173,7 @@ def AutogradVarMaskedStep(num_layers=1, lstm=False):
|
||||
class VarMaskedRNNBase(nn.Module):
|
||||
def __init__(self, Cell, input_size, hidden_size,
|
||||
num_layers=1, bias=True, batch_first=False,
|
||||
dropout=(0, 0), bidirectional=False, initializer=None, **kwargs):
|
||||
dropout=(0, 0), bidirectional=False, initializer=None,initial_method = None, **kwargs):
|
||||
|
||||
super(VarMaskedRNNBase, self).__init__()
|
||||
self.Cell = Cell
|
||||
@ -193,7 +194,7 @@ class VarMaskedRNNBase(nn.Module):
|
||||
cell = self.Cell(layer_input_size, hidden_size, self.bias, p=dropout, initializer=initializer, **kwargs)
|
||||
self.all_cells.append(cell)
|
||||
self.add_module('cell%d' % (layer * num_directions + direction), cell)
|
||||
|
||||
initial_parameter(self, initial_method)
|
||||
def reset_parameters(self):
|
||||
for cell in self.all_cells:
|
||||
cell.reset_parameters()
|
||||
@ -284,7 +285,7 @@ class VarFastLSTMCell(VarRNNCellBase):
|
||||
\end{array}
|
||||
"""
|
||||
|
||||
def __init__(self, input_size, hidden_size, bias=True, p=(0.5, 0.5), initializer=None):
|
||||
def __init__(self, input_size, hidden_size, bias=True, p=(0.5, 0.5), initializer=None,initial_method =None):
|
||||
super(VarFastLSTMCell, self).__init__()
|
||||
self.input_size = input_size
|
||||
self.hidden_size = hidden_size
|
||||
@ -311,7 +312,7 @@ class VarFastLSTMCell(VarRNNCellBase):
|
||||
self.p_hidden = p_hidden
|
||||
self.noise_in = None
|
||||
self.noise_hidden = None
|
||||
|
||||
initial_parameter(self, initial_method)
|
||||
def reset_parameters(self):
|
||||
for weight in self.parameters():
|
||||
if weight.dim() == 1:
|
||||
|
@ -2,8 +2,8 @@ from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
import torch.nn.init as init
|
||||
import torch.nn as nn
|
||||
def mask_softmax(matrix, mask):
|
||||
if mask is None:
|
||||
result = torch.nn.functional.softmax(matrix, dim=-1)
|
||||
@ -11,6 +11,51 @@ def mask_softmax(matrix, mask):
|
||||
raise NotImplementedError
|
||||
return result
|
||||
|
||||
def initial_parameter(net ,initial_method =None):
|
||||
|
||||
if initial_method == 'xavier_uniform':
|
||||
init_method = init.xavier_uniform_
|
||||
elif initial_method=='xavier_normal':
|
||||
init_method = init.xavier_normal_
|
||||
elif initial_method == 'kaiming_normal' or initial_method =='msra':
|
||||
init_method = init.kaiming_normal
|
||||
elif initial_method == 'kaiming_uniform':
|
||||
init_method = init.kaiming_normal
|
||||
elif initial_method == 'orthogonal':
|
||||
init_method = init.orthogonal_
|
||||
elif initial_method == 'sparse':
|
||||
init_method = init.sparse_
|
||||
elif initial_method =='normal':
|
||||
init_method = init.normal_
|
||||
elif initial_method =='uniform':
|
||||
initial_method = init.uniform_
|
||||
else:
|
||||
init_method = init.xavier_normal_
|
||||
def weights_init(m):
|
||||
# classname = m.__class__.__name__
|
||||
if isinstance(m, nn.Conv2d) or isinstance(m,nn.Conv1d) or isinstance(m,nn.Conv3d): # for all the cnn
|
||||
if initial_method != None:
|
||||
init_method(m.weight.data)
|
||||
else:
|
||||
init.xavier_normal_(m.weight.data)
|
||||
init.normal_(m.bias.data)
|
||||
elif isinstance(m, nn.LSTM):
|
||||
for w in m.parameters():
|
||||
if len(w.data.size())>1:
|
||||
init_method(w.data) # weight
|
||||
else:
|
||||
init.normal_(w.data) # bias
|
||||
elif hasattr(m, 'weight') and m.weight.requires_grad:
|
||||
init_method(m.weight.data)
|
||||
else:
|
||||
for w in m.parameters() :
|
||||
if w.requires_grad:
|
||||
if len(w.data.size())>1:
|
||||
init_method(w.data) # weight
|
||||
else:
|
||||
init.normal_(w.data) # bias
|
||||
# print("init else")
|
||||
net.apply(weights_init)
|
||||
|
||||
def seq_mask(seq_len, max_len):
|
||||
mask = [torch.ge(torch.LongTensor(seq_len), i + 1) for i in range(max_len)]
|
||||
|
@ -0,0 +1,13 @@
|
||||
[train]
|
||||
epochs = 30
|
||||
batch_size = 32
|
||||
pickle_path = "./save/"
|
||||
validate = true
|
||||
save_best_dev = true
|
||||
model_saved_path = "./save/"
|
||||
rnn_hidden_units = 300
|
||||
word_emb_dim = 300
|
||||
use_crf = true
|
||||
use_cuda = false
|
||||
loss_func = "cross_entropy"
|
||||
num_classes = 5
|
80
reproduction/LSTM+self_attention_sentiment_analysis/main.py
Normal file
80
reproduction/LSTM+self_attention_sentiment_analysis/main.py
Normal file
@ -0,0 +1,80 @@
|
||||
|
||||
import os
|
||||
|
||||
import torch.nn.functional as F
|
||||
|
||||
from fastNLP.loader.dataset_loader import ClassDatasetLoader as Dataset_loader
|
||||
from fastNLP.loader.embed_loader import EmbedLoader as EmbedLoader
|
||||
from fastNLP.loader.config_loader import ConfigSection
|
||||
from fastNLP.loader.config_loader import ConfigLoader
|
||||
|
||||
from fastNLP.models.base_model import BaseModel
|
||||
|
||||
from fastNLP.core.preprocess import ClassPreprocess as Preprocess
|
||||
from fastNLP.core.trainer import ClassificationTrainer
|
||||
|
||||
from fastNLP.modules.encoder.embedding import Embedding as Embedding
|
||||
from fastNLP.modules.encoder.lstm import Lstm
|
||||
from fastNLP.modules.aggregation.self_attention import SelfAttention
|
||||
from fastNLP.modules.decoder.MLP import MLP
|
||||
|
||||
|
||||
train_data_path = 'small_train_data.txt'
|
||||
dev_data_path = 'small_dev_data.txt'
|
||||
# emb_path = 'glove.txt'
|
||||
|
||||
lstm_hidden_size = 300
|
||||
embeding_size = 300
|
||||
attention_unit = 350
|
||||
attention_hops = 10
|
||||
class_num = 5
|
||||
nfc = 3000
|
||||
### data load ###
|
||||
train_dataset = Dataset_loader(train_data_path)
|
||||
train_data = train_dataset.load()
|
||||
|
||||
dev_args = Dataset_loader(dev_data_path)
|
||||
dev_data = dev_args.load()
|
||||
|
||||
###### preprocess ####
|
||||
preprocess = Preprocess()
|
||||
word2index, label2index = preprocess.build_dict(train_data)
|
||||
train_data, dev_data = preprocess.run(train_data, dev_data)
|
||||
|
||||
|
||||
|
||||
# emb = EmbedLoader(emb_path)
|
||||
# embedding = emb.load_embedding(emb_dim= embeding_size , emb_file= emb_path ,word_dict= word2index)
|
||||
### construct vocab ###
|
||||
|
||||
class SELF_ATTENTION_YELP_CLASSIFICATION(BaseModel):
|
||||
def __init__(self, args=None):
|
||||
super(SELF_ATTENTION_YELP_CLASSIFICATION,self).__init__()
|
||||
self.embedding = Embedding(len(word2index) ,embeding_size , init_emb= None )
|
||||
self.lstm = Lstm(input_size = embeding_size,hidden_size = lstm_hidden_size ,bidirectional = True)
|
||||
self.attention = SelfAttention(lstm_hidden_size * 2 ,dim =attention_unit ,num_vec=attention_hops)
|
||||
self.mlp = MLP(size_layer=[lstm_hidden_size * 2*attention_hops ,nfc ,class_num ] ,num_class=class_num ,)
|
||||
def forward(self,x):
|
||||
x_emb = self.embedding(x)
|
||||
output = self.lstm(x_emb)
|
||||
after_attention, penalty = self.attention(output,x)
|
||||
after_attention =after_attention.view(after_attention.size(0),-1)
|
||||
output = self.mlp(after_attention)
|
||||
return output
|
||||
|
||||
def loss(self, predict, ground_truth):
|
||||
print("predict:%s; g:%s" % (str(predict.size()), str(ground_truth.size())))
|
||||
print(ground_truth)
|
||||
return F.cross_entropy(predict, ground_truth)
|
||||
|
||||
train_args = ConfigSection()
|
||||
ConfigLoader("good path").load_config('config.cfg',{"train": train_args})
|
||||
train_args['vocab'] = len(word2index)
|
||||
|
||||
|
||||
trainer = ClassificationTrainer(**train_args.data)
|
||||
|
||||
# for k in train_args.__dict__.keys():
|
||||
# print(k, train_args[k])
|
||||
model = SELF_ATTENTION_YELP_CLASSIFICATION(train_args)
|
||||
trainer.train(model,train_data , dev_data)
|
8
setup.py
8
setup.py
@ -2,18 +2,18 @@
|
||||
# coding=utf-8
|
||||
from setuptools import setup, find_packages
|
||||
|
||||
with open('README.md') as f:
|
||||
with open('README.md', encoding='utf-8') as f:
|
||||
readme = f.read()
|
||||
|
||||
with open('LICENSE') as f:
|
||||
with open('LICENSE', encoding='utf-8') as f:
|
||||
license = f.read()
|
||||
|
||||
with open('requirements.txt') as f:
|
||||
with open('requirements.txt', encoding='utf-8') as f:
|
||||
reqs = f.read()
|
||||
|
||||
setup(
|
||||
name='fastNLP',
|
||||
version='0.0.1',
|
||||
version='0.0.3',
|
||||
description='fastNLP: Deep Learning Toolkit for NLP, developed by Fudan FastNLP Team',
|
||||
long_description=readme,
|
||||
license=license,
|
||||
|
@ -1,17 +0,0 @@
|
||||
import unittest
|
||||
|
||||
from fastNLP.core.action import Action, Batchifier, SequentialSampler
|
||||
|
||||
|
||||
class TestAction(unittest.TestCase):
|
||||
def test_case_1(self):
|
||||
x = [1, 2, 3, 4, 5, 6, 7, 8]
|
||||
y = [1, 1, 1, 1, 2, 2, 2, 2]
|
||||
data = []
|
||||
for i in range(len(x)):
|
||||
data.append([[x[i]], [y[i]]])
|
||||
data = Batchifier(SequentialSampler(data), batch_size=2, drop_last=False)
|
||||
action = Action()
|
||||
for batch_x in action.make_batch(data, use_cuda=False, output_length=True, max_len=None):
|
||||
print(batch_x)
|
||||
|
62
test/core/test_batch.py
Normal file
62
test/core/test_batch.py
Normal file
@ -0,0 +1,62 @@
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from fastNLP.core.batch import Batch
|
||||
from fastNLP.core.dataset import DataSet, create_dataset_from_lists
|
||||
from fastNLP.core.field import TextField, LabelField
|
||||
from fastNLP.core.instance import Instance
|
||||
|
||||
raw_texts = ["i am a cat",
|
||||
"this is a test of new batch",
|
||||
"ha ha",
|
||||
"I am a good boy .",
|
||||
"This is the most beautiful girl ."
|
||||
]
|
||||
texts = [text.strip().split() for text in raw_texts]
|
||||
labels = [0, 1, 0, 0, 1]
|
||||
|
||||
# prepare vocabulary
|
||||
vocab = {}
|
||||
for text in texts:
|
||||
for tokens in text:
|
||||
if tokens not in vocab:
|
||||
vocab[tokens] = len(vocab)
|
||||
|
||||
|
||||
class TestCase1(unittest.TestCase):
|
||||
def test(self):
|
||||
data = DataSet()
|
||||
for text, label in zip(texts, labels):
|
||||
x = TextField(text, is_target=False)
|
||||
y = LabelField(label, is_target=True)
|
||||
ins = Instance(text=x, label=y)
|
||||
data.append(ins)
|
||||
|
||||
# use vocabulary to index data
|
||||
data.index_field("text", vocab)
|
||||
|
||||
# define naive sampler for batch class
|
||||
class SeqSampler:
|
||||
def __call__(self, dataset):
|
||||
return list(range(len(dataset)))
|
||||
|
||||
# use batch to iterate dataset
|
||||
data_iterator = Batch(data, 2, SeqSampler(), False)
|
||||
for batch_x, batch_y in data_iterator:
|
||||
self.assertEqual(len(batch_x), 2)
|
||||
self.assertTrue(isinstance(batch_x, dict))
|
||||
self.assertTrue(isinstance(batch_x["text"], torch.LongTensor))
|
||||
self.assertTrue(isinstance(batch_y, dict))
|
||||
self.assertTrue(isinstance(batch_y["label"], torch.LongTensor))
|
||||
|
||||
|
||||
class TestCase2(unittest.TestCase):
|
||||
def test(self):
|
||||
data = DataSet()
|
||||
for text in texts:
|
||||
x = TextField(text, is_target=False)
|
||||
ins = Instance(text=x)
|
||||
data.append(ins)
|
||||
data_set = create_dataset_from_lists(texts, vocab, has_target=False)
|
||||
self.assertTrue(type(data) == type(data_set))
|
51
test/core/test_predictor.py
Normal file
51
test/core/test_predictor.py
Normal file
@ -0,0 +1,51 @@
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from fastNLP.core.predictor import Predictor
|
||||
from fastNLP.core.preprocess import save_pickle
|
||||
from fastNLP.models.sequence_modeling import SeqLabeling
|
||||
|
||||
|
||||
class TestPredictor(unittest.TestCase):
|
||||
def test_seq_label(self):
|
||||
model_args = {
|
||||
"vocab_size": 10,
|
||||
"word_emb_dim": 100,
|
||||
"rnn_hidden_units": 100,
|
||||
"num_classes": 5
|
||||
}
|
||||
|
||||
infer_data = [
|
||||
['a', 'b', 'c', 'd', 'e'],
|
||||
['a', '@', 'c', 'd', 'e'],
|
||||
['a', 'b', '#', 'd', 'e'],
|
||||
['a', 'b', 'c', '?', 'e'],
|
||||
['a', 'b', 'c', 'd', '$'],
|
||||
['!', 'b', 'c', 'd', 'e']
|
||||
]
|
||||
vocab = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9}
|
||||
|
||||
os.system("mkdir save")
|
||||
save_pickle({0: "0", 1: "1", 2: "2", 3: "3", 4: "4"}, "./save/", "id2class.pkl")
|
||||
save_pickle(vocab, "./save/", "word2id.pkl")
|
||||
|
||||
model = SeqLabeling(model_args)
|
||||
predictor = Predictor("./save/", task="seq_label")
|
||||
|
||||
results = predictor.predict(network=model, data=infer_data)
|
||||
|
||||
self.assertTrue(isinstance(results, list))
|
||||
self.assertGreater(len(results), 0)
|
||||
for res in results:
|
||||
self.assertTrue(isinstance(res, list))
|
||||
self.assertEqual(len(res), 5)
|
||||
self.assertTrue(isinstance(res[0], str))
|
||||
|
||||
os.system("rm -rf save")
|
||||
print("pickle path deleted")
|
||||
|
||||
|
||||
class TestPredictor2(unittest.TestCase):
|
||||
def test_text_classify(self):
|
||||
# TODO
|
||||
pass
|
@ -1,24 +1,25 @@
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from fastNLP.core.dataset import DataSet
|
||||
from fastNLP.core.preprocess import SeqLabelPreprocess
|
||||
|
||||
data = [
|
||||
[['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']],
|
||||
[['Hello', 'world', '!'], ['a', 'n', '.']],
|
||||
[['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']],
|
||||
[['Hello', 'world', '!'], ['a', 'n', '.']],
|
||||
[['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']],
|
||||
[['Hello', 'world', '!'], ['a', 'n', '.']],
|
||||
[['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']],
|
||||
[['Hello', 'world', '!'], ['a', 'n', '.']],
|
||||
[['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']],
|
||||
[['Hello', 'world', '!'], ['a', 'n', '.']],
|
||||
]
|
||||
|
||||
class TestSeqLabelPreprocess(unittest.TestCase):
|
||||
def test_case_1(self):
|
||||
data = [
|
||||
[['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']],
|
||||
[['Hello', 'world', '!'], ['a', 'n', '.']],
|
||||
[['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']],
|
||||
[['Hello', 'world', '!'], ['a', 'n', '.']],
|
||||
[['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']],
|
||||
[['Hello', 'world', '!'], ['a', 'n', '.']],
|
||||
[['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']],
|
||||
[['Hello', 'world', '!'], ['a', 'n', '.']],
|
||||
[['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']],
|
||||
[['Hello', 'world', '!'], ['a', 'n', '.']],
|
||||
]
|
||||
|
||||
class TestCase1(unittest.TestCase):
|
||||
def test(self):
|
||||
if os.path.exists("./save"):
|
||||
for root, dirs, files in os.walk("./save", topdown=False):
|
||||
for name in files:
|
||||
@ -27,17 +28,45 @@ class TestSeqLabelPreprocess(unittest.TestCase):
|
||||
os.rmdir(os.path.join(root, name))
|
||||
result = SeqLabelPreprocess().run(train_dev_data=data, train_dev_split=0.4,
|
||||
pickle_path="./save")
|
||||
result = SeqLabelPreprocess().run(train_dev_data=data, train_dev_split=0.4,
|
||||
pickle_path="./save")
|
||||
self.assertEqual(len(result), 2)
|
||||
self.assertEqual(type(result[0]), DataSet)
|
||||
self.assertEqual(type(result[1]), DataSet)
|
||||
|
||||
os.system("rm -rf save")
|
||||
print("pickle path deleted")
|
||||
|
||||
|
||||
class TestCase2(unittest.TestCase):
|
||||
def test(self):
|
||||
if os.path.exists("./save"):
|
||||
for root, dirs, files in os.walk("./save", topdown=False):
|
||||
for name in files:
|
||||
os.remove(os.path.join(root, name))
|
||||
for name in dirs:
|
||||
os.rmdir(os.path.join(root, name))
|
||||
result = SeqLabelPreprocess().run(test_data=data, train_dev_data=data,
|
||||
pickle_path="./save", train_dev_split=0.4,
|
||||
cross_val=True)
|
||||
result = SeqLabelPreprocess().run(test_data=data, train_dev_data=data,
|
||||
pickle_path="./save", train_dev_split=0.4,
|
||||
cross_val=True)
|
||||
cross_val=False)
|
||||
self.assertEqual(len(result), 3)
|
||||
self.assertEqual(type(result[0]), DataSet)
|
||||
self.assertEqual(type(result[1]), DataSet)
|
||||
self.assertEqual(type(result[2]), DataSet)
|
||||
|
||||
os.system("rm -rf save")
|
||||
print("pickle path deleted")
|
||||
|
||||
|
||||
class TestCase3(unittest.TestCase):
|
||||
def test(self):
|
||||
num_folds = 2
|
||||
result = SeqLabelPreprocess().run(test_data=None, train_dev_data=data,
|
||||
pickle_path="./save", train_dev_split=0.4,
|
||||
cross_val=True, n_fold=num_folds)
|
||||
self.assertEqual(len(result), 2)
|
||||
self.assertEqual(len(result[0]), num_folds)
|
||||
self.assertEqual(len(result[1]), num_folds)
|
||||
for data_set in result[0] + result[1]:
|
||||
self.assertEqual(type(data_set), DataSet)
|
||||
|
||||
os.system("rm -rf save")
|
||||
print("pickle path deleted")
|
||||
|
@ -1,37 +1,55 @@
|
||||
from fastNLP.core.preprocess import SeqLabelPreprocess
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from fastNLP.core.dataset import DataSet
|
||||
from fastNLP.core.field import TextField
|
||||
from fastNLP.core.instance import Instance
|
||||
from fastNLP.core.tester import SeqLabelTester
|
||||
from fastNLP.loader.config_loader import ConfigSection, ConfigLoader
|
||||
from fastNLP.loader.dataset_loader import TokenizeDatasetLoader
|
||||
from fastNLP.models.sequence_modeling import SeqLabeling
|
||||
|
||||
data_name = "pku_training.utf8"
|
||||
pickle_path = "data_for_tests"
|
||||
|
||||
|
||||
def foo():
|
||||
loader = TokenizeDatasetLoader("./data_for_tests/cws_pku_utf_8")
|
||||
train_data = loader.load_pku()
|
||||
class TestTester(unittest.TestCase):
|
||||
def test_case_1(self):
|
||||
model_args = {
|
||||
"vocab_size": 10,
|
||||
"word_emb_dim": 100,
|
||||
"rnn_hidden_units": 100,
|
||||
"num_classes": 5
|
||||
}
|
||||
valid_args = {"save_output": True, "validate_in_training": True, "save_dev_input": True,
|
||||
"save_loss": True, "batch_size": 2, "pickle_path": "./save/",
|
||||
"use_cuda": False, "print_every_step": 1}
|
||||
|
||||
train_args = ConfigSection()
|
||||
ConfigLoader("config.cfg").load_config("./data_for_tests/config", {"POS": train_args})
|
||||
train_data = [
|
||||
[['a', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
|
||||
[['a', '@', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
|
||||
[['a', 'b', '#', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
|
||||
[['a', 'b', 'c', '?', 'e'], ['a', '@', 'c', 'd', 'e']],
|
||||
[['a', 'b', 'c', 'd', '$'], ['a', '@', 'c', 'd', 'e']],
|
||||
[['!', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
|
||||
]
|
||||
vocab = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9}
|
||||
label_vocab = {'a': 0, '@': 1, 'c': 2, 'd': 3, 'e': 4}
|
||||
|
||||
# Preprocessor
|
||||
p = SeqLabelPreprocess()
|
||||
train_data = p.run(train_data)
|
||||
train_args["vocab_size"] = p.vocab_size
|
||||
train_args["num_classes"] = p.num_classes
|
||||
data_set = DataSet()
|
||||
for example in train_data:
|
||||
text, label = example[0], example[1]
|
||||
x = TextField(text, False)
|
||||
y = TextField(label, is_target=True)
|
||||
ins = Instance(word_seq=x, label_seq=y)
|
||||
data_set.append(ins)
|
||||
|
||||
model = SeqLabeling(train_args)
|
||||
data_set.index_field("word_seq", vocab)
|
||||
data_set.index_field("label_seq", label_vocab)
|
||||
|
||||
valid_args = {"save_output": True, "validate_in_training": True, "save_dev_input": True,
|
||||
"save_loss": True, "batch_size": 8, "pickle_path": "./data_for_tests/",
|
||||
"use_cuda": True}
|
||||
validator = SeqLabelTester(**valid_args)
|
||||
model = SeqLabeling(model_args)
|
||||
|
||||
print("start validation.")
|
||||
validator.test(model, train_data)
|
||||
print(validator.show_metrics())
|
||||
tester = SeqLabelTester(**valid_args)
|
||||
tester.test(network=model, dev_data=data_set)
|
||||
# If this can run, everything is OK.
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
foo()
|
||||
os.system("rm -rf save")
|
||||
print("pickle path deleted")
|
||||
|
@ -1,33 +1,54 @@
|
||||
import os
|
||||
|
||||
import torch.nn as nn
|
||||
import unittest
|
||||
|
||||
from fastNLP.core.trainer import SeqLabelTrainer
|
||||
from fastNLP.core.dataset import DataSet
|
||||
from fastNLP.core.field import TextField
|
||||
from fastNLP.core.instance import Instance
|
||||
from fastNLP.core.loss import Loss
|
||||
from fastNLP.core.optimizer import Optimizer
|
||||
from fastNLP.core.trainer import SeqLabelTrainer
|
||||
from fastNLP.models.sequence_modeling import SeqLabeling
|
||||
|
||||
|
||||
class TestTrainer(unittest.TestCase):
|
||||
def test_case_1(self):
|
||||
args = {"epochs": 3, "batch_size": 8, "validate": True, "use_cuda": True, "pickle_path": "./save/",
|
||||
args = {"epochs": 3, "batch_size": 2, "validate": True, "use_cuda": False, "pickle_path": "./save/",
|
||||
"save_best_dev": True, "model_name": "default_model_name.pkl",
|
||||
"loss": Loss(None),
|
||||
"optimizer": Optimizer("Adam", lr=0.001, weight_decay=0),
|
||||
"vocab_size": 20,
|
||||
"vocab_size": 10,
|
||||
"word_emb_dim": 100,
|
||||
"rnn_hidden_units": 100,
|
||||
"num_classes": 3
|
||||
"num_classes": 5
|
||||
}
|
||||
trainer = SeqLabelTrainer()
|
||||
trainer = SeqLabelTrainer(**args)
|
||||
|
||||
train_data = [
|
||||
[[1, 2, 3, 4, 5, 6], [1, 0, 1, 0, 1, 2]],
|
||||
[[2, 3, 4, 5, 1, 6], [0, 1, 0, 1, 0, 2]],
|
||||
[[1, 4, 1, 4, 1, 6], [1, 0, 1, 0, 1, 2]],
|
||||
[[1, 2, 3, 4, 5, 6], [1, 0, 1, 0, 1, 2]],
|
||||
[[2, 3, 4, 5, 1, 6], [0, 1, 0, 1, 0, 2]],
|
||||
[[1, 4, 1, 4, 1, 6], [1, 0, 1, 0, 1, 2]],
|
||||
[['a', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
|
||||
[['a', '@', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
|
||||
[['a', 'b', '#', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
|
||||
[['a', 'b', 'c', '?', 'e'], ['a', '@', 'c', 'd', 'e']],
|
||||
[['a', 'b', 'c', 'd', '$'], ['a', '@', 'c', 'd', 'e']],
|
||||
[['!', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
|
||||
]
|
||||
dev_data = train_data
|
||||
vocab = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9}
|
||||
label_vocab = {'a': 0, '@': 1, 'c': 2, 'd': 3, 'e': 4}
|
||||
|
||||
data_set = DataSet()
|
||||
for example in train_data:
|
||||
text, label = example[0], example[1]
|
||||
x = TextField(text, False)
|
||||
y = TextField(label, is_target=True)
|
||||
ins = Instance(word_seq=x, label_seq=y)
|
||||
data_set.append(ins)
|
||||
|
||||
data_set.index_field("word_seq", vocab)
|
||||
data_set.index_field("label_seq", label_vocab)
|
||||
|
||||
model = SeqLabeling(args)
|
||||
trainer.train(network=model, train_data=train_data, dev_data=dev_data)
|
||||
|
||||
trainer.train(network=model, train_data=data_set, dev_data=data_set)
|
||||
# If this can run, everything is OK.
|
||||
|
||||
os.system("rm -rf save")
|
||||
print("pickle path deleted")
|
||||
|
@ -15,11 +15,11 @@ from fastNLP.core.optimizer import Optimizer
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-s", "--save", type=str, default="./seq_label/", help="path to save pickle files")
|
||||
parser.add_argument("-t", "--train", type=str, default="./data_for_tests/people.txt",
|
||||
parser.add_argument("-t", "--train", type=str, default="../data_for_tests/people.txt",
|
||||
help="path to the training data")
|
||||
parser.add_argument("-c", "--config", type=str, default="./data_for_tests/config", help="path to the config file")
|
||||
parser.add_argument("-c", "--config", type=str, default="../data_for_tests/config", help="path to the config file")
|
||||
parser.add_argument("-m", "--model_name", type=str, default="seq_label_model.pkl", help="the name of the model")
|
||||
parser.add_argument("-i", "--infer", type=str, default="data_for_tests/people_infer.txt",
|
||||
parser.add_argument("-i", "--infer", type=str, default="../data_for_tests/people_infer.txt",
|
||||
help="data used for inference")
|
||||
|
||||
args = parser.parse_args()
|
||||
@ -86,7 +86,7 @@ def train_and_test():
|
||||
trainer = SeqLabelTrainer(
|
||||
epochs=trainer_args["epochs"],
|
||||
batch_size=trainer_args["batch_size"],
|
||||
validate=trainer_args["validate"],
|
||||
validate=False,
|
||||
use_cuda=trainer_args["use_cuda"],
|
||||
pickle_path=pickle_path,
|
||||
save_best_dev=trainer_args["save_best_dev"],
|
||||
@ -121,7 +121,7 @@ def train_and_test():
|
||||
|
||||
# Tester
|
||||
tester = SeqLabelTester(save_output=False,
|
||||
save_loss=False,
|
||||
save_loss=True,
|
||||
save_best_dev=False,
|
||||
batch_size=4,
|
||||
use_cuda=False,
|
||||
@ -139,5 +139,5 @@ def train_and_test():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# train_and_test()
|
||||
infer()
|
||||
train_and_test()
|
||||
# infer()
|
||||
|
@ -1,8 +0,0 @@
|
||||
|
||||
|
||||
def test_charlm():
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_charlm()
|
85
test/model/test_seq_label.py
Normal file
85
test/model/test_seq_label.py
Normal file
@ -0,0 +1,85 @@
|
||||
import os
|
||||
|
||||
from fastNLP.core.optimizer import Optimizer
|
||||
from fastNLP.core.preprocess import SeqLabelPreprocess
|
||||
from fastNLP.core.tester import SeqLabelTester
|
||||
from fastNLP.core.trainer import SeqLabelTrainer
|
||||
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
|
||||
from fastNLP.loader.dataset_loader import POSDatasetLoader
|
||||
from fastNLP.loader.model_loader import ModelLoader
|
||||
from fastNLP.models.sequence_modeling import SeqLabeling
|
||||
from fastNLP.saver.model_saver import ModelSaver
|
||||
|
||||
pickle_path = "./seq_label/"
|
||||
model_name = "seq_label_model.pkl"
|
||||
config_dir = "test/data_for_tests/config"
|
||||
data_path = "test/data_for_tests/people.txt"
|
||||
data_infer_path = "test/data_for_tests/people_infer.txt"
|
||||
|
||||
|
||||
def test_training():
|
||||
# Config Loader
|
||||
trainer_args = ConfigSection()
|
||||
model_args = ConfigSection()
|
||||
ConfigLoader("_").load_config(config_dir, {
|
||||
"test_seq_label_trainer": trainer_args, "test_seq_label_model": model_args})
|
||||
|
||||
# Data Loader
|
||||
pos_loader = POSDatasetLoader(data_path)
|
||||
train_data = pos_loader.load_lines()
|
||||
|
||||
# Preprocessor
|
||||
p = SeqLabelPreprocess()
|
||||
data_train, data_dev = p.run(train_data, pickle_path=pickle_path, train_dev_split=0.5)
|
||||
model_args["vocab_size"] = p.vocab_size
|
||||
model_args["num_classes"] = p.num_classes
|
||||
|
||||
trainer = SeqLabelTrainer(
|
||||
epochs=trainer_args["epochs"],
|
||||
batch_size=trainer_args["batch_size"],
|
||||
validate=False,
|
||||
use_cuda=False,
|
||||
pickle_path=pickle_path,
|
||||
save_best_dev=trainer_args["save_best_dev"],
|
||||
model_name=model_name,
|
||||
optimizer=Optimizer("SGD", lr=0.01, momentum=0.9),
|
||||
)
|
||||
|
||||
# Model
|
||||
model = SeqLabeling(model_args)
|
||||
|
||||
# Start training
|
||||
trainer.train(model, data_train, data_dev)
|
||||
|
||||
# Saver
|
||||
saver = ModelSaver(os.path.join(pickle_path, model_name))
|
||||
saver.save_pytorch(model)
|
||||
|
||||
del model, trainer, pos_loader
|
||||
|
||||
# Define the same model
|
||||
model = SeqLabeling(model_args)
|
||||
|
||||
# Dump trained parameters into the model
|
||||
ModelLoader.load_pytorch(model, os.path.join(pickle_path, model_name))
|
||||
|
||||
# Load test configuration
|
||||
tester_args = ConfigSection()
|
||||
ConfigLoader("config.cfg").load_config(config_dir, {"test_seq_label_tester": tester_args})
|
||||
|
||||
# Tester
|
||||
tester = SeqLabelTester(save_output=False,
|
||||
save_loss=True,
|
||||
save_best_dev=False,
|
||||
batch_size=4,
|
||||
use_cuda=False,
|
||||
pickle_path=pickle_path,
|
||||
model_name="seq_label_in_test.pkl",
|
||||
print_every_step=1
|
||||
)
|
||||
|
||||
# Start testing with validation data
|
||||
tester.test(model, data_dev)
|
||||
|
||||
loss, accuracy = tester.metrics
|
||||
assert 0 < accuracy < 1
|
@ -19,9 +19,9 @@ from fastNLP.core.loss import Loss
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-s", "--save", type=str, default="./test_classification/", help="path to save pickle files")
|
||||
parser.add_argument("-t", "--train", type=str, default="./data_for_tests/text_classify.txt",
|
||||
parser.add_argument("-t", "--train", type=str, default="../data_for_tests/text_classify.txt",
|
||||
help="path to the training data")
|
||||
parser.add_argument("-c", "--config", type=str, default="./data_for_tests/config", help="path to the config file")
|
||||
parser.add_argument("-c", "--config", type=str, default="../data_for_tests/config", help="path to the config file")
|
||||
parser.add_argument("-m", "--model_name", type=str, default="classify_model.pkl", help="the name of the model")
|
||||
|
||||
args = parser.parse_args()
|
||||
@ -115,4 +115,4 @@ def train():
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
||||
infer()
|
||||
# infer()
|
||||
|
Loading…
Reference in New Issue
Block a user