Merge pull request #76 from fastnlp/add_field_support

Introduce Field & DataSet to eliminate sub-trainers & sub-testers
This commit is contained in:
Xipeng Qiu 2018-09-17 20:35:36 +08:00 committed by GitHub
commit b46c4ba042
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 1079 additions and 732 deletions

View File

@ -1 +0,0 @@

View File

@ -4,88 +4,6 @@ import numpy as np
import torch
class Action(object):
"""Operations shared by Trainer, Tester, or Inference.
This is designed for reducing replicate codes.
- make_batch: produce a min-batch of data. @staticmethod
- pad: padding method used in sequence modeling. @staticmethod
- mode: change network mode for either train or test. (for PyTorch) @staticmethod
"""
def __init__(self):
super(Action, self).__init__()
@staticmethod
def make_batch(iterator, use_cuda, output_length=True, max_len=None):
"""Batch and Pad data.
:param iterator: an iterator, (object that implements __next__ method) which returns the next sample.
:param use_cuda: bool, whether to use GPU
:param output_length: bool, whether to output the original length of the sequence before padding. (default: True)
:param max_len: int, maximum sequence length. Longer sequences will be clipped. (default: None)
:return :
if output_length is True,
(batch_x, seq_len): tuple of two elements
batch_x: list. Each entry is a list of features of a sample. [batch_size, max_len]
seq_len: list. The length of the pre-padded sequence, if output_length is True.
batch_y: list. Each entry is a list of labels of a sample. [batch_size, num_labels]
if output_length is False,
batch_x: list. Each entry is a list of features of a sample. [batch_size, max_len]
batch_y: list. Each entry is a list of labels of a sample. [batch_size, num_labels]
"""
for batch in iterator:
batch_x = [sample[0] for sample in batch]
batch_y = [sample[1] for sample in batch]
batch_x = Action.pad(batch_x)
# pad batch_y only if it is a 2-level list
if len(batch_y) > 0 and isinstance(batch_y[0], list):
batch_y = Action.pad(batch_y)
# convert list to tensor
batch_x = convert_to_torch_tensor(batch_x, use_cuda)
batch_y = convert_to_torch_tensor(batch_y, use_cuda)
# trim data to max_len
if max_len is not None and batch_x.size(1) > max_len:
batch_x = batch_x[:, :max_len]
if output_length:
seq_len = [len(x) for x in batch_x]
yield (batch_x, seq_len), batch_y
else:
yield batch_x, batch_y
@staticmethod
def pad(batch, fill=0):
""" Pad a mini-batch of sequence samples to maximum length of this batch.
:param batch: list of list
:param fill: word index to pad, default 0.
:return batch: a padded mini-batch
"""
max_length = max([len(x) for x in batch])
for idx, sample in enumerate(batch):
if len(sample) < max_length:
batch[idx] = sample + ([fill] * (max_length - len(sample)))
return batch
@staticmethod
def mode(model, is_test=False):
"""Train mode or Test mode. This is for PyTorch currently.
:param model: a PyTorch model
:param is_test: bool, whether in test mode or not.
"""
if is_test:
model.eval()
else:
model.train()
def convert_to_torch_tensor(data_list, use_cuda):
"""Convert lists into (cuda) Tensors.
@ -168,19 +86,7 @@ class BaseSampler(object):
"""
def __init__(self, data_set):
"""
:param data_set: multi-level list, of shape [num_example, *]
"""
self.data_set_length = len(data_set)
self.data = data_set
def __len__(self):
return self.data_set_length
def __iter__(self):
def __call__(self, *args, **kwargs):
raise NotImplementedError
@ -189,16 +95,8 @@ class SequentialSampler(BaseSampler):
"""
def __init__(self, data_set):
"""
:param data_set: multi-level list
"""
super(SequentialSampler, self).__init__(data_set)
def __iter__(self):
return iter(self.data)
def __call__(self, data_set):
return list(range(len(data_set)))
class RandomSampler(BaseSampler):
@ -206,17 +104,9 @@ class RandomSampler(BaseSampler):
"""
def __init__(self, data_set):
"""
def __call__(self, data_set):
return list(np.random.permutation(len(data_set)))
:param data_set: multi-level list
"""
super(RandomSampler, self).__init__(data_set)
self.order = np.random.permutation(self.data_set_length)
def __iter__(self):
return iter((self.data[idx] for idx in self.order))
class Batchifier(object):
@ -252,6 +142,7 @@ class BucketBatchifier(Batchifier):
"""Partition all samples into multiple buckets, each of which contains sentences of approximately the same length.
In sampling, first random choose a bucket. Then sample data from it.
The number of buckets is decided dynamically by the variance of sentence lengths.
TODO: merge it into Batch
"""
def __init__(self, data_set, batch_size, num_buckets, drop_last=True, sampler=None):

126
fastNLP/core/batch.py Normal file
View File

@ -0,0 +1,126 @@
from collections import defaultdict
import torch
from fastNLP.core.dataset import DataSet
from fastNLP.core.field import TextField, LabelField
from fastNLP.core.instance import Instance
class Batch(object):
"""Batch is an iterable object which iterates over mini-batches.
::
for batch_x, batch_y in Batch(data_set):
"""
def __init__(self, dataset, batch_size, sampler, use_cuda):
self.dataset = dataset
self.batch_size = batch_size
self.sampler = sampler
self.use_cuda = use_cuda
self.idx_list = None
self.curidx = 0
def __iter__(self):
self.idx_list = self.sampler(self.dataset)
self.curidx = 0
self.lengths = self.dataset.get_length()
return self
def __next__(self):
"""
:return batch_x: dict of (str: torch.LongTensor), which means (field name: tensor of shape [batch_size, padding_length])
batch_x also contains an item (str: list of int) about origin lengths,
which means ("field_name_origin_len": origin lengths).
E.g.
::
{'text': tensor([[ 0, 1, 2, 3, 0, 0, 0], 4, 5, 2, 6, 7, 8, 9]]), 'text_origin_len': [4, 7]})
batch_y: dict of (str: torch.LongTensor), which means (field name: tensor of shape [batch_size, padding_length])
All tensors in both batch_x and batch_y will be cuda tensors if use_cuda is True.
The names of fields are defined in preprocessor's convert_to_dataset method.
"""
if self.curidx >= len(self.idx_list):
raise StopIteration
else:
endidx = min(self.curidx + self.batch_size, len(self.idx_list))
padding_length = {field_name: max(field_length[self.curidx: endidx])
for field_name, field_length in self.lengths.items()}
origin_lengths = {field_name: field_length[self.curidx: endidx]
for field_name, field_length in self.lengths.items()}
batch_x, batch_y = defaultdict(list), defaultdict(list)
for idx in range(self.curidx, endidx):
x, y = self.dataset.to_tensor(idx, padding_length)
for name, tensor in x.items():
batch_x[name].append(tensor)
for name, tensor in y.items():
batch_y[name].append(tensor)
batch_origin_length = {}
# combine instances into a batch
for batch in (batch_x, batch_y):
for name, tensor_list in batch.items():
if self.use_cuda:
batch[name] = torch.stack(tensor_list, dim=0).cuda()
else:
batch[name] = torch.stack(tensor_list, dim=0)
# add origin lengths in batch_x
for name, tensor in batch_x.items():
if self.use_cuda:
batch_origin_length[name + "_origin_len"] = torch.LongTensor(origin_lengths[name]).cuda()
else:
batch_origin_length[name + "_origin_len"] = torch.LongTensor(origin_lengths[name])
batch_x.update(batch_origin_length)
self.curidx += endidx
return batch_x, batch_y
if __name__ == "__main__":
"""simple running example
"""
texts = ["i am a cat",
"this is a test of new batch",
"haha"
]
labels = [0, 1, 0]
# prepare vocabulary
vocab = {}
for text in texts:
for tokens in text.split():
if tokens not in vocab:
vocab[tokens] = len(vocab)
print("vocabulary: ", vocab)
# prepare input dataset
data = DataSet()
for text, label in zip(texts, labels):
x = TextField(text.split(), False)
y = LabelField(label, is_target=True)
ins = Instance(text=x, label=y)
data.append(ins)
# use vocabulary to index data
data.index_field("text", vocab)
# define naive sampler for batch class
class SeqSampler:
def __call__(self, dataset):
return list(range(len(dataset)))
# use batch to iterate dataset
data_iterator = Batch(data, 2, SeqSampler(), False)
for epoch in range(1):
for batch_x, batch_y in data_iterator:
print(batch_x)
print(batch_y)
# do stuff

111
fastNLP/core/dataset.py Normal file
View File

@ -0,0 +1,111 @@
from collections import defaultdict
from fastNLP.core.field import TextField
from fastNLP.core.instance import Instance
def create_dataset_from_lists(str_lists: list, word_vocab: dict, has_target: bool = False, label_vocab: dict = None):
if has_target is True:
if label_vocab is None:
raise RuntimeError("Must provide label vocabulary to transform labels.")
return create_labeled_dataset_from_lists(str_lists, word_vocab, label_vocab)
else:
return create_unlabeled_dataset_from_lists(str_lists, word_vocab)
def create_labeled_dataset_from_lists(str_lists, word_vocab, label_vocab):
"""Create an DataSet instance that contains labels.
:param str_lists: list of list of strings, [num_examples, 2, *].
::
[
[[word_11, word_12, ...], [label_11, label_12, ...]],
...
]
:param word_vocab: dict of (str: int), which means (word: index).
:param label_vocab: dict of (str: int), which means (word: index).
:return data_set: a DataSet instance.
"""
data_set = DataSet()
for example in str_lists:
word_seq, label_seq = example[0], example[1]
x = TextField(word_seq, is_target=False)
y = TextField(label_seq, is_target=True)
data_set.append(Instance(word_seq=x, label_seq=y))
data_set.index_field("word_seq", word_vocab)
data_set.index_field("label_seq", label_vocab)
return data_set
def create_unlabeled_dataset_from_lists(str_lists, word_vocab):
"""Create an DataSet instance that contains no labels.
:param str_lists: list of list of strings, [num_examples, *].
::
[
[word_11, word_12, ...],
...
]
:param word_vocab: dict of (str: int), which means (word: index).
:return data_set: a DataSet instance.
"""
data_set = DataSet()
for word_seq in str_lists:
x = TextField(word_seq, is_target=False)
data_set.append(Instance(word_seq=x))
data_set.index_field("word_seq", word_vocab)
return data_set
class DataSet(list):
"""A DataSet object is a list of Instance objects.
"""
def __init__(self, name="", instances=None):
"""
:param name: str, the name of the dataset. (default: "")
:param instances: list of Instance objects. (default: None)
"""
list.__init__([])
self.name = name
if instances is not None:
self.extend(instances)
def index_all(self, vocab):
for ins in self:
ins.index_all(vocab)
def index_field(self, field_name, vocab):
for ins in self:
ins.index_field(field_name, vocab)
def to_tensor(self, idx: int, padding_length: dict):
"""Convert an instance in a dataset to tensor.
:param idx: int, the index of the instance in the dataset.
:param padding_length: int
:return tensor_x: dict of (str: torch.LongTensor), which means (field name: tensor of shape [padding_length, ])
tensor_y: dict of (str: torch.LongTensor), which means (field name: tensor of shape [padding_length, ])
"""
ins = self[idx]
return ins.to_tensor(padding_length)
def get_length(self):
"""Fetch lengths of all fields in all instances in a dataset.
:return lengths: dict of (str: list). The str is the field name.
The list contains lengths of this field in all instances.
"""
lengths = defaultdict(list)
for ins in self:
for field_name, field_length in ins.get_length().items():
lengths[field_name].append(field_length)
return lengths

93
fastNLP/core/field.py Normal file
View File

@ -0,0 +1,93 @@
import torch
class Field(object):
"""A field defines a data type.
"""
def __init__(self, is_target: bool):
self.is_target = is_target
def index(self, vocab):
raise NotImplementedError
def get_length(self):
raise NotImplementedError
def to_tensor(self, padding_length):
raise NotImplementedError
class TextField(Field):
def __init__(self, text, is_target):
"""
:param text: list of strings
:param is_target: bool
"""
super(TextField, self).__init__(is_target)
self.text = text
self._index = None
def index(self, vocab):
if self._index is None:
self._index = [vocab[c] for c in self.text]
else:
raise RuntimeError("Replicate indexing of this field.")
return self._index
def get_length(self):
"""Fetch the length of the text field.
:return length: int, the length of the text.
"""
return len(self.text)
def to_tensor(self, padding_length: int):
"""Convert text field to tensor.
:param padding_length: int
:return tensor: torch.LongTensor, of shape [padding_length, ]
"""
pads = []
if self._index is None:
raise RuntimeError("Indexing not done before to_tensor in TextField.")
if padding_length > self.get_length():
pads = [0] * (padding_length - self.get_length())
return torch.LongTensor(self._index + pads)
class LabelField(Field):
def __init__(self, label, is_target=True):
super(LabelField, self).__init__(is_target)
self.label = label
self._index = None
def get_length(self):
"""Fetch the length of the label field.
:return length: int, the length of the label, always 1.
"""
return 1
def index(self, vocab):
if self._index is None:
self._index = vocab[self.label]
return self._index
def to_tensor(self, padding_length):
if self._index is None:
if isinstance(self.label, int):
return torch.LongTensor([self.label])
elif isinstance(self.label, str):
raise RuntimeError("Field {} not indexed. Call index method.".format(self.label))
else:
raise RuntimeError(
"Not support type for LabelField. Expect str or int, got {}.".format(type(self.label)))
else:
return torch.LongTensor([self._index])
if __name__ == "__main__":
tf = TextField("test the code".split(), is_target=False)

53
fastNLP/core/instance.py Normal file
View File

@ -0,0 +1,53 @@
class Instance(object):
"""An instance which consists of Fields is an example in the DataSet.
"""
def __init__(self, **fields):
self.fields = fields
self.has_index = False
self.indexes = {}
def add_field(self, field_name, field):
self.fields[field_name] = field
def get_length(self):
"""Fetch the length of all fields in the instance.
:return length: dict of (str: int), which means (field name: field length).
"""
length = {name: field.get_length() for name, field in self.fields.items()}
return length
def index_field(self, field_name, vocab):
"""use `vocab` to index certain field
"""
self.indexes[field_name] = self.fields[field_name].index(vocab)
def index_all(self, vocab):
"""use `vocab` to index all fields
"""
if self.has_index:
print("error")
return self.indexes
indexes = {name: field.index(vocab) for name, field in self.fields.items()}
self.indexes = indexes
return indexes
def to_tensor(self, padding_length: dict):
"""Convert instance to tensor.
:param padding_length: dict of (str: int), which means (field name: padding_length of this field)
:return tensor_x: dict of (str: torch.LongTensor), which means (field name: tensor of shape [padding_length, ])
tensor_y: dict of (str: torch.LongTensor), which means (field name: tensor of shape [padding_length, ])
If is_target is False for all fields, tensor_y would be an empty dict.
"""
tensor_x = {}
tensor_y = {}
for name, field in self.fields.items():
if field.is_target:
tensor_y[name] = field.to_tensor(padding_length[name])
else:
tensor_x[name] = field.to_tensor(padding_length[name])
return tensor_x, tensor_y

View File

@ -1,53 +1,10 @@
import numpy as np
import torch
from fastNLP.core.action import Batchifier, SequentialSampler
from fastNLP.core.action import convert_to_torch_tensor
from fastNLP.core.preprocess import load_pickle, DEFAULT_UNKNOWN_LABEL
from fastNLP.modules import utils
def make_batch(iterator, use_cuda, output_length=False, max_len=None, min_len=None):
"""Batch and Pad data, only for Inference.
:param iterator: An iterable object that returns a list of indices representing a mini-batch of samples.
:param use_cuda: bool, whether to use GPU
:param output_length: bool, whether to output the original length of the sequence before padding. (default: False)
:param max_len: int, maximum sequence length. Longer sequences will be clipped. (default: None)
:param min_len: int, minimum sequence length. Shorter sequences will be padded. (default: None)
:return:
"""
for batch_x in iterator:
batch_x = pad(batch_x)
# convert list to tensor
batch_x = convert_to_torch_tensor(batch_x, use_cuda)
# trim data to max_len
if max_len is not None and batch_x.size(1) > max_len:
batch_x = batch_x[:, :max_len]
if min_len is not None and batch_x.size(1) < min_len:
pad_tensor = torch.zeros(batch_x.size(0), min_len - batch_x.size(1)).to(batch_x)
batch_x = torch.cat((batch_x, pad_tensor), 1)
if output_length:
seq_len = [len(x) for x in batch_x]
yield tuple([batch_x, seq_len])
else:
yield batch_x
def pad(batch, fill=0):
""" Pad a mini-batch of sequence samples to maximum length of this batch.
:param batch: list of list
:param fill: word index to pad, default 0.
:return batch: a padded mini-batch
"""
max_length = max([len(x) for x in batch])
for idx, sample in enumerate(batch):
if len(sample) < max_length:
batch[idx] = sample + ([fill] * (max_length - len(sample)))
return batch
from fastNLP.core.action import SequentialSampler
from fastNLP.core.batch import Batch
from fastNLP.core.dataset import create_dataset_from_lists
from fastNLP.core.preprocess import load_pickle
class Predictor(object):
@ -59,11 +16,17 @@ class Predictor(object):
Currently, Predictor does not support GPU.
"""
def __init__(self, pickle_path):
def __init__(self, pickle_path, task):
"""
:param pickle_path: str, the path to the pickle files.
:param task: str, specify which task the predictor will perform. One of ("seq_label", "text_classify").
"""
self.batch_size = 1
self.batch_output = []
self.iterator = None
self.pickle_path = pickle_path
self._task = task # one of ("seq_label", "text_classify")
self.index2label = load_pickle(self.pickle_path, "id2class.pkl")
self.word2index = load_pickle(self.pickle_path, "word2id.pkl")
@ -71,19 +34,19 @@ class Predictor(object):
"""Perform inference using the trained model.
:param network: a PyTorch model (cpu)
:param data: list of list of strings
:param data: list of list of strings, [num_examples, seq_len]
:return: list of list of strings, [num_examples, tag_seq_length]
"""
# transform strings into indices
# transform strings into DataSet object
data = self.prepare_input(data)
# turn on the testing mode; clean up the history
self.mode(network, test=True)
self.batch_output.clear()
data_iterator = iter(Batchifier(SequentialSampler(data), self.batch_size, drop_last=False))
data_iterator = Batch(data, batch_size=self.batch_size, sampler=SequentialSampler(), use_cuda=False)
for batch_x in self.make_batch(data_iterator, use_cuda=False):
for batch_x, _ in data_iterator:
with torch.no_grad():
prediction = self.data_forward(network, batch_x)
@ -99,103 +62,61 @@ class Predictor(object):
def data_forward(self, network, x):
"""Forward through network."""
raise NotImplementedError
def make_batch(self, iterator, use_cuda):
raise NotImplementedError
y = network(**x)
if self._task == "seq_label":
y = network.prediction(y)
return y
def prepare_input(self, data):
"""Transform two-level list of strings into that of index.
"""Transform two-level list of strings into an DataSet object.
In the training pipeline, this is done by Preprocessor. But in inference time, we do not call Preprocessor.
:param data:
:param data: list of list of strings.
::
[
[word_11, word_12, ...],
[word_21, word_22, ...],
...
]
:return data_index: list of list of int.
:return data_set: a DataSet instance.
"""
assert isinstance(data, list)
data_index = []
default_unknown_index = self.word2index[DEFAULT_UNKNOWN_LABEL]
for example in data:
data_index.append([self.word2index.get(w, default_unknown_index) for w in example])
return data_index
return create_dataset_from_lists(data, self.word2index, has_target=False)
def prepare_output(self, data):
"""Transform list of batch outputs into strings."""
raise NotImplementedError
if self._task == "seq_label":
return self._seq_label_prepare_output(data)
elif self._task == "text_classify":
return self._text_classify_prepare_output(data)
else:
raise NotImplementedError("Unknown task type {}".format(self._task))
class SeqLabelInfer(Predictor):
"""
Inference on sequence labeling models.
"""
def __init__(self, pickle_path):
super(SeqLabelInfer, self).__init__(pickle_path)
def data_forward(self, network, inputs):
"""
This is only for sequence labeling with CRF decoder.
:param network: a PyTorch model
:param inputs: tuple of (x, seq_len)
x: Tensor of shape [batch_size, max_len], where max_len is the maximum length of the mini-batch
after padding.
seq_len: list of int, the lengths of sequences before padding.
:return prediction: Tensor of shape [batch_size, max_len]
"""
if not isinstance(inputs[1], list) and isinstance(inputs[0], list):
raise RuntimeError("output_length must be true for sequence modeling.")
# unpack the returned value from make_batch
x, seq_len = inputs[0], inputs[1]
batch_size, max_len = x.size(0), x.size(1)
mask = utils.seq_mask(seq_len, max_len)
mask = mask.byte().view(batch_size, max_len)
y = network(x)
prediction = network.prediction(y, mask)
return torch.Tensor(prediction)
def make_batch(self, iterator, use_cuda):
return make_batch(iterator, use_cuda, output_length=True)
def prepare_output(self, batch_outputs):
"""Transform list of batch outputs into strings.
:param batch_outputs: list of 2-D Tensor, shape [num_batch, batch-size, tag_seq_length].
:return results: 2-D list of strings, shape [num_examples, tag_seq_length]
"""
def _seq_label_prepare_output(self, batch_outputs):
results = []
for batch in batch_outputs:
for example in np.array(batch):
results.append([self.index2label[int(x)] for x in example])
return results
class ClassificationInfer(Predictor):
"""
Inference on Classification models.
"""
def __init__(self, pickle_path):
super(ClassificationInfer, self).__init__(pickle_path)
def data_forward(self, network, x):
"""Forward through network."""
logits = network(x)
return logits
def make_batch(self, iterator, use_cuda):
return make_batch(iterator, use_cuda, output_length=False, min_len=5)
def prepare_output(self, batch_outputs):
"""
Transform list of batch outputs into strings.
:param batch_outputs: list of 2-D Tensor, of shape [num_batch, batch-size, num_classes].
:return results: list of strings
"""
def _text_classify_prepare_output(self, batch_outputs):
results = []
for batch_out in batch_outputs:
idx = np.argmax(batch_out.detach().numpy(), axis=-1)
results.extend([self.index2label[i] for i in idx])
return results
class SeqLabelInfer(Predictor):
def __init__(self, pickle_path):
print(
"[FastNLP Warning] SeqLabelInfer will be deprecated. Please use Predictor with argument 'task'='seq_label'.")
super(SeqLabelInfer, self).__init__(pickle_path, "seq_label")
class ClassificationInfer(Predictor):
def __init__(self, pickle_path):
print(
"[FastNLP Warning] ClassificationInfer will be deprecated. Please use Predictor with argument 'task'='text_classify'.")
super(ClassificationInfer, self).__init__(pickle_path, "text_classify")

View File

@ -3,6 +3,10 @@ import os
import numpy as np
from fastNLP.core.dataset import DataSet
from fastNLP.core.field import TextField, LabelField
from fastNLP.core.instance import Instance
DEFAULT_PADDING_LABEL = '<pad>' # dict index = 0
DEFAULT_UNKNOWN_LABEL = '<unk>' # dict index = 1
DEFAULT_RESERVED_LABEL = ['<reserved-2>',
@ -84,7 +88,7 @@ class BasePreprocess(object):
return len(self.label2index)
def run(self, train_dev_data, test_data=None, pickle_path="./", train_dev_split=0, cross_val=False, n_fold=10):
"""Main preprocessing pipeline.
"""Main pre-processing pipeline.
:param train_dev_data: three-level list, with either single label or multiple labels in a sample.
:param test_data: three-level list, with either single label or multiple labels in a sample. (optional)
@ -92,7 +96,9 @@ class BasePreprocess(object):
:param train_dev_split: float, between [0, 1]. The ratio of training data used as validation set.
:param cross_val: bool, whether to do cross validation.
:param n_fold: int, the number of folds of cross validation. Only useful when cross_val is True.
:return results: a tuple of datasets after preprocessing.
:return results: multiple datasets after pre-processing. If test_data is provided, return one more dataset.
If train_dev_split > 0, return one more dataset - the dev set. If cross_val is True, each dataset
is a list of DataSet objects; Otherwise, each dataset is a DataSet object.
"""
if pickle_exist(pickle_path, "word2id.pkl") and pickle_exist(pickle_path, "class2id.pkl"):
@ -111,68 +117,87 @@ class BasePreprocess(object):
index2label = self.build_reverse_dict(self.label2index)
save_pickle(index2label, pickle_path, "id2class.pkl")
data_train = []
data_dev = []
train_set = []
dev_set = []
if not cross_val:
if not pickle_exist(pickle_path, "data_train.pkl"):
data_train.extend(self.to_index(train_dev_data))
if train_dev_split > 0 and not pickle_exist(pickle_path, "data_dev.pkl"):
split = int(len(data_train) * train_dev_split)
data_dev = data_train[: split]
data_train = data_train[split:]
save_pickle(data_dev, pickle_path, "data_dev.pkl")
split = int(len(train_dev_data) * train_dev_split)
data_dev = train_dev_data[: split]
data_train = train_dev_data[split:]
train_set = self.convert_to_dataset(data_train, self.word2index, self.label2index)
dev_set = self.convert_to_dataset(data_dev, self.word2index, self.label2index)
save_pickle(dev_set, pickle_path, "data_dev.pkl")
print("{} of the training data is split for validation. ".format(train_dev_split))
save_pickle(data_train, pickle_path, "data_train.pkl")
else:
train_set = self.convert_to_dataset(train_dev_data, self.word2index, self.label2index)
save_pickle(train_set, pickle_path, "data_train.pkl")
else:
data_train = load_pickle(pickle_path, "data_train.pkl")
train_set = load_pickle(pickle_path, "data_train.pkl")
if pickle_exist(pickle_path, "data_dev.pkl"):
data_dev = load_pickle(pickle_path, "data_dev.pkl")
dev_set = load_pickle(pickle_path, "data_dev.pkl")
else:
# cross_val is True
if not pickle_exist(pickle_path, "data_train_0.pkl"):
# cross validation
data_idx = self.to_index(train_dev_data)
data_cv = self.cv_split(data_idx, n_fold)
data_cv = self.cv_split(train_dev_data, n_fold)
for i, (data_train_cv, data_dev_cv) in enumerate(data_cv):
data_train_cv = self.convert_to_dataset(data_train_cv, self.word2index, self.label2index)
data_dev_cv = self.convert_to_dataset(data_dev_cv, self.word2index, self.label2index)
save_pickle(
data_train_cv, pickle_path,
"data_train_{}.pkl".format(i))
save_pickle(
data_dev_cv, pickle_path,
"data_dev_{}.pkl".format(i))
data_train.append(data_train_cv)
data_dev.append(data_dev_cv)
train_set.append(data_train_cv)
dev_set.append(data_dev_cv)
print("{}-fold cross validation.".format(n_fold))
else:
for i in range(n_fold):
data_train_cv = load_pickle(pickle_path, "data_train_{}.pkl".format(i))
data_dev_cv = load_pickle(pickle_path, "data_dev_{}.pkl".format(i))
data_train.append(data_train_cv)
data_dev.append(data_dev_cv)
train_set.append(data_train_cv)
dev_set.append(data_dev_cv)
# prepare test data if provided
data_test = []
test_set = []
if test_data is not None:
if not pickle_exist(pickle_path, "data_test.pkl"):
data_test = self.to_index(test_data)
save_pickle(data_test, pickle_path, "data_test.pkl")
test_set = self.convert_to_dataset(test_data, self.word2index, self.label2index)
save_pickle(test_set, pickle_path, "data_test.pkl")
# return preprocessed results
results = [data_train]
results = [train_set]
if cross_val or train_dev_split > 0:
results.append(data_dev)
results.append(dev_set)
if test_data:
results.append(data_test)
results.append(test_set)
if len(results) == 1:
return results[0]
else:
return tuple(results)
def build_dict(self, data):
raise NotImplementedError
label2index = DEFAULT_WORD_TO_INDEX.copy()
word2index = DEFAULT_WORD_TO_INDEX.copy()
for example in data:
for word in example[0]:
if word not in word2index:
word2index[word] = len(word2index)
label = example[1]
if isinstance(label, str):
# label is a string
if label not in label2index:
label2index[label] = len(label2index)
elif isinstance(label, list):
# label is a list of strings
for single_label in label:
if single_label not in label2index:
label2index[single_label] = len(label2index)
return word2index, label2index
def to_index(self, data):
raise NotImplementedError
def build_reverse_dict(self, word_dict):
id2word = {word_dict[w]: w for w in word_dict}
@ -186,11 +211,23 @@ class BasePreprocess(object):
return data_train, data_dev
def cv_split(self, data, n_fold):
"""Split data for cross validation."""
"""Split data for cross validation.
:param data: list of string
:param n_fold: int
:return data_cv:
::
[
(data_train, data_dev), # 1st fold
(data_train, data_dev), # 2nd fold
...
]
"""
data_copy = data.copy()
np.random.shuffle(data_copy)
fold_size = round(len(data_copy) / n_fold)
data_cv = []
for i in range(n_fold - 1):
start = i * fold_size
@ -202,154 +239,72 @@ class BasePreprocess(object):
data_dev = data_copy[start:]
data_train = data_copy[:start]
data_cv.append((data_train, data_dev))
return data_cv
def convert_to_dataset(self, data, vocab, label_vocab):
"""Convert list of indices into a DataSet object.
:param data: list. Entries are strings.
:param vocab: a dict, mapping string (token) to index (int).
:param label_vocab: a dict, mapping string (label) to index (int).
:return data_set: a DataSet object
"""
use_word_seq = False
use_label_seq = False
use_label_str = False
# construct a DataSet object and fill it with Instances
data_set = DataSet()
for example in data:
words, label = example[0], example[1]
instance = Instance()
if isinstance(words, list):
x = TextField(words, is_target=False)
instance.add_field("word_seq", x)
use_word_seq = True
else:
raise NotImplementedError("words is a {}".format(type(words)))
if isinstance(label, list):
y = TextField(label, is_target=True)
instance.add_field("label_seq", y)
use_label_seq = True
elif isinstance(label, str):
y = LabelField(label, is_target=True)
instance.add_field("label", y)
use_label_str = True
else:
raise NotImplementedError("label is a {}".format(type(label)))
data_set.append(instance)
# convert strings to indices
if use_word_seq:
data_set.index_field("word_seq", vocab)
if use_label_seq:
data_set.index_field("label_seq", label_vocab)
if use_label_str:
data_set.index_field("label", label_vocab)
return data_set
class SeqLabelPreprocess(BasePreprocess):
"""Preprocess pipeline, including building mapping from words to index, from index to words,
from labels/classes to index, from index to labels/classes.
data of three-level list which have multiple labels in each sample.
::
[
[ [word_11, word_12, ...], [label_1, label_1, ...] ],
[ [word_21, word_22, ...], [label_2, label_1, ...] ],
...
]
"""
def __init__(self):
super(SeqLabelPreprocess, self).__init__()
def build_dict(self, data):
"""Add new words with indices into self.word_dict, new labels with indices into self.label_dict.
:param data: three-level list
::
[
[ [word_11, word_12, ...], [label_1, label_1, ...] ],
[ [word_21, word_22, ...], [label_2, label_1, ...] ],
...
]
:return word2index: dict of {str, int}
label2index: dict of {str, int}
"""
# In seq labeling, both word seq and label seq need to be padded to the same length in a mini-batch.
label2index = DEFAULT_WORD_TO_INDEX.copy()
word2index = DEFAULT_WORD_TO_INDEX.copy()
for example in data:
for word, label in zip(example[0], example[1]):
if word not in word2index:
word2index[word] = len(word2index)
if label not in label2index:
label2index[label] = len(label2index)
return word2index, label2index
def to_index(self, data):
"""Convert word strings and label strings into indices.
:param data: three-level list
::
[
[ [word_11, word_12, ...], [label_1, label_1, ...] ],
[ [word_21, word_22, ...], [label_2, label_1, ...] ],
...
]
:return data_index: the same shape as data, but each string is replaced by its corresponding index
"""
data_index = []
for example in data:
word_list = []
label_list = []
for word, label in zip(example[0], example[1]):
word_list.append(self.word2index.get(word, DEFAULT_WORD_TO_INDEX[DEFAULT_UNKNOWN_LABEL]))
label_list.append(self.label2index.get(label, DEFAULT_WORD_TO_INDEX[DEFAULT_UNKNOWN_LABEL]))
data_index.append([word_list, label_list])
return data_index
class ClassPreprocess(BasePreprocess):
""" Preprocess pipeline for classification datasets.
Preprocess pipeline, including building mapping from words to index, from index to words,
from labels/classes to index, from index to labels/classes.
design for data of three-level list which has a single label in each sample.
::
[
[ [word_11, word_12, ...], label_1 ],
[ [word_21, word_22, ...], label_2 ],
...
]
"""
def __init__(self):
super(ClassPreprocess, self).__init__()
def build_dict(self, data):
"""Build vocabulary."""
# build vocabulary from scratch if nothing exists
word2index = DEFAULT_WORD_TO_INDEX.copy()
label2index = {} # DEFAULT_WORD_TO_INDEX.copy()
# collect every word and label
for sent, label in data:
if len(sent) <= 1:
continue
if label not in label2index:
label2index[label] = len(label2index)
for word in sent:
if word not in word2index:
word2index[word] = len(word2index)
return word2index, label2index
def to_index(self, data):
"""Convert word strings and label strings into indices.
:param data: three-level list
::
[
[ [word_11, word_12, ...], label_1 ],
[ [word_21, word_22, ...], label_2 ],
...
]
:return data_index: the same shape as data, but each string is replaced by its corresponding index
"""
data_index = []
for example in data:
word_list = []
# example[0] is the word list, example[1] is the single label
for word in example[0]:
word_list.append(self.word2index.get(word, DEFAULT_WORD_TO_INDEX[DEFAULT_UNKNOWN_LABEL]))
label_index = self.label2index.get(example[1], DEFAULT_WORD_TO_INDEX[DEFAULT_UNKNOWN_LABEL])
data_index.append([word_list, label_index])
return data_index
def infer_preprocess(pickle_path, data):
"""Preprocess over inference data. Transform three-level list of strings into that of index.
::
[
[word_11, word_12, ...],
[word_21, word_22, ...],
...
]
"""
word2index = load_pickle(pickle_path, "word2id.pkl")
data_index = []
for example in data:
data_index.append([word2index.get(w, DEFAULT_UNKNOWN_LABEL) for w in example])
return data_index
if __name__ == "__main__":
p = BasePreprocess()
train_dev_data = [[["I", "am", "a", "good", "student", "."], "0"],
[["You", "are", "pretty", "."], "1"]
]
training_set = p.run(train_dev_data)
print(training_set)

View File

@ -1,9 +1,8 @@
import numpy as np
import torch
from fastNLP.core.action import Action
from fastNLP.core.action import RandomSampler, Batchifier
from fastNLP.modules import utils
from fastNLP.core.action import RandomSampler
from fastNLP.core.batch import Batch
from fastNLP.saver.logger import create_logger
logger = create_logger(__name__, "./train_test.log")
@ -35,16 +34,16 @@ class BaseTester(object):
"""
"required_args" is the collection of arguments that users must pass to Trainer explicitly.
This is used to warn users of essential settings in the training.
Obviously, "required_args" is the subset of "default_args".
The value in "default_args" to the keys in "required_args" is simply for type check.
Specially, "required_args" does not have default value, so they have nothing to do with "default_args".
"""
# add required arguments here
required_args = {}
required_args = {"task" # one of ("seq_label", "text_classify")
}
for req_key in required_args:
if req_key not in kwargs:
logger.error("Tester lacks argument {}".format(req_key))
raise ValueError("Tester lacks argument {}".format(req_key))
self._task = kwargs["task"]
for key in default_args:
if key in kwargs:
@ -79,14 +78,14 @@ class BaseTester(object):
self._model = network
# turn on the testing mode; clean up the history
self.mode(network, test=True)
self.mode(network, is_test=True)
self.eval_history.clear()
self.batch_output.clear()
iterator = iter(Batchifier(RandomSampler(dev_data), self.batch_size, drop_last=False))
data_iterator = Batch(dev_data, self.batch_size, sampler=RandomSampler(), use_cuda=self.use_cuda)
step = 0
for batch_x, batch_y in self.make_batch(iterator):
for batch_x, batch_y in data_iterator:
with torch.no_grad():
prediction = self.data_forward(network, batch_x)
eval_results = self.evaluate(prediction, batch_y)
@ -102,17 +101,22 @@ class BaseTester(object):
print(self.make_eval_output(prediction, eval_results))
step += 1
def mode(self, model, test):
def mode(self, model, is_test=False):
"""Train mode or Test mode. This is for PyTorch currently.
:param model: a PyTorch model
:param test: bool, whether in test mode.
:param is_test: bool, whether in test mode or not.
"""
Action.mode(model, test)
if is_test:
model.eval()
else:
model.train()
def data_forward(self, network, x):
"""A forward pass of the model. """
raise NotImplementedError
y = network(**x)
return y
def evaluate(self, predict, truth):
"""Compute evaluation metrics.
@ -121,7 +125,38 @@ class BaseTester(object):
:param truth: Tensor
:return eval_results: can be anything. It will be stored in self.eval_history
"""
raise NotImplementedError
if "label_seq" in truth:
truth = truth["label_seq"]
elif "label" in truth:
truth = truth["label"]
else:
raise NotImplementedError("Unknown key {} in batch_y.".format(truth.keys()))
if self._task == "seq_label":
return self._seq_label_evaluate(predict, truth)
elif self._task == "text_classify":
return self._text_classify_evaluate(predict, truth)
else:
raise NotImplementedError("Unknown task type {}.".format(self._task))
def _seq_label_evaluate(self, predict, truth):
batch_size, max_len = predict.size(0), predict.size(1)
loss = self._model.loss(predict, truth) / batch_size
prediction = self._model.prediction(predict)
# pad prediction to equal length
for pred in prediction:
if len(pred) < max_len:
pred += [0] * (max_len - len(pred))
results = torch.Tensor(prediction).view(-1, )
# make sure "results" is in the same device as "truth"
results = results.to(truth)
accuracy = torch.sum(results == truth.view((-1,))).to(torch.float) / results.shape[0]
return [float(loss), float(accuracy)]
def _text_classify_evaluate(self, y_logit, y_true):
y_prob = torch.nn.functional.softmax(y_logit, dim=-1)
return [y_prob, y_true]
@property
def metrics(self):
@ -131,7 +166,27 @@ class BaseTester(object):
:return : variable number of outputs
"""
raise NotImplementedError
if self._task == "seq_label":
return self._seq_label_metrics
elif self._task == "text_classify":
return self._text_classify_metrics
else:
raise NotImplementedError("Unknown task type {}.".format(self._task))
@property
def _seq_label_metrics(self):
batch_loss = np.mean([x[0] for x in self.eval_history])
batch_accuracy = np.mean([x[1] for x in self.eval_history])
return batch_loss, batch_accuracy
@property
def _text_classify_metrics(self):
y_prob, y_true = zip(*self.eval_history)
y_prob = torch.cat(y_prob, dim=0)
y_pred = torch.argmax(y_prob, dim=-1)
y_true = torch.cat(y_true, dim=0)
acc = float(torch.sum(y_pred == y_true)) / len(y_true)
return y_true.cpu().numpy(), y_prob.cpu().numpy(), acc
def show_metrics(self):
"""Customize evaluation outputs in Trainer.
@ -140,10 +195,8 @@ class BaseTester(object):
:return print_str: str
"""
raise NotImplementedError
def make_batch(self, iterator):
raise NotImplementedError
loss, accuracy = self.metrics
return "dev loss={:.2f}, accuracy={:.2f}".format(loss, accuracy)
def make_eval_output(self, predictions, eval_results):
"""Customize Tester outputs.
@ -152,108 +205,20 @@ class BaseTester(object):
:param eval_results: Tensor
:return: str, to be printed.
"""
raise NotImplementedError
return self.show_metrics()
class SeqLabelTester(BaseTester):
"""Tester for sequence labeling.
"""
def __init__(self, **test_args):
"""
:param test_args: a dict-like object that has __getitem__ method, can be accessed by "test_args["key_str"]"
"""
test_args.update({"task": "seq_label"})
print(
"[FastNLP Warning] SeqLabelTester will be deprecated. Please use Tester with argument 'task'='seq_label'.")
super(SeqLabelTester, self).__init__(**test_args)
self.max_len = None
self.mask = None
self.seq_len = None
def data_forward(self, network, inputs):
"""This is only for sequence labeling with CRF decoder.
:param network: a PyTorch model
:param inputs: tuple of (x, seq_len)
x: Tensor of shape [batch_size, max_len], where max_len is the maximum length of the mini-batch
after padding.
seq_len: list of int, the lengths of sequences before padding.
:return y: Tensor of shape [batch_size, max_len]
"""
if not isinstance(inputs, tuple):
raise RuntimeError("output_length must be true for sequence modeling.")
# unpack the returned value from make_batch
x, seq_len = inputs[0], inputs[1]
batch_size, max_len = x.size(0), x.size(1)
mask = utils.seq_mask(seq_len, max_len)
mask = mask.byte().view(batch_size, max_len)
if torch.cuda.is_available() and self.use_cuda:
mask = mask.cuda()
self.mask = mask
self.seq_len = seq_len
y = network(x)
return y
def evaluate(self, predict, truth):
"""Compute metrics (or loss).
:param predict: Tensor, [batch_size, max_len, tag_size]
:param truth: Tensor, [batch_size, max_len]
:return:
"""
batch_size, max_len = predict.size(0), predict.size(1)
loss = self._model.loss(predict, truth, self.mask) / batch_size
prediction = self._model.prediction(predict, self.mask)
results = torch.Tensor(prediction).view(-1, )
# make sure "results" is in the same device as "truth"
results = results.to(truth)
accuracy = torch.sum(results == truth.view((-1,))).to(torch.float) / results.shape[0]
return [float(loss), float(accuracy)]
def metrics(self):
batch_loss = np.mean([x[0] for x in self.eval_history])
batch_accuracy = np.mean([x[1] for x in self.eval_history])
return batch_loss, batch_accuracy
def show_metrics(self):
"""This is called by Trainer to print evaluation on dev set.
:return print_str: str
"""
loss, accuracy = self.metrics()
return "dev loss={:.2f}, accuracy={:.2f}".format(loss, accuracy)
def make_batch(self, iterator):
return Action.make_batch(iterator, use_cuda=self.use_cuda, output_length=True)
class ClassificationTester(BaseTester):
"""Tester for classification."""
def __init__(self, **test_args):
"""
:param test_args: a dict-like object that has __getitem__ method.
can be accessed by "test_args["key_str"]"
"""
test_args.update({"task": "seq_label"})
print(
"[FastNLP Warning] ClassificationTester will be deprecated. Please use Tester with argument 'task'='text_classify'.")
super(ClassificationTester, self).__init__(**test_args)
def make_batch(self, iterator, max_len=None):
return Action.make_batch(iterator, use_cuda=self.use_cuda, max_len=max_len)
def data_forward(self, network, x):
"""Forward through network."""
logits = network(x)
return logits
def evaluate(self, y_logit, y_true):
"""Return y_pred and y_true."""
y_prob = torch.nn.functional.softmax(y_logit, dim=-1)
return [y_prob, y_true]
def metrics(self):
"""Compute accuracy."""
y_prob, y_true = zip(*self.eval_history)
y_prob = torch.cat(y_prob, dim=0)
y_pred = torch.argmax(y_prob, dim=-1)
y_true = torch.cat(y_true, dim=0)
acc = float(torch.sum(y_pred == y_true)) / len(y_true)
return y_true.cpu().numpy(), y_prob.cpu().numpy(), acc

View File

@ -4,15 +4,13 @@ import time
from datetime import timedelta
import torch
import tensorboardX
from tensorboardX import SummaryWriter
from fastNLP.core.action import Action
from fastNLP.core.action import RandomSampler, Batchifier
from fastNLP.core.action import RandomSampler
from fastNLP.core.batch import Batch
from fastNLP.core.loss import Loss
from fastNLP.core.optimizer import Optimizer
from fastNLP.core.tester import SeqLabelTester, ClassificationTester
from fastNLP.modules import utils
from fastNLP.saver.logger import create_logger
from fastNLP.saver.model_saver import ModelSaver
@ -50,16 +48,16 @@ class BaseTrainer(object):
"""
"required_args" is the collection of arguments that users must pass to Trainer explicitly.
This is used to warn users of essential settings in the training.
Obviously, "required_args" is the subset of "default_args".
The value in "default_args" to the keys in "required_args" is simply for type check.
Specially, "required_args" does not have default value, so they have nothing to do with "default_args".
"""
# add required arguments here
required_args = {}
required_args = {"task" # one of ("seq_label", "text_classify")
}
for req_key in required_args:
if req_key not in kwargs:
logger.error("Trainer lacks argument {}".format(req_key))
raise ValueError("Trainer lacks argument {}".format(req_key))
self._task = kwargs["task"]
for key in default_args:
if key in kwargs:
@ -90,13 +88,14 @@ class BaseTrainer(object):
self._optimizer_proto = default_args["optimizer"]
self._summary_writer = SummaryWriter(self.pickle_path + 'tensorboard_logs')
self._graph_summaried = False
self._best_accuracy = 0.0
def train(self, network, train_data, dev_data=None):
"""General Training Procedure
:param network: a model
:param train_data: three-level list, the training set.
:param dev_data: three-level list, the validation data (optional)
:param train_data: a DataSet instance, the training data
:param dev_data: a DataSet instance, the validation data (optional)
"""
# transfer model to gpu if available
if torch.cuda.is_available() and self.use_cuda:
@ -126,9 +125,10 @@ class BaseTrainer(object):
logger.info("training epoch {}".format(epoch))
# turn on network training mode
self.mode(network, test=False)
self.mode(network, is_test=False)
# prepare mini-batch iterator
data_iterator = iter(Batchifier(RandomSampler(train_data), self.batch_size, drop_last=False))
data_iterator = Batch(train_data, batch_size=self.batch_size, sampler=RandomSampler(),
use_cuda=self.use_cuda)
logger.info("prepared data iterator")
# one forward and backward pass
@ -157,7 +157,7 @@ class BaseTrainer(object):
- epoch: int,
"""
step = 0
for batch_x, batch_y in self.make_batch(data_iterator):
for batch_x, batch_y in data_iterator:
prediction = self.data_forward(network, batch_x)
@ -166,10 +166,6 @@ class BaseTrainer(object):
self.update()
self._summary_writer.add_scalar("loss", loss.item(), global_step=step)
if not self._graph_summaried:
self._summary_writer.add_graph(network, batch_x)
self._graph_summaried = True
if kwargs["n_print"] > 0 and step % kwargs["n_print"] == 0:
end = time.time()
diff = timedelta(seconds=round(end - kwargs["start"]))
@ -204,11 +200,17 @@ class BaseTrainer(object):
network_copy = copy.deepcopy(network)
self.train(network_copy, train_data_cv[i], dev_data_cv[i])
def make_batch(self, iterator):
raise NotImplementedError
def mode(self, model, is_test=False):
"""Train mode or Test mode. This is for PyTorch currently.
def mode(self, network, test):
Action.mode(network, test)
:param model: a PyTorch model
:param is_test: bool, whether in test mode or not.
"""
if is_test:
model.eval()
else:
model.train()
def define_optimizer(self):
"""Define framework-specific optimizer specified by the models.
@ -224,7 +226,20 @@ class BaseTrainer(object):
self._optimizer.step()
def data_forward(self, network, x):
raise NotImplementedError
if self._task == "seq_label":
y = network(x["word_seq"], x["word_seq_origin_len"])
elif self._task == "text_classify":
y = network(x["word_seq"])
else:
raise NotImplementedError("Unknown task type {}.".format(self._task))
if not self._graph_summaried:
if self._task == "seq_label":
self._summary_writer.add_graph(network, (x["word_seq"], x["word_seq_origin_len"]), verbose=False)
elif self._task == "text_classify":
self._summary_writer.add_graph(network, x["word_seq"], verbose=False)
self._graph_summaried = True
return y
def grad_backward(self, loss):
"""Compute gradient with link rules.
@ -243,6 +258,13 @@ class BaseTrainer(object):
:param truth: ground truth label vector
:return: a scalar
"""
if "label_seq" in truth:
truth = truth["label_seq"]
elif "label" in truth:
truth = truth["label"]
truth = truth.view((-1,))
else:
raise NotImplementedError("Unknown key {} in batch_y.".format(truth.keys()))
return self._loss_func(predict, truth)
def define_loss(self):
@ -270,7 +292,12 @@ class BaseTrainer(object):
:param validator: a Tester instance
:return: bool, True means current results on dev set is the best.
"""
raise NotImplementedError
loss, accuracy = validator.metrics
if accuracy > self._best_accuracy:
self._best_accuracy = accuracy
return True
else:
return False
def save_model(self, network, model_name):
"""Save this model with such a name.
@ -291,55 +318,11 @@ class SeqLabelTrainer(BaseTrainer):
"""Trainer for Sequence Labeling
"""
def __init__(self, **kwargs):
kwargs.update({"task": "seq_label"})
print(
"[FastNLP Warning] SeqLabelTrainer will be deprecated. Please use Trainer with argument 'task'='seq_label'.")
super(SeqLabelTrainer, self).__init__(**kwargs)
# self.vocab_size = kwargs["vocab_size"]
# self.num_classes = kwargs["num_classes"]
self.max_len = None
self.mask = None
self.best_accuracy = 0.0
def data_forward(self, network, inputs):
if not isinstance(inputs, tuple):
raise RuntimeError("output_length must be true for sequence modeling. Receive {}".format(type(inputs[0])))
# unpack the returned value from make_batch
x, seq_len = inputs[0], inputs[1]
batch_size, max_len = x.size(0), x.size(1)
mask = utils.seq_mask(seq_len, max_len)
mask = mask.byte().view(batch_size, max_len)
if torch.cuda.is_available() and self.use_cuda:
mask = mask.cuda()
self.mask = mask
y = network(x)
return y
def get_loss(self, predict, truth):
"""Compute loss given prediction and ground truth.
:param predict: prediction label vector, [batch_size, max_len, tag_size]
:param truth: ground truth label vector, [batch_size, max_len]
:return loss: a scalar
"""
batch_size, max_len = predict.size(0), predict.size(1)
assert truth.shape == (batch_size, max_len)
loss = self._model.loss(predict, truth, self.mask)
return loss
def best_eval_result(self, validator):
loss, accuracy = validator.metrics()
if accuracy > self.best_accuracy:
self.best_accuracy = accuracy
return True
else:
return False
def make_batch(self, iterator):
return Action.make_batch(iterator, output_length=True, use_cuda=self.use_cuda)
def _create_validator(self, valid_args):
return SeqLabelTester(**valid_args)
@ -349,33 +332,10 @@ class ClassificationTrainer(BaseTrainer):
"""Trainer for text classification."""
def __init__(self, **train_args):
train_args.update({"task": "text_classify"})
print(
"[FastNLP Warning] ClassificationTrainer will be deprecated. Please use Trainer with argument 'task'='text_classify'.")
super(ClassificationTrainer, self).__init__(**train_args)
self.iterator = None
self.loss_func = None
self.optimizer = None
self.best_accuracy = 0
def data_forward(self, network, x):
"""Forward through network."""
logits = network(x)
return logits
def make_batch(self, iterator):
return Action.make_batch(iterator, output_length=False, use_cuda=self.use_cuda)
def get_acc(self, y_logit, y_true):
"""Compute accuracy."""
y_pred = torch.argmax(y_logit, dim=-1)
return int(torch.sum(y_true == y_pred)) / len(y_true)
def best_eval_result(self, validator):
_, _, accuracy = validator.metrics()
if accuracy > self.best_accuracy:
self.best_accuracy = accuracy
return True
else:
return False
def _create_validator(self, valid_args):
return ClassificationTester(**valid_args)

View File

@ -35,8 +35,12 @@ class CNNText(torch.nn.Module):
self.dropout = nn.Dropout(drop_prob)
self.fc = encoder.linear.Linear(sum(kernel_nums), num_classes)
def forward(self, x):
x = self.embed(x) # [N,L] -> [N,L,C]
def forward(self, word_seq):
"""
:param word_seq: torch.LongTensor, [batch_size, seq_len]
:return x: torch.LongTensor, [batch_size, num_classes]
"""
x = self.embed(word_seq) # [N,L] -> [N,L,C]
x = self.conv_pool(x) # [N,L,C] -> [N,C]
x = self.dropout(x)
x = self.fc(x) # [N,C] -> [N, N_class]

View File

@ -4,6 +4,20 @@ from fastNLP.models.base_model import BaseModel
from fastNLP.modules import decoder, encoder
def seq_mask(seq_len, max_len):
"""Create a mask for the sequences.
:param seq_len: list or torch.LongTensor
:param max_len: int
:return mask: torch.LongTensor
"""
if isinstance(seq_len, list):
seq_len = torch.LongTensor(seq_len)
mask = [torch.ge(seq_len, i + 1) for i in range(max_len)]
mask = torch.stack(mask, 1)
return mask
class SeqLabeling(BaseModel):
"""
PyTorch Network for sequence labeling
@ -20,13 +34,17 @@ class SeqLabeling(BaseModel):
self.Rnn = encoder.lstm.Lstm(word_emb_dim, hidden_dim)
self.Linear = encoder.linear.Linear(hidden_dim, num_classes)
self.Crf = decoder.CRF.ConditionalRandomField(num_classes)
self.mask = None
def forward(self, x):
def forward(self, word_seq, word_seq_origin_len):
"""
:param x: LongTensor, [batch_size, mex_len]
:param word_seq: LongTensor, [batch_size, mex_len]
:param word_seq_origin_len: LongTensor, [batch_size,], the origin lengths of the sequences.
:return y: [batch_size, mex_len, tag_size]
"""
x = self.Embedding(x)
self.mask = self.make_mask(word_seq, word_seq_origin_len)
x = self.Embedding(word_seq)
# [batch_size, max_len, word_emb_dim]
x = self.Rnn(x)
# [batch_size, max_len, hidden_size * direction]
@ -34,27 +52,34 @@ class SeqLabeling(BaseModel):
# [batch_size, max_len, num_classes]
return x
def loss(self, x, y, mask):
def loss(self, x, y):
"""
Negative log likelihood loss.
:param x: Tensor, [batch_size, max_len, tag_size]
:param y: Tensor, [batch_size, max_len]
:param mask: ByteTensor, [batch_size, ,max_len]
:return loss: a scalar Tensor
"""
x = x.float()
y = y.long()
total_loss = self.Crf(x, y, mask)
assert x.shape[:2] == y.shape
assert y.shape == self.mask.shape
total_loss = self.Crf(x, y, self.mask)
return torch.mean(total_loss)
def prediction(self, x, mask):
def make_mask(self, x, seq_len):
batch_size, max_len = x.size(0), x.size(1)
mask = seq_mask(seq_len, max_len)
mask = mask.byte().view(batch_size, max_len)
mask = mask.to(x)
return mask
def prediction(self, x):
"""
:param x: FloatTensor, [batch_size, max_len, tag_size]
:param mask: ByteTensor, [batch_size, max_len]
:return prediction: list of [decode path(list)]
"""
tag_seq = self.Crf.viterbi_decode(x, mask)
tag_seq = self.Crf.viterbi_decode(x, self.mask)
return tag_seq
@ -81,14 +106,17 @@ class AdvSeqLabel(SeqLabeling):
self.Crf = decoder.CRF.ConditionalRandomField(num_classes)
def forward(self, x):
def forward(self, word_seq, word_seq_origin_len):
"""
:param x: LongTensor, [batch_size, mex_len]
:param word_seq: LongTensor, [batch_size, mex_len]
:param word_seq_origin_len: list of int.
:return y: [batch_size, mex_len, tag_size]
"""
batch_size = x.size(0)
max_len = x.size(1)
x = self.Embedding(x)
self.mask = self.make_mask(word_seq, word_seq_origin_len)
batch_size = word_seq.size(0)
max_len = word_seq.size(1)
x = self.Embedding(word_seq)
# [batch_size, max_len, word_emb_dim]
x = self.Rnn(x)
# [batch_size, max_len, hidden_size * direction]

View File

@ -1,17 +0,0 @@
import unittest
from fastNLP.core.action import Action, Batchifier, SequentialSampler
class TestAction(unittest.TestCase):
def test_case_1(self):
x = [1, 2, 3, 4, 5, 6, 7, 8]
y = [1, 1, 1, 1, 2, 2, 2, 2]
data = []
for i in range(len(x)):
data.append([[x[i]], [y[i]]])
data = Batchifier(SequentialSampler(data), batch_size=2, drop_last=False)
action = Action()
for batch_x in action.make_batch(data, use_cuda=False, output_length=True, max_len=None):
print(batch_x)

62
test/core/test_batch.py Normal file
View File

@ -0,0 +1,62 @@
import unittest
import torch
from fastNLP.core.batch import Batch
from fastNLP.core.dataset import DataSet, create_dataset_from_lists
from fastNLP.core.field import TextField, LabelField
from fastNLP.core.instance import Instance
raw_texts = ["i am a cat",
"this is a test of new batch",
"ha ha",
"I am a good boy .",
"This is the most beautiful girl ."
]
texts = [text.strip().split() for text in raw_texts]
labels = [0, 1, 0, 0, 1]
# prepare vocabulary
vocab = {}
for text in texts:
for tokens in text:
if tokens not in vocab:
vocab[tokens] = len(vocab)
class TestCase1(unittest.TestCase):
def test(self):
data = DataSet()
for text, label in zip(texts, labels):
x = TextField(text, is_target=False)
y = LabelField(label, is_target=True)
ins = Instance(text=x, label=y)
data.append(ins)
# use vocabulary to index data
data.index_field("text", vocab)
# define naive sampler for batch class
class SeqSampler:
def __call__(self, dataset):
return list(range(len(dataset)))
# use batch to iterate dataset
data_iterator = Batch(data, 2, SeqSampler(), False)
for batch_x, batch_y in data_iterator:
self.assertEqual(len(batch_x), 2)
self.assertTrue(isinstance(batch_x, dict))
self.assertTrue(isinstance(batch_x["text"], torch.LongTensor))
self.assertTrue(isinstance(batch_y, dict))
self.assertTrue(isinstance(batch_y["label"], torch.LongTensor))
class TestCase2(unittest.TestCase):
def test(self):
data = DataSet()
for text in texts:
x = TextField(text, is_target=False)
ins = Instance(text=x)
data.append(ins)
data_set = create_dataset_from_lists(texts, vocab, has_target=False)
self.assertTrue(type(data) == type(data_set))

View File

@ -0,0 +1,51 @@
import os
import unittest
from fastNLP.core.predictor import Predictor
from fastNLP.core.preprocess import save_pickle
from fastNLP.models.sequence_modeling import SeqLabeling
class TestPredictor(unittest.TestCase):
def test_seq_label(self):
model_args = {
"vocab_size": 10,
"word_emb_dim": 100,
"rnn_hidden_units": 100,
"num_classes": 5
}
infer_data = [
['a', 'b', 'c', 'd', 'e'],
['a', '@', 'c', 'd', 'e'],
['a', 'b', '#', 'd', 'e'],
['a', 'b', 'c', '?', 'e'],
['a', 'b', 'c', 'd', '$'],
['!', 'b', 'c', 'd', 'e']
]
vocab = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9}
os.system("mkdir save")
save_pickle({0: "0", 1: "1", 2: "2", 3: "3", 4: "4"}, "./save/", "id2class.pkl")
save_pickle(vocab, "./save/", "word2id.pkl")
model = SeqLabeling(model_args)
predictor = Predictor("./save/", task="seq_label")
results = predictor.predict(network=model, data=infer_data)
self.assertTrue(isinstance(results, list))
self.assertGreater(len(results), 0)
for res in results:
self.assertTrue(isinstance(res, list))
self.assertEqual(len(res), 5)
self.assertTrue(isinstance(res[0], str))
os.system("rm -rf save")
print("pickle path deleted")
class TestPredictor2(unittest.TestCase):
def test_text_classify(self):
# TODO
pass

View File

@ -1,24 +1,25 @@
import os
import unittest
from fastNLP.core.dataset import DataSet
from fastNLP.core.preprocess import SeqLabelPreprocess
data = [
[['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']],
[['Hello', 'world', '!'], ['a', 'n', '.']],
[['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']],
[['Hello', 'world', '!'], ['a', 'n', '.']],
[['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']],
[['Hello', 'world', '!'], ['a', 'n', '.']],
[['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']],
[['Hello', 'world', '!'], ['a', 'n', '.']],
[['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']],
[['Hello', 'world', '!'], ['a', 'n', '.']],
]
class TestSeqLabelPreprocess(unittest.TestCase):
def test_case_1(self):
data = [
[['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']],
[['Hello', 'world', '!'], ['a', 'n', '.']],
[['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']],
[['Hello', 'world', '!'], ['a', 'n', '.']],
[['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']],
[['Hello', 'world', '!'], ['a', 'n', '.']],
[['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']],
[['Hello', 'world', '!'], ['a', 'n', '.']],
[['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']],
[['Hello', 'world', '!'], ['a', 'n', '.']],
]
class TestCase1(unittest.TestCase):
def test(self):
if os.path.exists("./save"):
for root, dirs, files in os.walk("./save", topdown=False):
for name in files:
@ -27,17 +28,45 @@ class TestSeqLabelPreprocess(unittest.TestCase):
os.rmdir(os.path.join(root, name))
result = SeqLabelPreprocess().run(train_dev_data=data, train_dev_split=0.4,
pickle_path="./save")
result = SeqLabelPreprocess().run(train_dev_data=data, train_dev_split=0.4,
pickle_path="./save")
self.assertEqual(len(result), 2)
self.assertEqual(type(result[0]), DataSet)
self.assertEqual(type(result[1]), DataSet)
os.system("rm -rf save")
print("pickle path deleted")
class TestCase2(unittest.TestCase):
def test(self):
if os.path.exists("./save"):
for root, dirs, files in os.walk("./save", topdown=False):
for name in files:
os.remove(os.path.join(root, name))
for name in dirs:
os.rmdir(os.path.join(root, name))
result = SeqLabelPreprocess().run(test_data=data, train_dev_data=data,
pickle_path="./save", train_dev_split=0.4,
cross_val=True)
result = SeqLabelPreprocess().run(test_data=data, train_dev_data=data,
pickle_path="./save", train_dev_split=0.4,
cross_val=True)
cross_val=False)
self.assertEqual(len(result), 3)
self.assertEqual(type(result[0]), DataSet)
self.assertEqual(type(result[1]), DataSet)
self.assertEqual(type(result[2]), DataSet)
os.system("rm -rf save")
print("pickle path deleted")
class TestCase3(unittest.TestCase):
def test(self):
num_folds = 2
result = SeqLabelPreprocess().run(test_data=None, train_dev_data=data,
pickle_path="./save", train_dev_split=0.4,
cross_val=True, n_fold=num_folds)
self.assertEqual(len(result), 2)
self.assertEqual(len(result[0]), num_folds)
self.assertEqual(len(result[1]), num_folds)
for data_set in result[0] + result[1]:
self.assertEqual(type(data_set), DataSet)
os.system("rm -rf save")
print("pickle path deleted")

View File

@ -1,37 +1,55 @@
from fastNLP.core.preprocess import SeqLabelPreprocess
import os
import unittest
from fastNLP.core.dataset import DataSet
from fastNLP.core.field import TextField
from fastNLP.core.instance import Instance
from fastNLP.core.tester import SeqLabelTester
from fastNLP.loader.config_loader import ConfigSection, ConfigLoader
from fastNLP.loader.dataset_loader import TokenizeDatasetLoader
from fastNLP.models.sequence_modeling import SeqLabeling
data_name = "pku_training.utf8"
pickle_path = "data_for_tests"
def foo():
loader = TokenizeDatasetLoader("./data_for_tests/cws_pku_utf_8")
train_data = loader.load_pku()
class TestTester(unittest.TestCase):
def test_case_1(self):
model_args = {
"vocab_size": 10,
"word_emb_dim": 100,
"rnn_hidden_units": 100,
"num_classes": 5
}
valid_args = {"save_output": True, "validate_in_training": True, "save_dev_input": True,
"save_loss": True, "batch_size": 2, "pickle_path": "./save/",
"use_cuda": False, "print_every_step": 1}
train_args = ConfigSection()
ConfigLoader("config.cfg").load_config("./data_for_tests/config", {"POS": train_args})
train_data = [
[['a', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
[['a', '@', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
[['a', 'b', '#', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
[['a', 'b', 'c', '?', 'e'], ['a', '@', 'c', 'd', 'e']],
[['a', 'b', 'c', 'd', '$'], ['a', '@', 'c', 'd', 'e']],
[['!', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
]
vocab = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9}
label_vocab = {'a': 0, '@': 1, 'c': 2, 'd': 3, 'e': 4}
# Preprocessor
p = SeqLabelPreprocess()
train_data = p.run(train_data)
train_args["vocab_size"] = p.vocab_size
train_args["num_classes"] = p.num_classes
data_set = DataSet()
for example in train_data:
text, label = example[0], example[1]
x = TextField(text, False)
y = TextField(label, is_target=True)
ins = Instance(word_seq=x, label_seq=y)
data_set.append(ins)
model = SeqLabeling(train_args)
data_set.index_field("word_seq", vocab)
data_set.index_field("label_seq", label_vocab)
valid_args = {"save_output": True, "validate_in_training": True, "save_dev_input": True,
"save_loss": True, "batch_size": 8, "pickle_path": "./data_for_tests/",
"use_cuda": True}
validator = SeqLabelTester(**valid_args)
model = SeqLabeling(model_args)
print("start validation.")
validator.test(model, train_data)
print(validator.show_metrics())
tester = SeqLabelTester(**valid_args)
tester.test(network=model, dev_data=data_set)
# If this can run, everything is OK.
if __name__ == "__main__":
foo()
os.system("rm -rf save")
print("pickle path deleted")

View File

@ -1,33 +1,54 @@
import os
import torch.nn as nn
import unittest
from fastNLP.core.trainer import SeqLabelTrainer
from fastNLP.core.dataset import DataSet
from fastNLP.core.field import TextField
from fastNLP.core.instance import Instance
from fastNLP.core.loss import Loss
from fastNLP.core.optimizer import Optimizer
from fastNLP.core.trainer import SeqLabelTrainer
from fastNLP.models.sequence_modeling import SeqLabeling
class TestTrainer(unittest.TestCase):
def test_case_1(self):
args = {"epochs": 3, "batch_size": 8, "validate": True, "use_cuda": True, "pickle_path": "./save/",
args = {"epochs": 3, "batch_size": 2, "validate": True, "use_cuda": False, "pickle_path": "./save/",
"save_best_dev": True, "model_name": "default_model_name.pkl",
"loss": Loss(None),
"optimizer": Optimizer("Adam", lr=0.001, weight_decay=0),
"vocab_size": 20,
"vocab_size": 10,
"word_emb_dim": 100,
"rnn_hidden_units": 100,
"num_classes": 3
"num_classes": 5
}
trainer = SeqLabelTrainer()
trainer = SeqLabelTrainer(**args)
train_data = [
[[1, 2, 3, 4, 5, 6], [1, 0, 1, 0, 1, 2]],
[[2, 3, 4, 5, 1, 6], [0, 1, 0, 1, 0, 2]],
[[1, 4, 1, 4, 1, 6], [1, 0, 1, 0, 1, 2]],
[[1, 2, 3, 4, 5, 6], [1, 0, 1, 0, 1, 2]],
[[2, 3, 4, 5, 1, 6], [0, 1, 0, 1, 0, 2]],
[[1, 4, 1, 4, 1, 6], [1, 0, 1, 0, 1, 2]],
[['a', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
[['a', '@', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
[['a', 'b', '#', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
[['a', 'b', 'c', '?', 'e'], ['a', '@', 'c', 'd', 'e']],
[['a', 'b', 'c', 'd', '$'], ['a', '@', 'c', 'd', 'e']],
[['!', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
]
dev_data = train_data
vocab = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9}
label_vocab = {'a': 0, '@': 1, 'c': 2, 'd': 3, 'e': 4}
data_set = DataSet()
for example in train_data:
text, label = example[0], example[1]
x = TextField(text, False)
y = TextField(label, is_target=True)
ins = Instance(word_seq=x, label_seq=y)
data_set.append(ins)
data_set.index_field("word_seq", vocab)
data_set.index_field("label_seq", label_vocab)
model = SeqLabeling(args)
trainer.train(network=model, train_data=train_data, dev_data=dev_data)
trainer.train(network=model, train_data=data_set, dev_data=data_set)
# If this can run, everything is OK.
os.system("rm -rf save")
print("pickle path deleted")

View File

@ -15,11 +15,11 @@ from fastNLP.core.optimizer import Optimizer
parser = argparse.ArgumentParser()
parser.add_argument("-s", "--save", type=str, default="./seq_label/", help="path to save pickle files")
parser.add_argument("-t", "--train", type=str, default="./data_for_tests/people.txt",
parser.add_argument("-t", "--train", type=str, default="../data_for_tests/people.txt",
help="path to the training data")
parser.add_argument("-c", "--config", type=str, default="./data_for_tests/config", help="path to the config file")
parser.add_argument("-c", "--config", type=str, default="../data_for_tests/config", help="path to the config file")
parser.add_argument("-m", "--model_name", type=str, default="seq_label_model.pkl", help="the name of the model")
parser.add_argument("-i", "--infer", type=str, default="data_for_tests/people_infer.txt",
parser.add_argument("-i", "--infer", type=str, default="../data_for_tests/people_infer.txt",
help="data used for inference")
args = parser.parse_args()
@ -86,7 +86,7 @@ def train_and_test():
trainer = SeqLabelTrainer(
epochs=trainer_args["epochs"],
batch_size=trainer_args["batch_size"],
validate=trainer_args["validate"],
validate=False,
use_cuda=trainer_args["use_cuda"],
pickle_path=pickle_path,
save_best_dev=trainer_args["save_best_dev"],
@ -121,7 +121,7 @@ def train_and_test():
# Tester
tester = SeqLabelTester(save_output=False,
save_loss=False,
save_loss=True,
save_best_dev=False,
batch_size=4,
use_cuda=False,
@ -139,5 +139,5 @@ def train_and_test():
if __name__ == "__main__":
# train_and_test()
infer()
train_and_test()
# infer()

View File

@ -1,8 +0,0 @@
def test_charlm():
pass
if __name__ == "__main__":
test_charlm()

View File

@ -0,0 +1,85 @@
import os
from fastNLP.core.optimizer import Optimizer
from fastNLP.core.preprocess import SeqLabelPreprocess
from fastNLP.core.tester import SeqLabelTester
from fastNLP.core.trainer import SeqLabelTrainer
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
from fastNLP.loader.dataset_loader import POSDatasetLoader
from fastNLP.loader.model_loader import ModelLoader
from fastNLP.models.sequence_modeling import SeqLabeling
from fastNLP.saver.model_saver import ModelSaver
pickle_path = "./seq_label/"
model_name = "seq_label_model.pkl"
config_dir = "test/data_for_tests/config"
data_path = "test/data_for_tests/people.txt"
data_infer_path = "test/data_for_tests/people_infer.txt"
def test_training():
# Config Loader
trainer_args = ConfigSection()
model_args = ConfigSection()
ConfigLoader("_").load_config(config_dir, {
"test_seq_label_trainer": trainer_args, "test_seq_label_model": model_args})
# Data Loader
pos_loader = POSDatasetLoader(data_path)
train_data = pos_loader.load_lines()
# Preprocessor
p = SeqLabelPreprocess()
data_train, data_dev = p.run(train_data, pickle_path=pickle_path, train_dev_split=0.5)
model_args["vocab_size"] = p.vocab_size
model_args["num_classes"] = p.num_classes
trainer = SeqLabelTrainer(
epochs=trainer_args["epochs"],
batch_size=trainer_args["batch_size"],
validate=False,
use_cuda=False,
pickle_path=pickle_path,
save_best_dev=trainer_args["save_best_dev"],
model_name=model_name,
optimizer=Optimizer("SGD", lr=0.01, momentum=0.9),
)
# Model
model = SeqLabeling(model_args)
# Start training
trainer.train(model, data_train, data_dev)
# Saver
saver = ModelSaver(os.path.join(pickle_path, model_name))
saver.save_pytorch(model)
del model, trainer, pos_loader
# Define the same model
model = SeqLabeling(model_args)
# Dump trained parameters into the model
ModelLoader.load_pytorch(model, os.path.join(pickle_path, model_name))
# Load test configuration
tester_args = ConfigSection()
ConfigLoader("config.cfg").load_config(config_dir, {"test_seq_label_tester": tester_args})
# Tester
tester = SeqLabelTester(save_output=False,
save_loss=True,
save_best_dev=False,
batch_size=4,
use_cuda=False,
pickle_path=pickle_path,
model_name="seq_label_in_test.pkl",
print_every_step=1
)
# Start testing with validation data
tester.test(model, data_dev)
loss, accuracy = tester.metrics
assert 0 < accuracy < 1

View File

@ -19,9 +19,9 @@ from fastNLP.core.loss import Loss
parser = argparse.ArgumentParser()
parser.add_argument("-s", "--save", type=str, default="./test_classification/", help="path to save pickle files")
parser.add_argument("-t", "--train", type=str, default="./data_for_tests/text_classify.txt",
parser.add_argument("-t", "--train", type=str, default="../data_for_tests/text_classify.txt",
help="path to the training data")
parser.add_argument("-c", "--config", type=str, default="./data_for_tests/config", help="path to the config file")
parser.add_argument("-c", "--config", type=str, default="../data_for_tests/config", help="path to the config file")
parser.add_argument("-m", "--model_name", type=str, default="classify_model.pkl", help="the name of the model")
args = parser.parse_args()
@ -115,4 +115,4 @@ def train():
if __name__ == "__main__":
train()
infer()
# infer()