mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-03 20:57:37 +08:00
commit
1d4e406e6f
@ -18,7 +18,7 @@ pre-processing data, constructing model and training model.
|
||||
from fastNLP.modules import aggregation
|
||||
from fastNLP.modules import decoder
|
||||
|
||||
from fastNLP.loader.dataset_loader import ClassDatasetLoader
|
||||
from fastNLP.loader.dataset_loader import ClassDataSetLoader
|
||||
from fastNLP.loader.preprocess import ClassPreprocess
|
||||
from fastNLP.core.trainer import ClassificationTrainer
|
||||
from fastNLP.core.inference import ClassificationInfer
|
||||
@ -50,7 +50,7 @@ pre-processing data, constructing model and training model.
|
||||
train_path = 'test/data_for_tests/text_classify.txt' # training set file
|
||||
|
||||
# load dataset
|
||||
ds_loader = ClassDatasetLoader("train", train_path)
|
||||
ds_loader = ClassDataSetLoader("train", train_path)
|
||||
data = ds_loader.load()
|
||||
|
||||
# pre-process dataset
|
||||
|
@ -3,7 +3,7 @@ from fastNLP.core.optimizer import Optimizer
|
||||
from fastNLP.core.predictor import ClassificationInfer
|
||||
from fastNLP.core.preprocess import ClassPreprocess
|
||||
from fastNLP.core.trainer import ClassificationTrainer
|
||||
from fastNLP.loader.dataset_loader import ClassDatasetLoader
|
||||
from fastNLP.loader.dataset_loader import ClassDataSetLoader
|
||||
from fastNLP.models.base_model import BaseModel
|
||||
from fastNLP.modules import aggregator
|
||||
from fastNLP.modules import decoder
|
||||
@ -36,7 +36,7 @@ data_dir = 'save/' # directory to save data and model
|
||||
train_path = './data_for_tests/text_classify.txt' # training set file
|
||||
|
||||
# load dataset
|
||||
ds_loader = ClassDatasetLoader(train_path)
|
||||
ds_loader = ClassDataSetLoader()
|
||||
data = ds_loader.load()
|
||||
|
||||
# pre-process dataset
|
||||
|
@ -17,7 +17,7 @@ class Batch(object):
|
||||
:param dataset: a DataSet object
|
||||
:param batch_size: int, the size of the batch
|
||||
:param sampler: a Sampler object
|
||||
:param use_cuda: bool, whetjher to use GPU
|
||||
:param use_cuda: bool, whether to use GPU
|
||||
|
||||
"""
|
||||
self.dataset = dataset
|
||||
@ -37,15 +37,12 @@ class Batch(object):
|
||||
"""
|
||||
|
||||
:return batch_x: dict of (str: torch.LongTensor), which means (field name: tensor of shape [batch_size, padding_length])
|
||||
batch_x also contains an item (str: list of int) about origin lengths,
|
||||
which means ("field_name_origin_len": origin lengths).
|
||||
E.g.
|
||||
::
|
||||
{'text': tensor([[ 0, 1, 2, 3, 0, 0, 0], 4, 5, 2, 6, 7, 8, 9]]), 'text_origin_len': [4, 7]})
|
||||
|
||||
batch_y: dict of (str: torch.LongTensor), which means (field name: tensor of shape [batch_size, padding_length])
|
||||
All tensors in both batch_x and batch_y will be cuda tensors if use_cuda is True.
|
||||
The names of fields are defined in preprocessor's convert_to_dataset method.
|
||||
|
||||
"""
|
||||
if self.curidx >= len(self.idx_list):
|
||||
@ -54,10 +51,9 @@ class Batch(object):
|
||||
endidx = min(self.curidx + self.batch_size, len(self.idx_list))
|
||||
padding_length = {field_name: max(field_length[self.curidx: endidx])
|
||||
for field_name, field_length in self.lengths.items()}
|
||||
origin_lengths = {field_name: field_length[self.curidx: endidx]
|
||||
for field_name, field_length in self.lengths.items()}
|
||||
|
||||
batch_x, batch_y = defaultdict(list), defaultdict(list)
|
||||
|
||||
# transform index to tensor and do padding for sequences
|
||||
for idx in range(self.curidx, endidx):
|
||||
x, y = self.dataset.to_tensor(idx, padding_length)
|
||||
for name, tensor in x.items():
|
||||
@ -65,8 +61,7 @@ class Batch(object):
|
||||
for name, tensor in y.items():
|
||||
batch_y[name].append(tensor)
|
||||
|
||||
batch_origin_length = {}
|
||||
# combine instances into a batch
|
||||
# combine instances to form a batch
|
||||
for batch in (batch_x, batch_y):
|
||||
for name, tensor_list in batch.items():
|
||||
if self.use_cuda:
|
||||
@ -74,14 +69,6 @@ class Batch(object):
|
||||
else:
|
||||
batch[name] = torch.stack(tensor_list, dim=0)
|
||||
|
||||
# add origin lengths in batch_x
|
||||
for name, tensor in batch_x.items():
|
||||
if self.use_cuda:
|
||||
batch_origin_length[name + "_origin_len"] = torch.LongTensor(origin_lengths[name]).cuda()
|
||||
else:
|
||||
batch_origin_length[name + "_origin_len"] = torch.LongTensor(origin_lengths[name])
|
||||
batch_x.update(batch_origin_length)
|
||||
|
||||
self.curidx += endidx
|
||||
self.curidx = endidx
|
||||
return batch_x, batch_y
|
||||
|
||||
|
@ -1,7 +1,11 @@
|
||||
import random
|
||||
from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
|
||||
from fastNLP.core.field import TextField
|
||||
from fastNLP.core.field import TextField, LabelField
|
||||
from fastNLP.core.instance import Instance
|
||||
from fastNLP.core.vocabulary import Vocabulary
|
||||
from fastNLP.loader.dataset_loader import POSDataSetLoader, ClassDataSetLoader
|
||||
|
||||
|
||||
def create_dataset_from_lists(str_lists: list, word_vocab: dict, has_target: bool = False, label_vocab: dict = None):
|
||||
@ -65,7 +69,8 @@ class DataSet(list):
|
||||
"""A DataSet object is a list of Instance objects.
|
||||
|
||||
"""
|
||||
def __init__(self, name="", instances=None):
|
||||
|
||||
def __init__(self, name="", instances=None, loader=None):
|
||||
"""
|
||||
|
||||
:param name: str, the name of the dataset. (default: "")
|
||||
@ -76,6 +81,7 @@ class DataSet(list):
|
||||
self.name = name
|
||||
if instances is not None:
|
||||
self.extend(instances)
|
||||
self.dataset_loader = loader
|
||||
|
||||
def index_all(self, vocab):
|
||||
for ins in self:
|
||||
@ -109,3 +115,180 @@ class DataSet(list):
|
||||
for field_name, field_length in ins.get_length().items():
|
||||
lengths[field_name].append(field_length)
|
||||
return lengths
|
||||
|
||||
def convert(self, data):
|
||||
"""Convert lists of strings into Instances with Fields"""
|
||||
raise NotImplementedError
|
||||
|
||||
def convert_with_vocabs(self, data, vocabs):
|
||||
"""Convert lists of strings into Instances with Fields, using existing Vocabulary. Useful in predicting."""
|
||||
raise NotImplementedError
|
||||
|
||||
def convert_for_infer(self, data, vocabs):
|
||||
"""Convert lists of strings into Instances with Fields."""
|
||||
|
||||
def load(self, data_path, vocabs=None, infer=False):
|
||||
"""Load data from the given files.
|
||||
|
||||
:param data_path: str, the path to the data
|
||||
:param infer: bool. If True, there is no label information in the data. Default: False.
|
||||
:param vocabs: dict of (name: Vocabulary object), used to index data. If not provided, a new vocabulary will be constructed.
|
||||
|
||||
"""
|
||||
raw_data = self.dataset_loader.load(data_path)
|
||||
if infer is True:
|
||||
self.convert_for_infer(raw_data, vocabs)
|
||||
else:
|
||||
if vocabs is not None:
|
||||
self.convert_with_vocabs(raw_data, vocabs)
|
||||
else:
|
||||
self.convert(raw_data)
|
||||
|
||||
def load_raw(self, raw_data, vocabs):
|
||||
"""
|
||||
|
||||
:param raw_data:
|
||||
:param vocabs:
|
||||
:return:
|
||||
"""
|
||||
self.convert_for_infer(raw_data, vocabs)
|
||||
|
||||
def split(self, ratio, shuffle=True):
|
||||
"""Train/dev splitting
|
||||
|
||||
:param ratio: float, between 0 and 1. The ratio of development set in origin data set.
|
||||
:param shuffle: bool, whether shuffle the data set before splitting. Default: True.
|
||||
:return train_set: a DataSet object, representing the training set
|
||||
dev_set: a DataSet object, representing the validation set
|
||||
|
||||
"""
|
||||
assert 0 < ratio < 1
|
||||
if shuffle:
|
||||
random.shuffle(self)
|
||||
split_idx = int(len(self) * ratio)
|
||||
dev_set = deepcopy(self)
|
||||
train_set = deepcopy(self)
|
||||
del train_set[:split_idx]
|
||||
del dev_set[split_idx:]
|
||||
return train_set, dev_set
|
||||
|
||||
|
||||
class SeqLabelDataSet(DataSet):
|
||||
def __init__(self, instances=None, loader=POSDataSetLoader()):
|
||||
super(SeqLabelDataSet, self).__init__(name="", instances=instances, loader=loader)
|
||||
self.word_vocab = Vocabulary()
|
||||
self.label_vocab = Vocabulary()
|
||||
|
||||
def convert(self, data):
|
||||
"""Convert lists of strings into Instances with Fields.
|
||||
|
||||
:param data: 3-level lists. Entries are strings.
|
||||
"""
|
||||
for example in data:
|
||||
word_seq, label_seq = example[0], example[1]
|
||||
# list, list
|
||||
self.word_vocab.update(word_seq)
|
||||
self.label_vocab.update(label_seq)
|
||||
x = TextField(word_seq, is_target=False)
|
||||
x_len = LabelField(len(word_seq), is_target=False)
|
||||
y = TextField(label_seq, is_target=False)
|
||||
instance = Instance()
|
||||
instance.add_field("word_seq", x)
|
||||
instance.add_field("truth", y)
|
||||
instance.add_field("word_seq_origin_len", x_len)
|
||||
self.append(instance)
|
||||
self.index_field("word_seq", self.word_vocab)
|
||||
self.index_field("truth", self.label_vocab)
|
||||
# no need to index "word_seq_origin_len"
|
||||
|
||||
def convert_with_vocabs(self, data, vocabs):
|
||||
for example in data:
|
||||
word_seq, label_seq = example[0], example[1]
|
||||
# list, list
|
||||
x = TextField(word_seq, is_target=False)
|
||||
x_len = LabelField(len(word_seq), is_target=False)
|
||||
y = TextField(label_seq, is_target=False)
|
||||
instance = Instance()
|
||||
instance.add_field("word_seq", x)
|
||||
instance.add_field("truth", y)
|
||||
instance.add_field("word_seq_origin_len", x_len)
|
||||
self.append(instance)
|
||||
self.index_field("word_seq", vocabs["word_vocab"])
|
||||
self.index_field("truth", vocabs["label_vocab"])
|
||||
# no need to index "word_seq_origin_len"
|
||||
|
||||
def convert_for_infer(self, data, vocabs):
|
||||
for word_seq in data:
|
||||
# list
|
||||
x = TextField(word_seq, is_target=False)
|
||||
x_len = LabelField(len(word_seq), is_target=False)
|
||||
instance = Instance()
|
||||
instance.add_field("word_seq", x)
|
||||
instance.add_field("word_seq_origin_len", x_len)
|
||||
self.append(instance)
|
||||
self.index_field("word_seq", vocabs["word_vocab"])
|
||||
# no need to index "word_seq_origin_len"
|
||||
|
||||
|
||||
class TextClassifyDataSet(DataSet):
|
||||
def __init__(self, instances=None, loader=ClassDataSetLoader()):
|
||||
super(TextClassifyDataSet, self).__init__(name="", instances=instances, loader=loader)
|
||||
self.word_vocab = Vocabulary()
|
||||
self.label_vocab = Vocabulary(need_default=False)
|
||||
|
||||
def convert(self, data):
|
||||
for example in data:
|
||||
word_seq, label = example[0], example[1]
|
||||
# list, str
|
||||
self.word_vocab.update(word_seq)
|
||||
self.label_vocab.update(label)
|
||||
x = TextField(word_seq, is_target=False)
|
||||
y = LabelField(label, is_target=True)
|
||||
instance = Instance()
|
||||
instance.add_field("word_seq", x)
|
||||
instance.add_field("label", y)
|
||||
self.append(instance)
|
||||
self.index_field("word_seq", self.word_vocab)
|
||||
self.index_field("label", self.label_vocab)
|
||||
|
||||
def convert_with_vocabs(self, data, vocabs):
|
||||
for example in data:
|
||||
word_seq, label = example[0], example[1]
|
||||
# list, str
|
||||
x = TextField(word_seq, is_target=False)
|
||||
y = LabelField(label, is_target=True)
|
||||
instance = Instance()
|
||||
instance.add_field("word_seq", x)
|
||||
instance.add_field("label", y)
|
||||
self.append(instance)
|
||||
self.index_field("word_seq", vocabs["word_vocab"])
|
||||
self.index_field("label", vocabs["label_vocab"])
|
||||
|
||||
def convert_for_infer(self, data, vocabs):
|
||||
for word_seq in data:
|
||||
# list
|
||||
x = TextField(word_seq, is_target=False)
|
||||
instance = Instance()
|
||||
instance.add_field("word_seq", x)
|
||||
self.append(instance)
|
||||
self.index_field("word_seq", vocabs["word_vocab"])
|
||||
|
||||
|
||||
def change_field_is_target(data_set, field_name, new_target):
|
||||
"""Change the flag of is_target in a field.
|
||||
|
||||
:param data_set: a DataSet object
|
||||
:param field_name: str, the name of the field
|
||||
:param new_target: one of (True, False, None), representing this field is batch_x / is batch_y / neither.
|
||||
|
||||
"""
|
||||
for inst in data_set:
|
||||
inst.fields[field_name].is_target = new_target
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
data_set = SeqLabelDataSet()
|
||||
data_set.load("../../test/data_for_tests/people.txt")
|
||||
a, b = data_set.split(0.3)
|
||||
print(type(data_set), type(a), type(b))
|
||||
print(len(data_set), len(a), len(b))
|
||||
|
@ -59,6 +59,9 @@ class TextField(Field):
|
||||
|
||||
|
||||
class LabelField(Field):
|
||||
"""The Field representing a single label. Can be a string or integer.
|
||||
|
||||
"""
|
||||
def __init__(self, label, is_target=True):
|
||||
super(LabelField, self).__init__(is_target)
|
||||
self.label = label
|
||||
@ -73,13 +76,14 @@ class LabelField(Field):
|
||||
|
||||
def index(self, vocab):
|
||||
if self._index is None:
|
||||
if isinstance(self.label, str):
|
||||
self._index = vocab[self.label]
|
||||
return self._index
|
||||
|
||||
def to_tensor(self, padding_length):
|
||||
if self._index is None:
|
||||
if isinstance(self.label, int):
|
||||
return torch.LongTensor([self.label])
|
||||
return torch.tensor(self.label)
|
||||
elif isinstance(self.label, str):
|
||||
raise RuntimeError("Field {} not indexed. Call index method.".format(self.label))
|
||||
else:
|
||||
|
@ -46,8 +46,11 @@ class Instance(object):
|
||||
tensor_x = {}
|
||||
tensor_y = {}
|
||||
for name, field in self.fields.items():
|
||||
if field.is_target:
|
||||
if field.is_target is True:
|
||||
tensor_y[name] = field.to_tensor(padding_length[name])
|
||||
else:
|
||||
elif field.is_target is False:
|
||||
tensor_x[name] = field.to_tensor(padding_length[name])
|
||||
else:
|
||||
# is_target is None
|
||||
continue
|
||||
return tensor_x, tensor_y
|
||||
|
@ -33,10 +33,25 @@ class Loss(object):
|
||||
"""Given a name of a loss function, return it from PyTorch.
|
||||
|
||||
:param loss_name: str, the name of a loss function
|
||||
|
||||
- cross_entropy: combines log softmax and nll loss in a single function.
|
||||
- nll: negative log likelihood
|
||||
|
||||
:return loss: a PyTorch loss
|
||||
"""
|
||||
|
||||
class InnerCrossEntropy:
|
||||
"""A simple wrapper to guarantee input shapes."""
|
||||
|
||||
def __init__(self):
|
||||
self.f = torch.nn.CrossEntropyLoss()
|
||||
|
||||
def __call__(self, predict, truth):
|
||||
truth = truth.view(-1, )
|
||||
return self.f(predict, truth)
|
||||
|
||||
if loss_name == "cross_entropy":
|
||||
return torch.nn.CrossEntropyLoss()
|
||||
return InnerCrossEntropy()
|
||||
elif loss_name == 'nll':
|
||||
return torch.nn.NLLLoss()
|
||||
else:
|
||||
|
@ -4,6 +4,56 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
class Evaluator(object):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __call__(self, predict, truth):
|
||||
"""
|
||||
|
||||
:param predict: list of tensors, the network outputs from all batches.
|
||||
:param truth: list of dict, the ground truths from all batch_y.
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ClassifyEvaluator(Evaluator):
|
||||
def __init__(self):
|
||||
super(ClassifyEvaluator, self).__init__()
|
||||
|
||||
def __call__(self, predict, truth):
|
||||
y_prob = [torch.nn.functional.softmax(y_logit, dim=-1) for y_logit in predict]
|
||||
y_prob = torch.cat(y_prob, dim=0)
|
||||
y_pred = torch.argmax(y_prob, dim=-1)
|
||||
y_true = torch.cat(truth, dim=0)
|
||||
acc = float(torch.sum(y_pred == y_true)) / len(y_true)
|
||||
return {"accuracy": acc}
|
||||
|
||||
|
||||
class SeqLabelEvaluator(Evaluator):
|
||||
def __init__(self):
|
||||
super(SeqLabelEvaluator, self).__init__()
|
||||
|
||||
def __call__(self, predict, truth):
|
||||
"""
|
||||
|
||||
:param predict: list of List, the network outputs from all batches.
|
||||
:param truth: list of dict, the ground truths from all batch_y.
|
||||
:return accuracy:
|
||||
"""
|
||||
truth = [item["truth"] for item in truth]
|
||||
total_correct, total_count= 0., 0.
|
||||
for x, y in zip(predict, truth):
|
||||
mask = torch.Tensor(x).ge(1)
|
||||
correct = torch.sum(torch.Tensor(x) * mask.float() == (y * mask.long()).float())
|
||||
correct -= torch.sum(torch.Tensor(x).le(0))
|
||||
total_correct += float(correct)
|
||||
total_count += float(torch.sum(mask))
|
||||
accuracy = total_correct / total_count
|
||||
return {"accuracy": float(accuracy)}
|
||||
|
||||
|
||||
def _conver_numpy(x):
|
||||
"""convert input data to numpy array
|
||||
|
||||
|
@ -16,43 +16,42 @@ class Predictor(object):
|
||||
Currently, Predictor does not support GPU.
|
||||
"""
|
||||
|
||||
def __init__(self, pickle_path, task):
|
||||
def __init__(self, pickle_path, post_processor):
|
||||
"""
|
||||
|
||||
:param pickle_path: str, the path to the pickle files.
|
||||
:param task: str, specify which task the predictor will perform. One of ("seq_label", "text_classify").
|
||||
:param post_processor: a function or callable object, that takes list of batch outputs as input
|
||||
|
||||
"""
|
||||
self.batch_size = 1
|
||||
self.batch_output = []
|
||||
self.pickle_path = pickle_path
|
||||
self._task = task # one of ("seq_label", "text_classify")
|
||||
self.label_vocab = load_pickle(self.pickle_path, "class2id.pkl")
|
||||
self._post_processor = post_processor
|
||||
self.label_vocab = load_pickle(self.pickle_path, "label2id.pkl")
|
||||
self.word_vocab = load_pickle(self.pickle_path, "word2id.pkl")
|
||||
|
||||
def predict(self, network, data):
|
||||
"""Perform inference using the trained model.
|
||||
|
||||
:param network: a PyTorch model (cpu)
|
||||
:param data: list of list of strings, [num_examples, seq_len]
|
||||
:param data: a DataSet object.
|
||||
:return: list of list of strings, [num_examples, tag_seq_length]
|
||||
"""
|
||||
# transform strings into DataSet object
|
||||
data = self.prepare_input(data)
|
||||
# data = self.prepare_input(data)
|
||||
|
||||
# turn on the testing mode; clean up the history
|
||||
self.mode(network, test=True)
|
||||
self.batch_output.clear()
|
||||
batch_output = []
|
||||
|
||||
data_iterator = Batch(data, batch_size=self.batch_size, sampler=SequentialSampler(), use_cuda=False)
|
||||
|
||||
for batch_x, _ in data_iterator:
|
||||
with torch.no_grad():
|
||||
prediction = self.data_forward(network, batch_x)
|
||||
batch_output.append(prediction)
|
||||
|
||||
self.batch_output.append(prediction)
|
||||
|
||||
return self.prepare_output(self.batch_output)
|
||||
return self._post_processor(batch_output, self.label_vocab)
|
||||
|
||||
def mode(self, network, test=True):
|
||||
if test:
|
||||
@ -62,13 +61,7 @@ class Predictor(object):
|
||||
|
||||
def data_forward(self, network, x):
|
||||
"""Forward through network."""
|
||||
if self._task == "seq_label":
|
||||
y = network(x["word_seq"], x["word_seq_origin_len"])
|
||||
y = network.prediction(y)
|
||||
elif self._task == "text_classify":
|
||||
y = network(x["word_seq"])
|
||||
else:
|
||||
raise NotImplementedError("Unknown task type {}.".format(self._task))
|
||||
y = network(**x)
|
||||
return y
|
||||
|
||||
def prepare_input(self, data):
|
||||
@ -88,39 +81,32 @@ class Predictor(object):
|
||||
assert isinstance(data, list)
|
||||
return create_dataset_from_lists(data, self.word_vocab, has_target=False)
|
||||
|
||||
def prepare_output(self, data):
|
||||
"""Transform list of batch outputs into strings."""
|
||||
if self._task == "seq_label":
|
||||
return self._seq_label_prepare_output(data)
|
||||
elif self._task == "text_classify":
|
||||
return self._text_classify_prepare_output(data)
|
||||
else:
|
||||
raise NotImplementedError("Unknown task type {}".format(self._task))
|
||||
|
||||
def _seq_label_prepare_output(self, batch_outputs):
|
||||
results = []
|
||||
for batch in batch_outputs:
|
||||
for example in np.array(batch):
|
||||
results.append([self.label_vocab.to_word(int(x)) for x in example])
|
||||
return results
|
||||
|
||||
def _text_classify_prepare_output(self, batch_outputs):
|
||||
results = []
|
||||
for batch_out in batch_outputs:
|
||||
idx = np.argmax(batch_out.detach().numpy(), axis=-1)
|
||||
results.extend([self.label_vocab.to_word(i) for i in idx])
|
||||
return results
|
||||
|
||||
|
||||
class SeqLabelInfer(Predictor):
|
||||
def __init__(self, pickle_path):
|
||||
print(
|
||||
"[FastNLP Warning] SeqLabelInfer will be deprecated. Please use Predictor with argument 'task'='seq_label'.")
|
||||
super(SeqLabelInfer, self).__init__(pickle_path, "seq_label")
|
||||
"[FastNLP Warning] SeqLabelInfer will be deprecated. Please use Predictor directly.")
|
||||
super(SeqLabelInfer, self).__init__(pickle_path, seq_label_post_processor)
|
||||
|
||||
|
||||
class ClassificationInfer(Predictor):
|
||||
def __init__(self, pickle_path):
|
||||
print(
|
||||
"[FastNLP Warning] ClassificationInfer will be deprecated. Please use Predictor with argument 'task'='text_classify'.")
|
||||
super(ClassificationInfer, self).__init__(pickle_path, "text_classify")
|
||||
"[FastNLP Warning] ClassificationInfer will be deprecated. Please use Predictor directly.")
|
||||
super(ClassificationInfer, self).__init__(pickle_path, text_classify_post_processor)
|
||||
|
||||
|
||||
def seq_label_post_processor(batch_outputs, label_vocab):
|
||||
results = []
|
||||
for batch in batch_outputs:
|
||||
for example in np.array(batch):
|
||||
results.append([label_vocab.to_word(int(x)) for x in example])
|
||||
return results
|
||||
|
||||
|
||||
def text_classify_post_processor(batch_outputs, label_vocab):
|
||||
results = []
|
||||
for batch_out in batch_outputs:
|
||||
idx = np.argmax(batch_out.detach().numpy(), axis=-1)
|
||||
results.extend([label_vocab.to_word(i) for i in idx])
|
||||
return results
|
||||
|
@ -18,6 +18,9 @@ def save_pickle(obj, pickle_path, file_name):
|
||||
:param pickle_path: str, the directory where the pickle file is to be saved
|
||||
:param file_name: str, the name of the pickle file. In general, it should be ended by "pkl".
|
||||
"""
|
||||
if not os.path.exists(pickle_path):
|
||||
os.mkdir(pickle_path)
|
||||
print("make dir {} before saving pickle file".format(pickle_path))
|
||||
with open(os.path.join(pickle_path, file_name), "wb") as f:
|
||||
_pickle.dump(obj, f)
|
||||
print("{} saved in {}".format(file_name, pickle_path))
|
||||
@ -66,14 +69,26 @@ class Preprocessor(object):
|
||||
Preprocessors will check if those files are already in the directory and will reuse them in future calls.
|
||||
"""
|
||||
|
||||
def __init__(self, label_is_seq=False):
|
||||
def __init__(self, label_is_seq=False, share_vocab=False, add_char_field=False):
|
||||
"""
|
||||
|
||||
:param label_is_seq: bool, whether label is a sequence. If True, label vocabulary will preserve
|
||||
several special tokens for sequence processing.
|
||||
:param share_vocab: bool, whether word sequence and label sequence share the same vocabulary. Typically, this
|
||||
is only available when label_is_seq is True. Default: False.
|
||||
:param add_char_field: bool, whether to add character representations to all TextFields. Default: False.
|
||||
"""
|
||||
self.data_vocab = Vocabulary()
|
||||
self.label_vocab = Vocabulary(need_default=label_is_seq)
|
||||
if label_is_seq is True:
|
||||
if share_vocab is True:
|
||||
self.label_vocab = self.data_vocab
|
||||
else:
|
||||
self.label_vocab = Vocabulary()
|
||||
else:
|
||||
self.label_vocab = Vocabulary(need_default=False)
|
||||
|
||||
self.character_vocab = Vocabulary(need_default=False)
|
||||
self.add_char_field = add_char_field
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
@ -83,6 +98,12 @@ class Preprocessor(object):
|
||||
def num_classes(self):
|
||||
return len(self.label_vocab)
|
||||
|
||||
@property
|
||||
def char_vocab_size(self):
|
||||
if self.character_vocab is None:
|
||||
self.build_char_dict()
|
||||
return len(self.character_vocab)
|
||||
|
||||
def run(self, train_dev_data, test_data=None, pickle_path="./", train_dev_split=0, cross_val=False, n_fold=10):
|
||||
"""Main pre-processing pipeline.
|
||||
|
||||
@ -96,7 +117,6 @@ class Preprocessor(object):
|
||||
If train_dev_split > 0, return one more dataset - the dev set. If cross_val is True, each dataset
|
||||
is a list of DataSet objects; Otherwise, each dataset is a DataSet object.
|
||||
"""
|
||||
|
||||
if pickle_exist(pickle_path, "word2id.pkl") and pickle_exist(pickle_path, "class2id.pkl"):
|
||||
self.data_vocab = load_pickle(pickle_path, "word2id.pkl")
|
||||
self.label_vocab = load_pickle(pickle_path, "class2id.pkl")
|
||||
@ -176,6 +196,16 @@ class Preprocessor(object):
|
||||
self.label_vocab.update(label)
|
||||
return self.data_vocab, self.label_vocab
|
||||
|
||||
def build_char_dict(self):
|
||||
char_collection = set()
|
||||
for word in self.data_vocab.word2idx:
|
||||
if len(word) == 0:
|
||||
continue
|
||||
for ch in word:
|
||||
if ch not in char_collection:
|
||||
char_collection.add(ch)
|
||||
self.character_vocab.update(list(char_collection))
|
||||
|
||||
def build_reverse_dict(self):
|
||||
self.data_vocab.build_reverse_vocab()
|
||||
self.label_vocab.build_reverse_vocab()
|
||||
|
@ -1,7 +1,7 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from fastNLP.core.batch import Batch
|
||||
from fastNLP.core.metrics import Evaluator
|
||||
from fastNLP.core.sampler import RandomSampler
|
||||
from fastNLP.saver.logger import create_logger
|
||||
|
||||
@ -22,28 +22,23 @@ class Tester(object):
|
||||
"kwargs" must have the same type as "default_args" on corresponding keys.
|
||||
Otherwise, error will raise.
|
||||
"""
|
||||
default_args = {"save_output": True, # collect outputs of validation set
|
||||
"save_loss": True, # collect losses in validation
|
||||
"save_best_dev": False, # save best model during validation
|
||||
"batch_size": 8,
|
||||
default_args = {"batch_size": 8,
|
||||
"use_cuda": False,
|
||||
"pickle_path": "./save/",
|
||||
"model_name": "dev_best_model.pkl",
|
||||
"print_every_step": 1,
|
||||
"evaluator": Evaluator()
|
||||
}
|
||||
"""
|
||||
"required_args" is the collection of arguments that users must pass to Trainer explicitly.
|
||||
This is used to warn users of essential settings in the training.
|
||||
Specially, "required_args" does not have default value, so they have nothing to do with "default_args".
|
||||
"""
|
||||
required_args = {"task" # one of ("seq_label", "text_classify")
|
||||
}
|
||||
required_args = {}
|
||||
|
||||
for req_key in required_args:
|
||||
if req_key not in kwargs:
|
||||
logger.error("Tester lacks argument {}".format(req_key))
|
||||
raise ValueError("Tester lacks argument {}".format(req_key))
|
||||
self._task = kwargs["task"]
|
||||
|
||||
for key in default_args:
|
||||
if key in kwargs:
|
||||
@ -59,17 +54,13 @@ class Tester(object):
|
||||
pass
|
||||
print(default_args)
|
||||
|
||||
self.save_output = default_args["save_output"]
|
||||
self.save_best_dev = default_args["save_best_dev"]
|
||||
self.save_loss = default_args["save_loss"]
|
||||
self.batch_size = default_args["batch_size"]
|
||||
self.pickle_path = default_args["pickle_path"]
|
||||
self.use_cuda = default_args["use_cuda"]
|
||||
self.print_every_step = default_args["print_every_step"]
|
||||
self._evaluator = default_args["evaluator"]
|
||||
|
||||
self._model = None
|
||||
self.eval_history = [] # evaluation results of all batches
|
||||
self.batch_output = [] # outputs of all batches
|
||||
|
||||
def test(self, network, dev_data):
|
||||
if torch.cuda.is_available() and self.use_cuda:
|
||||
@ -80,26 +71,18 @@ class Tester(object):
|
||||
# turn on the testing mode; clean up the history
|
||||
self.mode(network, is_test=True)
|
||||
self.eval_history.clear()
|
||||
self.batch_output.clear()
|
||||
output_list = []
|
||||
truth_list = []
|
||||
|
||||
data_iterator = Batch(dev_data, self.batch_size, sampler=RandomSampler(), use_cuda=self.use_cuda)
|
||||
step = 0
|
||||
|
||||
for batch_x, batch_y in data_iterator:
|
||||
with torch.no_grad():
|
||||
prediction = self.data_forward(network, batch_x)
|
||||
eval_results = self.evaluate(prediction, batch_y)
|
||||
|
||||
if self.save_output:
|
||||
self.batch_output.append(prediction)
|
||||
if self.save_loss:
|
||||
self.eval_history.append(eval_results)
|
||||
|
||||
print_output = "[test step {}] {}".format(step, eval_results)
|
||||
logger.info(print_output)
|
||||
if self.print_every_step > 0 and step % self.print_every_step == 0:
|
||||
print(self.make_eval_output(prediction, eval_results))
|
||||
step += 1
|
||||
output_list.append(prediction)
|
||||
truth_list.append(batch_y)
|
||||
eval_results = self.evaluate(output_list, truth_list)
|
||||
print("[tester] {}".format(self.print_eval_results(eval_results)))
|
||||
|
||||
def mode(self, model, is_test=False):
|
||||
"""Train mode or Test mode. This is for PyTorch currently.
|
||||
@ -121,104 +104,30 @@ class Tester(object):
|
||||
def evaluate(self, predict, truth):
|
||||
"""Compute evaluation metrics.
|
||||
|
||||
:param predict: Tensor
|
||||
:param truth: Tensor
|
||||
:param predict: list of Tensor
|
||||
:param truth: list of dict
|
||||
:return eval_results: can be anything. It will be stored in self.eval_history
|
||||
"""
|
||||
if "label_seq" in truth:
|
||||
truth = truth["label_seq"]
|
||||
elif "label" in truth:
|
||||
truth = truth["label"]
|
||||
else:
|
||||
raise NotImplementedError("Unknown key {} in batch_y.".format(truth.keys()))
|
||||
return self._evaluator(predict, truth)
|
||||
|
||||
if self._task == "seq_label":
|
||||
return self._seq_label_evaluate(predict, truth)
|
||||
elif self._task == "text_classify":
|
||||
return self._text_classify_evaluate(predict, truth)
|
||||
else:
|
||||
raise NotImplementedError("Unknown task type {}.".format(self._task))
|
||||
def print_eval_results(self, results):
|
||||
"""Override this method to support more print formats.
|
||||
|
||||
def _seq_label_evaluate(self, predict, truth):
|
||||
batch_size, max_len = predict.size(0), predict.size(1)
|
||||
loss = self._model.loss(predict, truth) / batch_size
|
||||
prediction = self._model.prediction(predict)
|
||||
# pad prediction to equal length
|
||||
for pred in prediction:
|
||||
if len(pred) < max_len:
|
||||
pred += [0] * (max_len - len(pred))
|
||||
results = torch.Tensor(prediction).view(-1, )
|
||||
:param results: dict, (str: float) is (metrics name: value)
|
||||
|
||||
# make sure "results" is in the same device as "truth"
|
||||
results = results.to(truth)
|
||||
accuracy = torch.sum(results == truth.view((-1,))).to(torch.float) / results.shape[0]
|
||||
return [float(loss), float(accuracy)]
|
||||
|
||||
def _text_classify_evaluate(self, y_logit, y_true):
|
||||
y_prob = torch.nn.functional.softmax(y_logit, dim=-1)
|
||||
return [y_prob, y_true]
|
||||
|
||||
@property
|
||||
def metrics(self):
|
||||
"""Compute and return metrics.
|
||||
Use self.eval_history to compute metrics over the whole dev set.
|
||||
Please refer to metrics.py for common metric functions.
|
||||
|
||||
:return : variable number of outputs
|
||||
"""
|
||||
if self._task == "seq_label":
|
||||
return self._seq_label_metrics
|
||||
elif self._task == "text_classify":
|
||||
return self._text_classify_metrics
|
||||
else:
|
||||
raise NotImplementedError("Unknown task type {}.".format(self._task))
|
||||
|
||||
@property
|
||||
def _seq_label_metrics(self):
|
||||
batch_loss = np.mean([x[0] for x in self.eval_history])
|
||||
batch_accuracy = np.mean([x[1] for x in self.eval_history])
|
||||
return batch_loss, batch_accuracy
|
||||
|
||||
@property
|
||||
def _text_classify_metrics(self):
|
||||
y_prob, y_true = zip(*self.eval_history)
|
||||
y_prob = torch.cat(y_prob, dim=0)
|
||||
y_pred = torch.argmax(y_prob, dim=-1)
|
||||
y_true = torch.cat(y_true, dim=0)
|
||||
acc = float(torch.sum(y_pred == y_true)) / len(y_true)
|
||||
return y_true.cpu().numpy(), y_prob.cpu().numpy(), acc
|
||||
|
||||
def show_metrics(self):
|
||||
"""Customize evaluation outputs in Trainer.
|
||||
Called by Trainer to print evaluation results on dev set during training.
|
||||
Use self.metrics to fetch available metrics.
|
||||
|
||||
:return print_str: str
|
||||
"""
|
||||
loss, accuracy = self.metrics
|
||||
return "dev loss={:.2f}, accuracy={:.2f}".format(loss, accuracy)
|
||||
|
||||
def make_eval_output(self, predictions, eval_results):
|
||||
"""Customize Tester outputs.
|
||||
|
||||
:param predictions: Tensor
|
||||
:param eval_results: Tensor
|
||||
:return: str, to be printed.
|
||||
"""
|
||||
return self.show_metrics()
|
||||
return ", ".join([str(key) + "=" + str(value) for key, value in results.items()])
|
||||
|
||||
|
||||
class SeqLabelTester(Tester):
|
||||
def __init__(self, **test_args):
|
||||
test_args.update({"task": "seq_label"})
|
||||
print(
|
||||
"[FastNLP Warning] SeqLabelTester will be deprecated. Please use Tester with argument 'task'='seq_label'.")
|
||||
"[FastNLP Warning] SeqLabelTester will be deprecated. Please use Tester directly.")
|
||||
super(SeqLabelTester, self).__init__(**test_args)
|
||||
|
||||
|
||||
class ClassificationTester(Tester):
|
||||
def __init__(self, **test_args):
|
||||
test_args.update({"task": "text_classify"})
|
||||
print(
|
||||
"[FastNLP Warning] ClassificationTester will be deprecated. Please use Tester with argument 'task'='text_classify'.")
|
||||
"[FastNLP Warning] ClassificationTester will be deprecated. Please use Tester directly.")
|
||||
super(ClassificationTester, self).__init__(**test_args)
|
||||
|
@ -8,6 +8,7 @@ from tensorboardX import SummaryWriter
|
||||
|
||||
from fastNLP.core.batch import Batch
|
||||
from fastNLP.core.loss import Loss
|
||||
from fastNLP.core.metrics import Evaluator
|
||||
from fastNLP.core.optimizer import Optimizer
|
||||
from fastNLP.core.sampler import RandomSampler
|
||||
from fastNLP.core.tester import SeqLabelTester, ClassificationTester
|
||||
@ -43,21 +44,20 @@ class Trainer(object):
|
||||
default_args = {"epochs": 1, "batch_size": 2, "validate": False, "use_cuda": False, "pickle_path": "./save/",
|
||||
"save_best_dev": False, "model_name": "default_model_name.pkl", "print_every_step": 1,
|
||||
"loss": Loss(None), # used to pass type check
|
||||
"optimizer": Optimizer("Adam", lr=0.001, weight_decay=0)
|
||||
"optimizer": Optimizer("Adam", lr=0.001, weight_decay=0),
|
||||
"evaluator": Evaluator()
|
||||
}
|
||||
"""
|
||||
"required_args" is the collection of arguments that users must pass to Trainer explicitly.
|
||||
This is used to warn users of essential settings in the training.
|
||||
Specially, "required_args" does not have default value, so they have nothing to do with "default_args".
|
||||
"""
|
||||
required_args = {"task" # one of ("seq_label", "text_classify")
|
||||
}
|
||||
required_args = {}
|
||||
|
||||
for req_key in required_args:
|
||||
if req_key not in kwargs:
|
||||
logger.error("Trainer lacks argument {}".format(req_key))
|
||||
raise ValueError("Trainer lacks argument {}".format(req_key))
|
||||
self._task = kwargs["task"]
|
||||
|
||||
for key in default_args:
|
||||
if key in kwargs:
|
||||
@ -86,6 +86,7 @@ class Trainer(object):
|
||||
self._loss_func = default_args["loss"].get() # return a pytorch loss function or None
|
||||
self._optimizer = None
|
||||
self._optimizer_proto = default_args["optimizer"]
|
||||
self._evaluator = default_args["evaluator"]
|
||||
self._summary_writer = SummaryWriter(self.pickle_path + 'tensorboard_logs')
|
||||
self._graph_summaried = False
|
||||
self._best_accuracy = 0.0
|
||||
@ -106,9 +107,8 @@ class Trainer(object):
|
||||
|
||||
# define Tester over dev data
|
||||
if self.validate:
|
||||
default_valid_args = {"save_output": True, "validate_in_training": True, "save_dev_input": True,
|
||||
"save_loss": True, "batch_size": self.batch_size, "pickle_path": self.pickle_path,
|
||||
"use_cuda": self.use_cuda, "print_every_step": 0}
|
||||
default_valid_args = {"batch_size": self.batch_size, "pickle_path": self.pickle_path,
|
||||
"use_cuda": self.use_cuda, "evaluator": self._evaluator}
|
||||
validator = self._create_validator(default_valid_args)
|
||||
logger.info("validator defined as {}".format(str(validator)))
|
||||
|
||||
@ -229,18 +229,9 @@ class Trainer(object):
|
||||
self._optimizer.step()
|
||||
|
||||
def data_forward(self, network, x):
|
||||
if self._task == "seq_label":
|
||||
y = network(x["word_seq"], x["word_seq_origin_len"])
|
||||
elif self._task == "text_classify":
|
||||
y = network(x["word_seq"])
|
||||
else:
|
||||
raise NotImplementedError("Unknown task type {}.".format(self._task))
|
||||
|
||||
y = network(**x)
|
||||
if not self._graph_summaried:
|
||||
if self._task == "seq_label":
|
||||
self._summary_writer.add_graph(network, (x["word_seq"], x["word_seq_origin_len"]), verbose=False)
|
||||
elif self._task == "text_classify":
|
||||
self._summary_writer.add_graph(network, x["word_seq"], verbose=False)
|
||||
# self._summary_writer.add_graph(network, x, verbose=False)
|
||||
self._graph_summaried = True
|
||||
return y
|
||||
|
||||
@ -261,13 +252,9 @@ class Trainer(object):
|
||||
:param truth: ground truth label vector
|
||||
:return: a scalar
|
||||
"""
|
||||
if "label_seq" in truth:
|
||||
truth = truth["label_seq"]
|
||||
elif "label" in truth:
|
||||
truth = truth["label"]
|
||||
truth = truth.view((-1,))
|
||||
else:
|
||||
raise NotImplementedError("Unknown key {} in batch_y.".format(truth.keys()))
|
||||
if len(truth) > 1:
|
||||
raise NotImplementedError("Not ready to handle multi-labels.")
|
||||
truth = list(truth.values())[0] if len(truth) > 0 else None
|
||||
return self._loss_func(predict, truth)
|
||||
|
||||
def define_loss(self):
|
||||
@ -278,8 +265,8 @@ class Trainer(object):
|
||||
These two losses cannot be defined at the same time.
|
||||
Trainer does not handle loss definition or choose default losses.
|
||||
"""
|
||||
if hasattr(self._model, "loss") and self._loss_func is not None:
|
||||
raise ValueError("Both the model and Trainer define loss. Please take out your loss.")
|
||||
# if hasattr(self._model, "loss") and self._loss_func is not None:
|
||||
# raise ValueError("Both the model and Trainer define loss. Please take out your loss.")
|
||||
|
||||
if hasattr(self._model, "loss"):
|
||||
self._loss_func = self._model.loss
|
||||
@ -322,9 +309,8 @@ class SeqLabelTrainer(Trainer):
|
||||
|
||||
"""
|
||||
def __init__(self, **kwargs):
|
||||
kwargs.update({"task": "seq_label"})
|
||||
print(
|
||||
"[FastNLP Warning] SeqLabelTrainer will be deprecated. Please use Trainer with argument 'task'='seq_label'.")
|
||||
"[FastNLP Warning] SeqLabelTrainer will be deprecated. Please use Trainer directly.")
|
||||
super(SeqLabelTrainer, self).__init__(**kwargs)
|
||||
|
||||
def _create_validator(self, valid_args):
|
||||
@ -335,9 +321,8 @@ class ClassificationTrainer(Trainer):
|
||||
"""Trainer for text classification."""
|
||||
|
||||
def __init__(self, **train_args):
|
||||
train_args.update({"task": "text_classify"})
|
||||
print(
|
||||
"[FastNLP Warning] ClassificationTrainer will be deprecated. Please use Trainer with argument 'task'='text_classify'.")
|
||||
"[FastNLP Warning] ClassificationTrainer will be deprecated. Please use Trainer directly.")
|
||||
super(ClassificationTrainer, self).__init__(**train_args)
|
||||
|
||||
def _create_validator(self, valid_args):
|
||||
|
@ -10,6 +10,7 @@ DEFAULT_WORD_TO_INDEX = {DEFAULT_PADDING_LABEL: 0, DEFAULT_UNKNOWN_LABEL: 1,
|
||||
DEFAULT_RESERVED_LABEL[0]: 2, DEFAULT_RESERVED_LABEL[1]: 3,
|
||||
DEFAULT_RESERVED_LABEL[2]: 4}
|
||||
|
||||
|
||||
def isiterable(p_object):
|
||||
try:
|
||||
it = iter(p_object)
|
||||
@ -17,6 +18,7 @@ def isiterable(p_object):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class Vocabulary(object):
|
||||
"""Use for word and index one to one mapping
|
||||
|
||||
@ -28,9 +30,11 @@ class Vocabulary(object):
|
||||
vocab["word"]
|
||||
vocab.to_word(5)
|
||||
"""
|
||||
|
||||
def __init__(self, need_default=True):
|
||||
"""
|
||||
:param bool need_default: set if the Vocabulary has default labels reserved.
|
||||
:param bool need_default: set if the Vocabulary has default labels reserved for sequences. Default: True.
|
||||
|
||||
"""
|
||||
if need_default:
|
||||
self.word2idx = deepcopy(DEFAULT_WORD_TO_INDEX)
|
||||
@ -50,7 +54,7 @@ class Vocabulary(object):
|
||||
def update(self, word):
|
||||
"""add word or list of words into Vocabulary
|
||||
|
||||
:param word: a list of str or str
|
||||
:param word: a list of string or a single string
|
||||
"""
|
||||
if not isinstance(word, str) and isiterable(word):
|
||||
# it's a nested list
|
||||
@ -63,7 +67,6 @@ class Vocabulary(object):
|
||||
if self.idx2word is not None:
|
||||
self.idx2word = None
|
||||
|
||||
|
||||
def __getitem__(self, w):
|
||||
"""To support usage like::
|
||||
|
||||
@ -119,6 +122,3 @@ class Vocabulary(object):
|
||||
"""
|
||||
self.__dict__.update(state)
|
||||
self.idx2word = None
|
||||
|
||||
|
||||
|
@ -4,6 +4,8 @@ from fastNLP.core.predictor import SeqLabelInfer, ClassificationInfer
|
||||
from fastNLP.core.preprocess import load_pickle
|
||||
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
|
||||
from fastNLP.loader.model_loader import ModelLoader
|
||||
from fastNLP.core.dataset import SeqLabelDataSet, TextClassifyDataSet
|
||||
|
||||
|
||||
"""
|
||||
mapping from model name to [URL, file_name.class_name, model_pickle_name]
|
||||
@ -76,6 +78,8 @@ class FastNLP(object):
|
||||
self.model_dir = model_dir
|
||||
self.model = None
|
||||
self.infer_type = None # "seq_label"/"text_class"
|
||||
self.word_vocab = None
|
||||
self.label_vocab = None
|
||||
|
||||
def load(self, model_name, config_file="config", section_name="model"):
|
||||
"""
|
||||
@ -100,10 +104,10 @@ class FastNLP(object):
|
||||
print("Restore model hyper-parameters {}".format(str(model_args.data)))
|
||||
|
||||
# fetch dictionary size and number of labels from pickle files
|
||||
word_vocab = load_pickle(self.model_dir, "word2id.pkl")
|
||||
model_args["vocab_size"] = len(word_vocab)
|
||||
label_vocab = load_pickle(self.model_dir, "class2id.pkl")
|
||||
model_args["num_classes"] = len(label_vocab)
|
||||
self.word_vocab = load_pickle(self.model_dir, "word2id.pkl")
|
||||
model_args["vocab_size"] = len(self.word_vocab)
|
||||
self.label_vocab = load_pickle(self.model_dir, "label2id.pkl")
|
||||
model_args["num_classes"] = len(self.label_vocab)
|
||||
|
||||
# Construct the model
|
||||
model = model_class(model_args)
|
||||
@ -130,8 +134,11 @@ class FastNLP(object):
|
||||
# tokenize: list of string ---> 2-D list of string
|
||||
infer_input = self.tokenize(raw_input, language="zh")
|
||||
|
||||
# 2-D list of string ---> 2-D list of tags
|
||||
results = infer.predict(self.model, infer_input)
|
||||
# create DataSet: 2-D list of strings ----> DataSet
|
||||
infer_data = self._create_data_set(infer_input)
|
||||
|
||||
# DataSet ---> 2-D list of tags
|
||||
results = infer.predict(self.model, infer_data)
|
||||
|
||||
# 2-D list of tags ---> list of final answers
|
||||
outputs = self._make_output(results, infer_input)
|
||||
@ -154,6 +161,11 @@ class FastNLP(object):
|
||||
return module
|
||||
|
||||
def _create_inference(self, model_dir):
|
||||
"""Specify which task to perform.
|
||||
|
||||
:param model_dir:
|
||||
:return:
|
||||
"""
|
||||
if self.infer_type == "seq_label":
|
||||
return SeqLabelInfer(model_dir)
|
||||
elif self.infer_type == "text_class":
|
||||
@ -161,6 +173,24 @@ class FastNLP(object):
|
||||
else:
|
||||
raise ValueError("fail to create inference instance")
|
||||
|
||||
def _create_data_set(self, infer_input):
|
||||
"""Create a DataSet object given the raw inputs.
|
||||
|
||||
:param infer_input: 2-D lists of strings
|
||||
:return data_set: a DataSet object
|
||||
"""
|
||||
if self.infer_type == "seq_label":
|
||||
data_set = SeqLabelDataSet()
|
||||
data_set.load_raw(infer_input, {"word_vocab": self.word_vocab})
|
||||
return data_set
|
||||
elif self.infer_type == "text_class":
|
||||
data_set = TextClassifyDataSet()
|
||||
data_set.load_raw(infer_input, {"word_vocab": self.word_vocab})
|
||||
return data_set
|
||||
else:
|
||||
raise RuntimeError("fail to make outputs with infer type {}".format(self.infer_type))
|
||||
|
||||
|
||||
def _load(self, model_dir, model_name):
|
||||
# To do
|
||||
return 0
|
||||
|
@ -1,27 +1,22 @@
|
||||
class BaseLoader(object):
|
||||
"""docstring for BaseLoader"""
|
||||
|
||||
def __init__(self, data_path):
|
||||
def __init__(self):
|
||||
super(BaseLoader, self).__init__()
|
||||
self.data_path = data_path
|
||||
|
||||
def load(self):
|
||||
"""
|
||||
:return: string
|
||||
"""
|
||||
with open(self.data_path, "r", encoding="utf-8") as f:
|
||||
text = f.read()
|
||||
return text
|
||||
|
||||
def load_lines(self):
|
||||
with open(self.data_path, "r", encoding="utf=8") as f:
|
||||
def load_lines(self, data_path):
|
||||
with open(data_path, "r", encoding="utf=8") as f:
|
||||
text = f.readlines()
|
||||
return [line.strip() for line in text]
|
||||
|
||||
def load(self, data_path):
|
||||
with open(data_path, "r", encoding="utf-8") as f:
|
||||
text = f.readlines()
|
||||
return [[word for word in sent.strip()] for sent in text]
|
||||
|
||||
|
||||
class ToyLoader0(BaseLoader):
|
||||
"""
|
||||
For charLM
|
||||
For CharLM
|
||||
"""
|
||||
|
||||
def __init__(self, data_path):
|
||||
|
@ -8,9 +8,9 @@ from fastNLP.loader.base_loader import BaseLoader
|
||||
class ConfigLoader(BaseLoader):
|
||||
"""loader for configuration files"""
|
||||
|
||||
def __int__(self, data_name, data_path):
|
||||
super(ConfigLoader, self).__init__(data_path)
|
||||
self.config = self.parse(super(ConfigLoader, self).load())
|
||||
def __int__(self, data_path):
|
||||
super(ConfigLoader, self).__init__()
|
||||
self.config = self.parse(super(ConfigLoader, self).load(data_path))
|
||||
|
||||
@staticmethod
|
||||
def parse(string):
|
||||
|
@ -3,14 +3,17 @@ import os
|
||||
from fastNLP.loader.base_loader import BaseLoader
|
||||
|
||||
|
||||
class DatasetLoader(BaseLoader):
|
||||
class DataSetLoader(BaseLoader):
|
||||
""""loader for data sets"""
|
||||
|
||||
def __init__(self, data_path):
|
||||
super(DatasetLoader, self).__init__(data_path)
|
||||
def __init__(self):
|
||||
super(DataSetLoader, self).__init__()
|
||||
|
||||
def load(self, path):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class POSDatasetLoader(DatasetLoader):
|
||||
class POSDataSetLoader(DataSetLoader):
|
||||
"""Dataset Loader for POS Tag datasets.
|
||||
|
||||
In these datasets, each line are divided by '\t'
|
||||
@ -31,16 +34,10 @@ class POSDatasetLoader(DatasetLoader):
|
||||
to label5.
|
||||
"""
|
||||
|
||||
def __init__(self, data_path):
|
||||
super(POSDatasetLoader, self).__init__(data_path)
|
||||
def __init__(self):
|
||||
super(POSDataSetLoader, self).__init__()
|
||||
|
||||
def load(self):
|
||||
assert os.path.exists(self.data_path)
|
||||
with open(self.data_path, "r", encoding="utf-8") as f:
|
||||
line = f.read()
|
||||
return line
|
||||
|
||||
def load_lines(self):
|
||||
def load(self, data_path):
|
||||
"""
|
||||
:return data: three-level list
|
||||
[
|
||||
@ -49,7 +46,7 @@ class POSDatasetLoader(DatasetLoader):
|
||||
...
|
||||
]
|
||||
"""
|
||||
with open(self.data_path, "r", encoding="utf-8") as f:
|
||||
with open(data_path, "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
return self.parse(lines)
|
||||
|
||||
@ -79,15 +76,15 @@ class POSDatasetLoader(DatasetLoader):
|
||||
return data
|
||||
|
||||
|
||||
class TokenizeDatasetLoader(DatasetLoader):
|
||||
class TokenizeDataSetLoader(DataSetLoader):
|
||||
"""
|
||||
Data set loader for tokenization data sets
|
||||
"""
|
||||
|
||||
def __init__(self, data_path):
|
||||
super(TokenizeDatasetLoader, self).__init__(data_path)
|
||||
def __init__(self):
|
||||
super(TokenizeDataSetLoader, self).__init__()
|
||||
|
||||
def load_pku(self, max_seq_len=32):
|
||||
def load(self, data_path, max_seq_len=32):
|
||||
"""
|
||||
load pku dataset for Chinese word segmentation
|
||||
CWS (Chinese Word Segmentation) pku training dataset format:
|
||||
@ -104,7 +101,7 @@ class TokenizeDatasetLoader(DatasetLoader):
|
||||
:return: three-level lists
|
||||
"""
|
||||
assert isinstance(max_seq_len, int) and max_seq_len > 0
|
||||
with open(self.data_path, "r", encoding="utf-8") as f:
|
||||
with open(data_path, "r", encoding="utf-8") as f:
|
||||
sentences = f.readlines()
|
||||
data = []
|
||||
for sent in sentences:
|
||||
@ -135,15 +132,15 @@ class TokenizeDatasetLoader(DatasetLoader):
|
||||
return data
|
||||
|
||||
|
||||
class ClassDatasetLoader(DatasetLoader):
|
||||
class ClassDataSetLoader(DataSetLoader):
|
||||
"""Loader for classification data sets"""
|
||||
|
||||
def __init__(self, data_path):
|
||||
super(ClassDatasetLoader, self).__init__(data_path)
|
||||
def __init__(self):
|
||||
super(ClassDataSetLoader, self).__init__()
|
||||
|
||||
def load(self):
|
||||
assert os.path.exists(self.data_path)
|
||||
with open(self.data_path, "r", encoding="utf-8") as f:
|
||||
def load(self, data_path):
|
||||
assert os.path.exists(data_path)
|
||||
with open(data_path, "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
return self.parse(lines)
|
||||
|
||||
@ -169,21 +166,21 @@ class ClassDatasetLoader(DatasetLoader):
|
||||
return dataset
|
||||
|
||||
|
||||
class ConllLoader(DatasetLoader):
|
||||
class ConllLoader(DataSetLoader):
|
||||
"""loader for conll format files"""
|
||||
|
||||
def __int__(self, data_path):
|
||||
"""
|
||||
:param str data_path: the path to the conll data set
|
||||
"""
|
||||
super(ConllLoader, self).__init__(data_path)
|
||||
self.data_set = self.parse(self.load())
|
||||
super(ConllLoader, self).__init__()
|
||||
self.data_set = self.parse(self.load(data_path))
|
||||
|
||||
def load(self):
|
||||
def load(self, data_path):
|
||||
"""
|
||||
:return: list lines: all lines in a conll file
|
||||
"""
|
||||
with open(self.data_path, "r", encoding="utf-8") as f:
|
||||
with open(data_path, "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
return lines
|
||||
|
||||
@ -207,28 +204,48 @@ class ConllLoader(DatasetLoader):
|
||||
return sentences
|
||||
|
||||
|
||||
class LMDatasetLoader(DatasetLoader):
|
||||
def __init__(self, data_path):
|
||||
super(LMDatasetLoader, self).__init__(data_path)
|
||||
class LMDataSetLoader(DataSetLoader):
|
||||
"""Language Model Dataset Loader
|
||||
|
||||
def load(self):
|
||||
if not os.path.exists(self.data_path):
|
||||
raise FileNotFoundError("file {} not found.".format(self.data_path))
|
||||
with open(self.data_path, "r", encoding="utf=8") as f:
|
||||
This loader produces data for language model training in a supervised way.
|
||||
That means it has X and Y.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(LMDataSetLoader, self).__init__()
|
||||
|
||||
def load(self, data_path):
|
||||
if not os.path.exists(data_path):
|
||||
raise FileNotFoundError("file {} not found.".format(data_path))
|
||||
with open(data_path, "r", encoding="utf=8") as f:
|
||||
text = " ".join(f.readlines())
|
||||
return text.strip().split()
|
||||
tokens = text.strip().split()
|
||||
return self.sentence_cut(tokens)
|
||||
|
||||
def sentence_cut(self, tokens, sentence_length=15):
|
||||
start_idx = 0
|
||||
data_set = []
|
||||
for idx in range(len(tokens) // sentence_length):
|
||||
x = tokens[start_idx * idx: start_idx * idx + sentence_length]
|
||||
y = tokens[start_idx * idx + 1: start_idx * idx + sentence_length + 1]
|
||||
if start_idx * idx + sentence_length + 1 >= len(tokens):
|
||||
# ad hoc
|
||||
y.extend(["<unk>"])
|
||||
data_set.append([x, y])
|
||||
return data_set
|
||||
|
||||
|
||||
class PeopleDailyCorpusLoader(DatasetLoader):
|
||||
class PeopleDailyCorpusLoader(DataSetLoader):
|
||||
"""
|
||||
People Daily Corpus: Chinese word segmentation, POS tag, NER
|
||||
"""
|
||||
|
||||
def __init__(self, data_path):
|
||||
super(PeopleDailyCorpusLoader, self).__init__(data_path)
|
||||
def __init__(self):
|
||||
super(PeopleDailyCorpusLoader, self).__init__()
|
||||
|
||||
def load(self):
|
||||
with open(self.data_path, "r", encoding="utf-8") as f:
|
||||
def load(self, data_path):
|
||||
with open(data_path, "r", encoding="utf-8") as f:
|
||||
sents = f.readlines()
|
||||
|
||||
pos_tag_examples = []
|
||||
|
@ -1,215 +1,8 @@
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from torch.autograd import Variable
|
||||
|
||||
from fastNLP.models.base_model import BaseModel
|
||||
|
||||
USE_GPU = True
|
||||
|
||||
"""
|
||||
To be deprecated.
|
||||
"""
|
||||
|
||||
|
||||
class CharLM(BaseModel):
|
||||
"""
|
||||
Controller of the Character-level Neural Language Model
|
||||
"""
|
||||
def __init__(self, lstm_batch_size, lstm_seq_len):
|
||||
super(CharLM, self).__init__()
|
||||
"""
|
||||
Settings: should come from config loader or pre-processing
|
||||
"""
|
||||
self.word_embed_dim = 300
|
||||
self.char_embedding_dim = 15
|
||||
self.cnn_batch_size = lstm_batch_size * lstm_seq_len
|
||||
self.lstm_seq_len = lstm_seq_len
|
||||
self.lstm_batch_size = lstm_batch_size
|
||||
self.num_epoch = 10
|
||||
self.old_PPL = 100000
|
||||
self.best_PPL = 100000
|
||||
|
||||
"""
|
||||
These parameters are set by pre-processing.
|
||||
"""
|
||||
self.max_word_len = None
|
||||
self.num_char = None
|
||||
self.vocab_size = None
|
||||
self.preprocess("./data_for_tests/charlm.txt")
|
||||
|
||||
self.data = None # named tuple to store all data set
|
||||
self.data_ready = False
|
||||
self.criterion = nn.CrossEntropyLoss()
|
||||
self._loss = None
|
||||
self.use_gpu = USE_GPU
|
||||
|
||||
# word_emb_dim == hidden_size / num of hidden units
|
||||
self.hidden = (to_var(torch.zeros(2, self.lstm_batch_size, self.word_embed_dim)),
|
||||
to_var(torch.zeros(2, self.lstm_batch_size, self.word_embed_dim)))
|
||||
|
||||
self.model = charLM(self.char_embedding_dim,
|
||||
self.word_embed_dim,
|
||||
self.vocab_size,
|
||||
self.num_char,
|
||||
use_gpu=self.use_gpu)
|
||||
for param in self.model.parameters():
|
||||
nn.init.uniform(param.data, -0.05, 0.05)
|
||||
|
||||
self.learning_rate = 0.1
|
||||
self.optimizer = None
|
||||
|
||||
def prepare_input(self, raw_text):
|
||||
"""
|
||||
:param raw_text: raw input text consisting of words
|
||||
:return: torch.Tensor, torch.Tensor
|
||||
feature matrix, label vector
|
||||
This function is only called once in Trainer.train, but may called multiple times in Tester.test
|
||||
So Tester will save test input for frequent calls.
|
||||
"""
|
||||
if os.path.exists("cache/prep.pt") is False:
|
||||
self.preprocess("./data_for_tests/charlm.txt") # To do: This is not good. Need to fix..
|
||||
objects = torch.load("cache/prep.pt")
|
||||
word_dict = objects["word_dict"]
|
||||
char_dict = objects["char_dict"]
|
||||
max_word_len = self.max_word_len
|
||||
print("word/char dictionary built. Start making inputs.")
|
||||
|
||||
words = raw_text
|
||||
input_vec = np.array(text2vec(words, char_dict, max_word_len))
|
||||
# Labels are next-word index in word_dict with the same length as inputs
|
||||
input_label = np.array([word_dict[w] for w in words[1:]] + [word_dict[words[-1]]])
|
||||
feature_input = torch.from_numpy(input_vec)
|
||||
label_input = torch.from_numpy(input_label)
|
||||
return feature_input, label_input
|
||||
|
||||
def mode(self, test=False):
|
||||
if test:
|
||||
self.model.eval()
|
||||
else:
|
||||
self.model.train()
|
||||
|
||||
def data_forward(self, x):
|
||||
"""
|
||||
:param x: Tensor of size [lstm_batch_size, lstm_seq_len, max_word_len+2]
|
||||
:return: Tensor of size [num_words, ?]
|
||||
"""
|
||||
# additional processing of inputs after batching
|
||||
num_seq = x.size()[0] // self.lstm_seq_len
|
||||
x = x[:num_seq * self.lstm_seq_len, :]
|
||||
x = x.view(-1, self.lstm_seq_len, self.max_word_len + 2)
|
||||
|
||||
# detach hidden state of LSTM from last batch
|
||||
hidden = [state.detach() for state in self.hidden]
|
||||
output, self.hidden = self.model(to_var(x), hidden)
|
||||
return output
|
||||
|
||||
def grad_backward(self):
|
||||
self.model.zero_grad()
|
||||
self._loss.backward()
|
||||
torch.nn.utils.clip_grad_norm(self.model.parameters(), 5, norm_type=2)
|
||||
self.optimizer.step()
|
||||
|
||||
def get_loss(self, predict, truth):
|
||||
self._loss = self.criterion(predict, to_var(truth))
|
||||
return self._loss.data # No pytorch data structure exposed outsides
|
||||
|
||||
def define_optimizer(self):
|
||||
# redefine optimizer for every new epoch
|
||||
self.optimizer = optim.SGD(self.model.parameters(), lr=self.learning_rate, momentum=0.85)
|
||||
|
||||
def save(self):
|
||||
print("network saved")
|
||||
# torch.save(self.models, "cache/models.pkl")
|
||||
|
||||
def preprocess(self, all_text_files):
|
||||
word_dict, char_dict = create_word_char_dict(all_text_files)
|
||||
num_char = len(char_dict)
|
||||
self.vocab_size = len(word_dict)
|
||||
char_dict["BOW"] = num_char + 1
|
||||
char_dict["EOW"] = num_char + 2
|
||||
char_dict["PAD"] = 0
|
||||
self.num_char = num_char + 3
|
||||
# char_dict is a dict of (int, string), int counting from 0 to 47
|
||||
reverse_word_dict = {value: key for key, value in word_dict.items()}
|
||||
self.max_word_len = max([len(word) for word in word_dict])
|
||||
objects = {
|
||||
"word_dict": word_dict,
|
||||
"char_dict": char_dict,
|
||||
"reverse_word_dict": reverse_word_dict,
|
||||
}
|
||||
if not os.path.exists("cache"):
|
||||
os.mkdir("cache")
|
||||
torch.save(objects, "cache/prep.pt")
|
||||
print("Preprocess done.")
|
||||
|
||||
|
||||
"""
|
||||
Global Functions
|
||||
"""
|
||||
|
||||
|
||||
def batch_generator(x, batch_size):
|
||||
# x: [num_words, in_channel, height, width]
|
||||
# partitions x into batches
|
||||
num_step = x.size()[0] // batch_size
|
||||
for t in range(num_step):
|
||||
yield x[t * batch_size:(t + 1) * batch_size]
|
||||
|
||||
|
||||
def text2vec(words, char_dict, max_word_len):
|
||||
""" Return list of list of int """
|
||||
word_vec = []
|
||||
for word in words:
|
||||
vec = [char_dict[ch] for ch in word]
|
||||
if len(vec) < max_word_len:
|
||||
vec += [char_dict["PAD"] for _ in range(max_word_len - len(vec))]
|
||||
vec = [char_dict["BOW"]] + vec + [char_dict["EOW"]]
|
||||
word_vec.append(vec)
|
||||
return word_vec
|
||||
|
||||
|
||||
def read_data(file_name):
|
||||
with open(file_name, 'r') as f:
|
||||
corpus = f.read().lower()
|
||||
import re
|
||||
corpus = re.sub(r"<unk>", "unk", corpus)
|
||||
return corpus.split()
|
||||
|
||||
|
||||
def get_char_dict(vocabulary):
|
||||
char_dict = dict()
|
||||
count = 1
|
||||
for word in vocabulary:
|
||||
for ch in word:
|
||||
if ch not in char_dict:
|
||||
char_dict[ch] = count
|
||||
count += 1
|
||||
return char_dict
|
||||
|
||||
|
||||
def create_word_char_dict(*file_name):
|
||||
text = []
|
||||
for file in file_name:
|
||||
text += read_data(file)
|
||||
word_dict = {word: ix for ix, word in enumerate(set(text))}
|
||||
char_dict = get_char_dict(word_dict)
|
||||
return word_dict, char_dict
|
||||
|
||||
|
||||
def to_var(x):
|
||||
if torch.cuda.is_available() and USE_GPU:
|
||||
x = x.cuda()
|
||||
return Variable(x)
|
||||
|
||||
|
||||
"""
|
||||
Neural Network
|
||||
"""
|
||||
from fastNLP.modules.encoder.lstm import LSTM
|
||||
|
||||
|
||||
class Highway(nn.Module):
|
||||
@ -225,9 +18,8 @@ class Highway(nn.Module):
|
||||
return torch.mul(t, F.relu(self.fc2(x))) + torch.mul(1 - t, x)
|
||||
|
||||
|
||||
class charLM(nn.Module):
|
||||
"""Character-level Neural Language Model
|
||||
CNN + highway network + LSTM
|
||||
class CharLM(nn.Module):
|
||||
"""CNN + highway network + LSTM
|
||||
# Input:
|
||||
4D tensor with shape [batch_size, in_channel, height, width]
|
||||
# Output:
|
||||
@ -241,8 +33,8 @@ class charLM(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(self, char_emb_dim, word_emb_dim,
|
||||
vocab_size, num_char, use_gpu):
|
||||
super(charLM, self).__init__()
|
||||
vocab_size, num_char):
|
||||
super(CharLM, self).__init__()
|
||||
self.char_emb_dim = char_emb_dim
|
||||
self.word_emb_dim = word_emb_dim
|
||||
self.vocab_size = vocab_size
|
||||
@ -254,8 +46,7 @@ class charLM(nn.Module):
|
||||
self.convolutions = []
|
||||
|
||||
# list of tuples: (the number of filter, width)
|
||||
# self.filter_num_width = [(25, 1), (50, 2), (75, 3), (100, 4), (125, 5), (150, 6)]
|
||||
self.filter_num_width = [(25, 1), (50, 2), (75, 3)]
|
||||
self.filter_num_width = [(25, 1), (50, 2), (75, 3), (100, 4), (125, 5), (150, 6)]
|
||||
|
||||
for out_channel, filter_width in self.filter_num_width:
|
||||
self.convolutions.append(
|
||||
@ -278,29 +69,13 @@ class charLM(nn.Module):
|
||||
# LSTM
|
||||
self.lstm_num_layers = 2
|
||||
|
||||
self.lstm = nn.LSTM(input_size=self.highway_input_dim,
|
||||
hidden_size=self.word_emb_dim,
|
||||
num_layers=self.lstm_num_layers,
|
||||
bias=True,
|
||||
dropout=0.5,
|
||||
batch_first=True)
|
||||
|
||||
self.lstm = LSTM(self.highway_input_dim, hidden_size=self.word_emb_dim, num_layers=self.lstm_num_layers,
|
||||
dropout=0.5)
|
||||
# output layer
|
||||
self.dropout = nn.Dropout(p=0.5)
|
||||
self.linear = nn.Linear(self.word_emb_dim, self.vocab_size)
|
||||
|
||||
if use_gpu is True:
|
||||
for x in range(len(self.convolutions)):
|
||||
self.convolutions[x] = self.convolutions[x].cuda()
|
||||
self.highway1 = self.highway1.cuda()
|
||||
self.highway2 = self.highway2.cuda()
|
||||
self.lstm = self.lstm.cuda()
|
||||
self.dropout = self.dropout.cuda()
|
||||
self.char_embed = self.char_embed.cuda()
|
||||
self.linear = self.linear.cuda()
|
||||
self.batch_norm = self.batch_norm.cuda()
|
||||
|
||||
def forward(self, x, hidden):
|
||||
def forward(self, x):
|
||||
# Input: Variable of Tensor with shape [num_seq, seq_len, max_word_len+2]
|
||||
# Return: Variable of Tensor with shape [num_words, len(word_dict)]
|
||||
lstm_batch_size = x.size()[0]
|
||||
@ -313,7 +88,7 @@ class charLM(nn.Module):
|
||||
# [num_seq*seq_len, max_word_len+2, char_emb_dim]
|
||||
|
||||
x = torch.transpose(x.view(x.size()[0], 1, x.size()[1], -1), 2, 3)
|
||||
# [num_seq*seq_len, 1, char_emb_dim, max_word_len+2]
|
||||
# [num_seq*seq_len, 1, max_word_len+2, char_emb_dim]
|
||||
|
||||
x = self.conv_layers(x)
|
||||
# [num_seq*seq_len, total_num_filters]
|
||||
@ -328,7 +103,7 @@ class charLM(nn.Module):
|
||||
x = x.contiguous().view(lstm_batch_size, lstm_seq_len, -1)
|
||||
# [num_seq, seq_len, total_num_filters]
|
||||
|
||||
x, hidden = self.lstm(x, hidden)
|
||||
x, hidden = self.lstm(x)
|
||||
# [seq_len, num_seq, hidden_size]
|
||||
|
||||
x = self.dropout(x)
|
||||
@ -339,7 +114,7 @@ class charLM(nn.Module):
|
||||
|
||||
x = self.linear(x)
|
||||
# [num_seq*seq_len, vocab_size]
|
||||
return x, hidden
|
||||
return x
|
||||
|
||||
def conv_layers(self, x):
|
||||
chosen_list = list()
|
||||
|
@ -31,16 +31,18 @@ class SeqLabeling(BaseModel):
|
||||
num_classes = args["num_classes"]
|
||||
|
||||
self.Embedding = encoder.embedding.Embedding(vocab_size, word_emb_dim)
|
||||
self.Rnn = encoder.lstm.Lstm(word_emb_dim, hidden_dim)
|
||||
self.Rnn = encoder.lstm.LSTM(word_emb_dim, hidden_dim)
|
||||
self.Linear = encoder.linear.Linear(hidden_dim, num_classes)
|
||||
self.Crf = decoder.CRF.ConditionalRandomField(num_classes)
|
||||
self.mask = None
|
||||
|
||||
def forward(self, word_seq, word_seq_origin_len):
|
||||
def forward(self, word_seq, word_seq_origin_len, truth=None):
|
||||
"""
|
||||
:param word_seq: LongTensor, [batch_size, mex_len]
|
||||
:param word_seq_origin_len: LongTensor, [batch_size,], the origin lengths of the sequences.
|
||||
:return y: [batch_size, mex_len, tag_size]
|
||||
:param truth: LongTensor, [batch_size, max_len]
|
||||
:return y: If truth is None, return list of [decode path(list)]. Used in testing and predicting.
|
||||
If truth is not None, return loss, a scalar. Used in training.
|
||||
"""
|
||||
self.mask = self.make_mask(word_seq, word_seq_origin_len)
|
||||
|
||||
@ -50,9 +52,16 @@ class SeqLabeling(BaseModel):
|
||||
# [batch_size, max_len, hidden_size * direction]
|
||||
x = self.Linear(x)
|
||||
# [batch_size, max_len, num_classes]
|
||||
return x
|
||||
if truth is not None:
|
||||
return self._internal_loss(x, truth)
|
||||
else:
|
||||
return self.decode(x)
|
||||
|
||||
def loss(self, x, y):
|
||||
""" Since the loss has been computed in forward(), this function simply returns x."""
|
||||
return x
|
||||
|
||||
def _internal_loss(self, x, y):
|
||||
"""
|
||||
Negative log likelihood loss.
|
||||
:param x: Tensor, [batch_size, max_len, tag_size]
|
||||
@ -74,12 +83,19 @@ class SeqLabeling(BaseModel):
|
||||
mask = mask.to(x)
|
||||
return mask
|
||||
|
||||
def prediction(self, x):
|
||||
def decode(self, x, pad=True):
|
||||
"""
|
||||
:param x: FloatTensor, [batch_size, max_len, tag_size]
|
||||
:param pad: pad the output sequence to equal lengths
|
||||
:return prediction: list of [decode path(list)]
|
||||
"""
|
||||
max_len = x.shape[1]
|
||||
tag_seq = self.Crf.viterbi_decode(x, self.mask)
|
||||
# pad prediction to equal length
|
||||
if pad is True:
|
||||
for pred in tag_seq:
|
||||
if len(pred) < max_len:
|
||||
pred += [0] * (max_len - len(pred))
|
||||
return tag_seq
|
||||
|
||||
|
||||
@ -97,7 +113,7 @@ class AdvSeqLabel(SeqLabeling):
|
||||
num_classes = args["num_classes"]
|
||||
|
||||
self.Embedding = encoder.embedding.Embedding(vocab_size, word_emb_dim, init_emb=emb)
|
||||
self.Rnn = encoder.lstm.Lstm(word_emb_dim, hidden_dim, num_layers=3, dropout=0.3, bidirectional=True)
|
||||
self.Rnn = encoder.lstm.LSTM(word_emb_dim, hidden_dim, num_layers=3, dropout=0.3, bidirectional=True)
|
||||
self.Linear1 = encoder.Linear(hidden_dim * 2, hidden_dim * 2 // 3)
|
||||
self.batch_norm = torch.nn.BatchNorm1d(hidden_dim * 2 // 3)
|
||||
self.relu = torch.nn.ReLU()
|
||||
@ -106,11 +122,12 @@ class AdvSeqLabel(SeqLabeling):
|
||||
|
||||
self.Crf = decoder.CRF.ConditionalRandomField(num_classes)
|
||||
|
||||
def forward(self, word_seq, word_seq_origin_len):
|
||||
def forward(self, word_seq, word_seq_origin_len, truth=None):
|
||||
"""
|
||||
:param word_seq: LongTensor, [batch_size, mex_len]
|
||||
:param word_seq_origin_len: list of int.
|
||||
:return y: [batch_size, mex_len, tag_size]
|
||||
:param truth: LongTensor, [batch_size, max_len]
|
||||
:return y:
|
||||
"""
|
||||
self.mask = self.make_mask(word_seq, word_seq_origin_len)
|
||||
|
||||
@ -129,4 +146,7 @@ class AdvSeqLabel(SeqLabeling):
|
||||
x = self.Linear2(x)
|
||||
x = x.view(batch_size, max_len, -1)
|
||||
# [batch_size, max_len, num_classes]
|
||||
return x
|
||||
if truth is not None:
|
||||
return self._internal_loss(x, truth)
|
||||
else:
|
||||
return self.decode(x)
|
||||
|
@ -55,14 +55,13 @@ class SelfAttention(nn.Module):
|
||||
input = input.contiguous()
|
||||
size = input.size() # [bsz, len, nhid]
|
||||
|
||||
|
||||
input_origin = input_origin.expand(self.attention_hops, -1, -1) # [hops,baz, len]
|
||||
input_origin = input_origin.transpose(0, 1).contiguous() # [baz, hops,len]
|
||||
|
||||
y1 = self.tanh(self.ws1(self.drop(input))) # [baz,len,dim] -->[bsz,len, attention-unit]
|
||||
attention = self.ws2(y1).transpose(1,2).contiguous() # [bsz,len, attention-unit]--> [bsz, len, hop]--> [baz,hop,len]
|
||||
attention = self.ws2(y1).transpose(1,
|
||||
2).contiguous() # [bsz,len, attention-unit]--> [bsz, len, hop]--> [baz,hop,len]
|
||||
|
||||
attention = attention + (-999999 * (input_origin == 0).float()) # remove the weight on padding token.
|
||||
attention = F.softmax(attention, 2) # [baz ,hop, len]
|
||||
return torch.bmm(attention, input), self.penalization(attention) # output1 --> [baz ,hop ,nhid]
|
||||
|
||||
|
@ -1,10 +1,10 @@
|
||||
from .embedding import Embedding
|
||||
from .linear import Linear
|
||||
from .lstm import Lstm
|
||||
from .conv import Conv
|
||||
from .conv_maxpool import ConvMaxpool
|
||||
from .embedding import Embedding
|
||||
from .linear import Linear
|
||||
from .lstm import LSTM
|
||||
|
||||
__all__ = ["Lstm",
|
||||
__all__ = ["LSTM",
|
||||
"Embedding",
|
||||
"Linear",
|
||||
"Conv",
|
||||
|
@ -1,9 +1,10 @@
|
||||
import torch.nn as nn
|
||||
|
||||
from fastNLP.modules.utils import initial_parameter
|
||||
class Lstm(nn.Module):
|
||||
"""
|
||||
LSTM module
|
||||
|
||||
|
||||
class LSTM(nn.Module):
|
||||
"""Long Short Term Memory
|
||||
|
||||
Args:
|
||||
input_size : input size
|
||||
@ -13,13 +14,17 @@ class Lstm(nn.Module):
|
||||
bidirectional : If True, becomes a bidirectional RNN. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0, bidirectional=False , initial_method = None):
|
||||
super(Lstm, self).__init__()
|
||||
def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0.0, bidirectional=False,
|
||||
initial_method=None):
|
||||
super(LSTM, self).__init__()
|
||||
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=True, batch_first=True,
|
||||
dropout=dropout, bidirectional=bidirectional)
|
||||
initial_parameter(self, initial_method)
|
||||
|
||||
def forward(self, x):
|
||||
x, _ = self.lstm(x)
|
||||
return x
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
lstm = Lstm(10)
|
||||
lstm = LSTM(10)
|
||||
|
@ -18,7 +18,7 @@ class ConfigSaver(object):
|
||||
:return: The section.
|
||||
"""
|
||||
sect = ConfigSection()
|
||||
ConfigLoader(self.file_path).load_config(self.file_path, {sect_name: sect})
|
||||
ConfigLoader().load_config(self.file_path, {sect_name: sect})
|
||||
return sect
|
||||
|
||||
def _read_section(self):
|
||||
|
25
reproduction/Char-aware_NLM/main.py
Normal file
25
reproduction/Char-aware_NLM/main.py
Normal file
@ -0,0 +1,25 @@
|
||||
from fastNLP.core.loss import Loss
|
||||
from fastNLP.core.preprocess import Preprocessor
|
||||
from fastNLP.core.trainer import Trainer
|
||||
from fastNLP.loader.dataset_loader import LMDataSetLoader
|
||||
from fastNLP.models.char_language_model import CharLM
|
||||
|
||||
PICKLE = "./save/"
|
||||
|
||||
|
||||
def train():
|
||||
loader = LMDataSetLoader()
|
||||
train_data = loader.load()
|
||||
|
||||
pre = Preprocessor(label_is_seq=True, share_vocab=True)
|
||||
train_set = pre.run(train_data, pickle_path=PICKLE)
|
||||
|
||||
model = CharLM(50, 50, pre.vocab_size, pre.char_vocab_size)
|
||||
|
||||
trainer = Trainer(task="language_model", loss=Loss("cross_entropy"))
|
||||
|
||||
trainer.train(model, train_set)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -4,12 +4,12 @@ from fastNLP.core.preprocess import ClassPreprocess as Preprocess
|
||||
from fastNLP.core.trainer import ClassificationTrainer
|
||||
from fastNLP.loader.config_loader import ConfigLoader
|
||||
from fastNLP.loader.config_loader import ConfigSection
|
||||
from fastNLP.loader.dataset_loader import ClassDatasetLoader as Dataset_loader
|
||||
from fastNLP.loader.dataset_loader import ClassDataSetLoader as Dataset_loader
|
||||
from fastNLP.models.base_model import BaseModel
|
||||
from fastNLP.modules.aggregator.self_attention import SelfAttention
|
||||
from fastNLP.modules.decoder.MLP import MLP
|
||||
from fastNLP.modules.encoder.embedding import Embedding as Embedding
|
||||
from fastNLP.modules.encoder.lstm import Lstm
|
||||
from fastNLP.modules.encoder.lstm import LSTM
|
||||
|
||||
train_data_path = 'small_train_data.txt'
|
||||
dev_data_path = 'small_dev_data.txt'
|
||||
@ -43,7 +43,7 @@ class SELF_ATTENTION_YELP_CLASSIFICATION(BaseModel):
|
||||
def __init__(self, args=None):
|
||||
super(SELF_ATTENTION_YELP_CLASSIFICATION,self).__init__()
|
||||
self.embedding = Embedding(len(word2index) ,embeding_size , init_emb= None )
|
||||
self.lstm = Lstm(input_size = embeding_size,hidden_size = lstm_hidden_size ,bidirectional = True)
|
||||
self.lstm = LSTM(input_size=embeding_size, hidden_size=lstm_hidden_size, bidirectional=True)
|
||||
self.attention = SelfAttention(lstm_hidden_size * 2 ,dim =attention_unit ,num_vec=attention_hops)
|
||||
self.mlp = MLP(size_layer=[lstm_hidden_size * 2*attention_hops ,nfc ,class_num ])
|
||||
def forward(self,x):
|
||||
|
@ -5,7 +5,7 @@ sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))
|
||||
|
||||
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
|
||||
from fastNLP.core.trainer import SeqLabelTrainer
|
||||
from fastNLP.loader.dataset_loader import TokenizeDatasetLoader, BaseLoader
|
||||
from fastNLP.loader.dataset_loader import TokenizeDataSetLoader, BaseLoader
|
||||
from fastNLP.core.preprocess import SeqLabelPreprocess, load_pickle
|
||||
from fastNLP.saver.model_saver import ModelSaver
|
||||
from fastNLP.loader.model_loader import ModelLoader
|
||||
@ -66,8 +66,8 @@ def train():
|
||||
ConfigLoader("good_path").load_config(cfgfile, {"train": train_args, "test": test_args})
|
||||
|
||||
# Data Loader
|
||||
loader = TokenizeDatasetLoader(cws_data_path)
|
||||
train_data = loader.load_pku()
|
||||
loader = TokenizeDataSetLoader()
|
||||
train_data = loader.load()
|
||||
|
||||
# Preprocessor
|
||||
preprocessor = SeqLabelPreprocess()
|
||||
|
@ -66,7 +66,7 @@ def train():
|
||||
ConfigLoader("good_name").load_config(cfgfile, {"train": train_args, "test": test_args})
|
||||
|
||||
# Data Loader
|
||||
loader = PeopleDailyCorpusLoader(pos_tag_data_path)
|
||||
loader = PeopleDailyCorpusLoader()
|
||||
train_data, _ = loader.load()
|
||||
|
||||
# Preprocessor
|
||||
|
@ -43,8 +43,10 @@ class TestCase1(unittest.TestCase):
|
||||
|
||||
# use batch to iterate dataset
|
||||
data_iterator = Batch(data, 2, SeqSampler(), False)
|
||||
total_data = 0
|
||||
for batch_x, batch_y in data_iterator:
|
||||
self.assertEqual(len(batch_x), 2)
|
||||
total_data += batch_x["text"].size(0)
|
||||
self.assertTrue(batch_x["text"].size(0) == 2 or total_data == len(raw_texts))
|
||||
self.assertTrue(isinstance(batch_x, dict))
|
||||
self.assertTrue(isinstance(batch_x["text"], torch.LongTensor))
|
||||
self.assertTrue(isinstance(batch_y, dict))
|
||||
|
@ -1,20 +1,42 @@
|
||||
import sys, os
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path = [os.path.join(os.path.dirname(__file__), '..')] + sys.path
|
||||
|
||||
from fastNLP.core import metrics
|
||||
# from sklearn import metrics as skmetrics
|
||||
import unittest
|
||||
import numpy as np
|
||||
from numpy import random
|
||||
from fastNLP.core.metrics import SeqLabelEvaluator
|
||||
import torch
|
||||
|
||||
|
||||
def generate_fake_label(low, high, size):
|
||||
return random.randint(low, high, size), random.randint(low, high, size)
|
||||
|
||||
|
||||
class TestEvaluator(unittest.TestCase):
|
||||
def test_a(self):
|
||||
evaluator = SeqLabelEvaluator()
|
||||
pred = [[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]]
|
||||
truth = [{"truth": torch.LongTensor([1, 2, 3, 3, 3])}, {"truth": torch.LongTensor([1, 2, 3, 3, 4])}]
|
||||
ans = evaluator(pred, truth)
|
||||
print(ans)
|
||||
|
||||
def test_b(self):
|
||||
evaluator = SeqLabelEvaluator()
|
||||
pred = [[1, 2, 3, 4, 5, 0, 0], [1, 2, 3, 4, 5, 0, 0]]
|
||||
truth = [{"truth": torch.LongTensor([1, 2, 3, 3, 3, 0, 0])}, {"truth": torch.LongTensor([1, 2, 3, 3, 4, 0, 0])}]
|
||||
ans = evaluator(pred, truth)
|
||||
print(ans)
|
||||
|
||||
|
||||
class TestMetrics(unittest.TestCase):
|
||||
delta = 1e-5
|
||||
# test for binary, multiclass, multilabel
|
||||
data_types = [((1000,), 2), ((1000,), 10), ((1000, 10), 2)]
|
||||
fake_data = [generate_fake_label(0, high, shape) for shape, high in data_types]
|
||||
|
||||
def test_accuracy_score(self):
|
||||
for y_true, y_pred in self.fake_data:
|
||||
for normalize in [True, False]:
|
||||
@ -73,5 +95,6 @@ class TestMetrics(unittest.TestCase):
|
||||
# ans = skmetrics.f1_score(y_true, y_pred)
|
||||
# self.assertAlmostEqual(ans, test, delta=self.delta)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
@ -2,9 +2,12 @@ import os
|
||||
import unittest
|
||||
|
||||
from fastNLP.core.predictor import Predictor
|
||||
from fastNLP.core.dataset import TextClassifyDataSet, SeqLabelDataSet
|
||||
from fastNLP.core.preprocess import save_pickle
|
||||
from fastNLP.models.sequence_modeling import SeqLabeling
|
||||
from fastNLP.core.vocabulary import Vocabulary
|
||||
from fastNLP.loader.base_loader import BaseLoader
|
||||
from fastNLP.models.sequence_modeling import SeqLabeling
|
||||
from fastNLP.models.cnn_text_classification import CNNText
|
||||
|
||||
|
||||
class TestPredictor(unittest.TestCase):
|
||||
@ -31,20 +34,41 @@ class TestPredictor(unittest.TestCase):
|
||||
class_vocab.word2idx = {"0": 0, "1": 1, "2": 2, "3": 3, "4": 4}
|
||||
|
||||
os.system("mkdir save")
|
||||
save_pickle(class_vocab, "./save/", "class2id.pkl")
|
||||
save_pickle(class_vocab, "./save/", "label2id.pkl")
|
||||
save_pickle(vocab, "./save/", "word2id.pkl")
|
||||
|
||||
model = SeqLabeling(model_args)
|
||||
predictor = Predictor("./save/", task="seq_label")
|
||||
model = CNNText(model_args)
|
||||
import fastNLP.core.predictor as pre
|
||||
predictor = Predictor("./save/", pre.text_classify_post_processor)
|
||||
|
||||
results = predictor.predict(network=model, data=infer_data)
|
||||
# Load infer data
|
||||
infer_data_set = TextClassifyDataSet(loader=BaseLoader())
|
||||
infer_data_set.convert_for_infer(infer_data, vocabs={"word_vocab": vocab.word2idx})
|
||||
|
||||
results = predictor.predict(network=model, data=infer_data_set)
|
||||
|
||||
self.assertTrue(isinstance(results, list))
|
||||
self.assertGreater(len(results), 0)
|
||||
self.assertEqual(len(results), len(infer_data))
|
||||
for res in results:
|
||||
self.assertTrue(isinstance(res, str))
|
||||
self.assertTrue(res in class_vocab.word2idx)
|
||||
|
||||
del model, predictor, infer_data_set
|
||||
|
||||
model = SeqLabeling(model_args)
|
||||
predictor = Predictor("./save/", pre.seq_label_post_processor)
|
||||
|
||||
infer_data_set = SeqLabelDataSet(loader=BaseLoader())
|
||||
infer_data_set.convert_for_infer(infer_data, vocabs={"word_vocab": vocab.word2idx})
|
||||
|
||||
results = predictor.predict(network=model, data=infer_data_set)
|
||||
self.assertTrue(isinstance(results, list))
|
||||
self.assertEqual(len(results), len(infer_data))
|
||||
for i in range(len(infer_data)):
|
||||
res = results[i]
|
||||
self.assertTrue(isinstance(res, list))
|
||||
self.assertEqual(len(res), 5)
|
||||
self.assertTrue(isinstance(res[0], str))
|
||||
self.assertEqual(len(res), len(infer_data[i]))
|
||||
|
||||
os.system("rm -rf save")
|
||||
print("pickle path deleted")
|
||||
|
@ -1,8 +1,9 @@
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from fastNLP.core.dataset import DataSet
|
||||
from fastNLP.core.field import TextField
|
||||
from fastNLP.core.dataset import SeqLabelDataSet
|
||||
from fastNLP.core.metrics import SeqLabelEvaluator
|
||||
from fastNLP.core.field import TextField, LabelField
|
||||
from fastNLP.core.instance import Instance
|
||||
from fastNLP.core.tester import SeqLabelTester
|
||||
from fastNLP.models.sequence_modeling import SeqLabeling
|
||||
@ -21,7 +22,7 @@ class TestTester(unittest.TestCase):
|
||||
}
|
||||
valid_args = {"save_output": True, "validate_in_training": True, "save_dev_input": True,
|
||||
"save_loss": True, "batch_size": 2, "pickle_path": "./save/",
|
||||
"use_cuda": False, "print_every_step": 1}
|
||||
"use_cuda": False, "print_every_step": 1, "evaluator": SeqLabelEvaluator()}
|
||||
|
||||
train_data = [
|
||||
[['a', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
|
||||
@ -34,16 +35,17 @@ class TestTester(unittest.TestCase):
|
||||
vocab = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9}
|
||||
label_vocab = {'a': 0, '@': 1, 'c': 2, 'd': 3, 'e': 4}
|
||||
|
||||
data_set = DataSet()
|
||||
data_set = SeqLabelDataSet()
|
||||
for example in train_data:
|
||||
text, label = example[0], example[1]
|
||||
x = TextField(text, False)
|
||||
x_len = LabelField(len(text), is_target=False)
|
||||
y = TextField(label, is_target=True)
|
||||
ins = Instance(word_seq=x, label_seq=y)
|
||||
ins = Instance(word_seq=x, truth=y, word_seq_origin_len=x_len)
|
||||
data_set.append(ins)
|
||||
|
||||
data_set.index_field("word_seq", vocab)
|
||||
data_set.index_field("label_seq", label_vocab)
|
||||
data_set.index_field("truth", label_vocab)
|
||||
|
||||
model = SeqLabeling(model_args)
|
||||
|
||||
|
@ -1,8 +1,9 @@
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from fastNLP.core.dataset import DataSet
|
||||
from fastNLP.core.field import TextField
|
||||
from fastNLP.core.dataset import SeqLabelDataSet
|
||||
from fastNLP.core.metrics import SeqLabelEvaluator
|
||||
from fastNLP.core.field import TextField, LabelField
|
||||
from fastNLP.core.instance import Instance
|
||||
from fastNLP.core.loss import Loss
|
||||
from fastNLP.core.optimizer import Optimizer
|
||||
@ -12,14 +13,15 @@ from fastNLP.models.sequence_modeling import SeqLabeling
|
||||
|
||||
class TestTrainer(unittest.TestCase):
|
||||
def test_case_1(self):
|
||||
args = {"epochs": 3, "batch_size": 2, "validate": True, "use_cuda": False, "pickle_path": "./save/",
|
||||
args = {"epochs": 3, "batch_size": 2, "validate": False, "use_cuda": False, "pickle_path": "./save/",
|
||||
"save_best_dev": True, "model_name": "default_model_name.pkl",
|
||||
"loss": Loss(None),
|
||||
"loss": Loss("cross_entropy"),
|
||||
"optimizer": Optimizer("Adam", lr=0.001, weight_decay=0),
|
||||
"vocab_size": 10,
|
||||
"word_emb_dim": 100,
|
||||
"rnn_hidden_units": 100,
|
||||
"num_classes": 5
|
||||
"num_classes": 5,
|
||||
"evaluator": SeqLabelEvaluator()
|
||||
}
|
||||
trainer = SeqLabelTrainer(**args)
|
||||
|
||||
@ -34,16 +36,17 @@ class TestTrainer(unittest.TestCase):
|
||||
vocab = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9}
|
||||
label_vocab = {'a': 0, '@': 1, 'c': 2, 'd': 3, 'e': 4}
|
||||
|
||||
data_set = DataSet()
|
||||
data_set = SeqLabelDataSet()
|
||||
for example in train_data:
|
||||
text, label = example[0], example[1]
|
||||
x = TextField(text, False)
|
||||
y = TextField(label, is_target=True)
|
||||
ins = Instance(word_seq=x, label_seq=y)
|
||||
x_len = LabelField(len(text), is_target=False)
|
||||
y = TextField(label, is_target=False)
|
||||
ins = Instance(word_seq=x, truth=y, word_seq_origin_len=x_len)
|
||||
data_set.append(ins)
|
||||
|
||||
data_set.index_field("word_seq", vocab)
|
||||
data_set.index_field("label_seq", label_vocab)
|
||||
data_set.index_field("truth", label_vocab)
|
||||
|
||||
model = SeqLabeling(args)
|
||||
|
||||
|
@ -9,10 +9,54 @@ input = [1,2,3]
|
||||
|
||||
text = "this is text"
|
||||
|
||||
doubles = 0.5
|
||||
doubles = 0.8
|
||||
|
||||
tt = 0.5
|
||||
|
||||
test = 105
|
||||
|
||||
str = "this is a str"
|
||||
|
||||
double = 0.5
|
||||
|
||||
|
||||
[t]
|
||||
x = "this is an test section"
|
||||
|
||||
|
||||
|
||||
[test-case-2]
|
||||
double = 0.5
|
||||
|
||||
doubles = 0.8
|
||||
|
||||
tt = 0.5
|
||||
|
||||
test = 105
|
||||
|
||||
str = "this is a str"
|
||||
|
||||
[another-test]
|
||||
doubles = 0.8
|
||||
|
||||
tt = 0.5
|
||||
|
||||
test = 105
|
||||
|
||||
str = "this is a str"
|
||||
|
||||
double = 0.5
|
||||
|
||||
|
||||
[one-another-test]
|
||||
doubles = 0.8
|
||||
|
||||
tt = 0.5
|
||||
|
||||
test = 105
|
||||
|
||||
str = "this is a str"
|
||||
|
||||
double = 0.5
|
||||
|
||||
|
||||
|
@ -31,7 +31,7 @@ class TestConfigLoader(unittest.TestCase):
|
||||
return dict
|
||||
|
||||
test_arg = ConfigSection()
|
||||
ConfigLoader("config").load_config(os.path.join("./test/loader", "config"), {"test": test_arg})
|
||||
ConfigLoader().load_config(os.path.join("./test/loader", "config"), {"test": test_arg})
|
||||
|
||||
section = read_section_from_config(os.path.join("./test/loader", "config"), "test")
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from fastNLP.loader.dataset_loader import POSDatasetLoader, LMDatasetLoader, TokenizeDatasetLoader, \
|
||||
from fastNLP.loader.dataset_loader import POSDataSetLoader, LMDataSetLoader, TokenizeDataSetLoader, \
|
||||
PeopleDailyCorpusLoader, ConllLoader
|
||||
|
||||
|
||||
@ -8,34 +9,34 @@ class TestDatasetLoader(unittest.TestCase):
|
||||
def test_case_1(self):
|
||||
data = """Tom\tT\nand\tF\nJerry\tT\n.\tF\n\nHello\tT\nworld\tF\n!\tF"""
|
||||
lines = data.split("\n")
|
||||
answer = POSDatasetLoader.parse(lines)
|
||||
answer = POSDataSetLoader.parse(lines)
|
||||
truth = [[["Tom", "and", "Jerry", "."], ["T", "F", "T", "F"]], [["Hello", "world", "!"], ["T", "F", "F"]]]
|
||||
self.assertListEqual(answer, truth, "POS Dataset Loader")
|
||||
|
||||
def test_case_TokenizeDatasetLoader(self):
|
||||
loader = TokenizeDatasetLoader("./test/data_for_tests/cws_pku_utf_8")
|
||||
data = loader.load_pku(max_seq_len=32)
|
||||
print("pass TokenizeDatasetLoader test!")
|
||||
loader = TokenizeDataSetLoader()
|
||||
data = loader.load("./test/data_for_tests/cws_pku_utf_8", max_seq_len=32)
|
||||
print("pass TokenizeDataSetLoader test!")
|
||||
|
||||
def test_case_POSDatasetLoader(self):
|
||||
loader = POSDatasetLoader("./test/data_for_tests/people.txt")
|
||||
data = loader.load()
|
||||
datas = loader.load_lines()
|
||||
print("pass POSDatasetLoader test!")
|
||||
loader = POSDataSetLoader()
|
||||
data = loader.load("./test/data_for_tests/people.txt")
|
||||
datas = loader.load_lines("./test/data_for_tests/people.txt")
|
||||
print("pass POSDataSetLoader test!")
|
||||
|
||||
def test_case_LMDatasetLoader(self):
|
||||
loader = LMDatasetLoader("./test/data_for_tests/cws_pku_utf_8")
|
||||
data = loader.load()
|
||||
datas = loader.load_lines()
|
||||
print("pass TokenizeDatasetLoader test!")
|
||||
loader = LMDataSetLoader()
|
||||
data = loader.load("./test/data_for_tests/charlm.txt")
|
||||
datas = loader.load_lines("./test/data_for_tests/charlm.txt")
|
||||
print("pass TokenizeDataSetLoader test!")
|
||||
|
||||
def test_PeopleDailyCorpusLoader(self):
|
||||
loader = PeopleDailyCorpusLoader("./test/data_for_tests/people_daily_raw.txt")
|
||||
_, _ = loader.load()
|
||||
loader = PeopleDailyCorpusLoader()
|
||||
_, _ = loader.load("./test/data_for_tests/people_daily_raw.txt")
|
||||
|
||||
def test_ConllLoader(self):
|
||||
loader = ConllLoader("./test/data_for_tests/conll_example.txt")
|
||||
_ = loader.load()
|
||||
loader = ConllLoader()
|
||||
_ = loader.load("./test/data_for_tests/conll_example.txt")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -4,14 +4,16 @@ sys.path.append("..")
|
||||
import argparse
|
||||
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
|
||||
from fastNLP.core.trainer import SeqLabelTrainer
|
||||
from fastNLP.loader.dataset_loader import POSDatasetLoader, BaseLoader
|
||||
from fastNLP.core.preprocess import SeqLabelPreprocess, load_pickle
|
||||
from fastNLP.loader.dataset_loader import BaseLoader
|
||||
from fastNLP.saver.model_saver import ModelSaver
|
||||
from fastNLP.loader.model_loader import ModelLoader
|
||||
from fastNLP.core.tester import SeqLabelTester
|
||||
from fastNLP.models.sequence_modeling import SeqLabeling
|
||||
from fastNLP.core.predictor import SeqLabelInfer
|
||||
from fastNLP.core.optimizer import Optimizer
|
||||
from fastNLP.core.dataset import SeqLabelDataSet, change_field_is_target
|
||||
from fastNLP.core.metrics import SeqLabelEvaluator
|
||||
from fastNLP.core.preprocess import save_pickle, load_pickle
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-s", "--save", type=str, default="./seq_label/", help="path to save pickle files")
|
||||
@ -33,24 +35,27 @@ data_infer_path = args.infer
|
||||
def infer():
|
||||
# Load infer configuration, the same as test
|
||||
test_args = ConfigSection()
|
||||
ConfigLoader("config.cfg").load_config(config_dir, {"POS_infer": test_args})
|
||||
ConfigLoader().load_config(config_dir, {"POS_infer": test_args})
|
||||
|
||||
# fetch dictionary size and number of labels from pickle files
|
||||
word2index = load_pickle(pickle_path, "word2id.pkl")
|
||||
test_args["vocab_size"] = len(word2index)
|
||||
index2label = load_pickle(pickle_path, "class2id.pkl")
|
||||
test_args["num_classes"] = len(index2label)
|
||||
word_vocab = load_pickle(pickle_path, "word2id.pkl")
|
||||
label_vocab = load_pickle(pickle_path, "label2id.pkl")
|
||||
test_args["vocab_size"] = len(word_vocab)
|
||||
test_args["num_classes"] = len(label_vocab)
|
||||
print("vocabularies loaded")
|
||||
|
||||
# Define the same model
|
||||
model = SeqLabeling(test_args)
|
||||
print("model defined")
|
||||
|
||||
# Dump trained parameters into the model
|
||||
ModelLoader.load_pytorch(model, os.path.join(pickle_path, model_name))
|
||||
print("model loaded!")
|
||||
|
||||
# Data Loader
|
||||
raw_data_loader = BaseLoader(data_infer_path)
|
||||
infer_data = raw_data_loader.load_lines()
|
||||
infer_data = SeqLabelDataSet(loader=BaseLoader())
|
||||
infer_data.load(data_infer_path, vocabs={"word_vocab": word_vocab, "label_vocab": label_vocab}, infer=True)
|
||||
print("data set prepared")
|
||||
|
||||
# Inference interface
|
||||
infer = SeqLabelInfer(pickle_path)
|
||||
@ -65,24 +70,18 @@ def train_and_test():
|
||||
# Config Loader
|
||||
trainer_args = ConfigSection()
|
||||
model_args = ConfigSection()
|
||||
ConfigLoader("config.cfg").load_config(config_dir, {
|
||||
ConfigLoader().load_config(config_dir, {
|
||||
"test_seq_label_trainer": trainer_args, "test_seq_label_model": model_args})
|
||||
|
||||
# Data Loader
|
||||
pos_loader = POSDatasetLoader(data_path)
|
||||
train_data = pos_loader.load_lines()
|
||||
data_set = SeqLabelDataSet()
|
||||
data_set.load(data_path)
|
||||
train_set, dev_set = data_set.split(0.3, shuffle=True)
|
||||
model_args["vocab_size"] = len(data_set.word_vocab)
|
||||
model_args["num_classes"] = len(data_set.label_vocab)
|
||||
|
||||
# Preprocessor
|
||||
p = SeqLabelPreprocess()
|
||||
data_train, data_dev = p.run(train_data, pickle_path=pickle_path, train_dev_split=0.5)
|
||||
model_args["vocab_size"] = p.vocab_size
|
||||
model_args["num_classes"] = p.num_classes
|
||||
save_pickle(data_set.word_vocab, pickle_path, "word2id.pkl")
|
||||
save_pickle(data_set.label_vocab, pickle_path, "label2id.pkl")
|
||||
|
||||
# Trainer: two definition styles
|
||||
# 1
|
||||
# trainer = SeqLabelTrainer(trainer_args.data)
|
||||
|
||||
# 2
|
||||
trainer = SeqLabelTrainer(
|
||||
epochs=trainer_args["epochs"],
|
||||
batch_size=trainer_args["batch_size"],
|
||||
@ -98,7 +97,7 @@ def train_and_test():
|
||||
model = SeqLabeling(model_args)
|
||||
|
||||
# Start training
|
||||
trainer.train(model, data_train, data_dev)
|
||||
trainer.train(model, train_set, dev_set)
|
||||
print("Training finished!")
|
||||
|
||||
# Saver
|
||||
@ -106,7 +105,9 @@ def train_and_test():
|
||||
saver.save_pytorch(model)
|
||||
print("Model saved!")
|
||||
|
||||
del model, trainer, pos_loader
|
||||
del model, trainer
|
||||
|
||||
change_field_is_target(dev_set, "truth", True)
|
||||
|
||||
# Define the same model
|
||||
model = SeqLabeling(model_args)
|
||||
@ -117,27 +118,21 @@ def train_and_test():
|
||||
|
||||
# Load test configuration
|
||||
tester_args = ConfigSection()
|
||||
ConfigLoader("config.cfg").load_config(config_dir, {"test_seq_label_tester": tester_args})
|
||||
ConfigLoader().load_config(config_dir, {"test_seq_label_tester": tester_args})
|
||||
|
||||
# Tester
|
||||
tester = SeqLabelTester(save_output=False,
|
||||
save_loss=True,
|
||||
save_best_dev=False,
|
||||
batch_size=4,
|
||||
tester = SeqLabelTester(batch_size=4,
|
||||
use_cuda=False,
|
||||
pickle_path=pickle_path,
|
||||
model_name="seq_label_in_test.pkl",
|
||||
print_every_step=1
|
||||
evaluator=SeqLabelEvaluator()
|
||||
)
|
||||
|
||||
# Start testing with validation data
|
||||
tester.test(model, data_dev)
|
||||
|
||||
# print test results
|
||||
print(tester.show_metrics())
|
||||
tester.test(model, dev_set)
|
||||
print("model tested!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_and_test()
|
||||
# infer()
|
||||
infer()
|
||||
|
@ -1,30 +1,32 @@
|
||||
import os
|
||||
|
||||
from fastNLP.core.predictor import Predictor
|
||||
from fastNLP.core.preprocess import Preprocessor, load_pickle
|
||||
from fastNLP.core.dataset import SeqLabelDataSet, change_field_is_target
|
||||
from fastNLP.core.metrics import SeqLabelEvaluator
|
||||
from fastNLP.core.predictor import SeqLabelInfer
|
||||
from fastNLP.core.preprocess import save_pickle, load_pickle
|
||||
from fastNLP.core.tester import SeqLabelTester
|
||||
from fastNLP.core.trainer import SeqLabelTrainer
|
||||
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
|
||||
from fastNLP.loader.dataset_loader import TokenizeDatasetLoader, BaseLoader
|
||||
from fastNLP.loader.dataset_loader import TokenizeDataSetLoader, BaseLoader
|
||||
from fastNLP.loader.model_loader import ModelLoader
|
||||
from fastNLP.models.sequence_modeling import SeqLabeling
|
||||
from fastNLP.saver.model_saver import ModelSaver
|
||||
|
||||
data_name = "pku_training.utf8"
|
||||
cws_data_path = "test/data_for_tests/cws_pku_utf_8"
|
||||
cws_data_path = "./test/data_for_tests/cws_pku_utf_8"
|
||||
pickle_path = "./save/"
|
||||
data_infer_path = "test/data_for_tests/people_infer.txt"
|
||||
config_path = "test/data_for_tests/config"
|
||||
data_infer_path = "./test/data_for_tests/people_infer.txt"
|
||||
config_path = "./test/data_for_tests/config"
|
||||
|
||||
def infer():
|
||||
# Load infer configuration, the same as test
|
||||
test_args = ConfigSection()
|
||||
ConfigLoader("config.cfg").load_config(config_path, {"POS_infer": test_args})
|
||||
ConfigLoader().load_config(config_path, {"POS_infer": test_args})
|
||||
|
||||
# fetch dictionary size and number of labels from pickle files
|
||||
word2index = load_pickle(pickle_path, "word2id.pkl")
|
||||
test_args["vocab_size"] = len(word2index)
|
||||
index2label = load_pickle(pickle_path, "class2id.pkl")
|
||||
index2label = load_pickle(pickle_path, "label2id.pkl")
|
||||
test_args["num_classes"] = len(index2label)
|
||||
|
||||
# Define the same model
|
||||
@ -34,31 +36,29 @@ def infer():
|
||||
ModelLoader.load_pytorch(model, "./save/saved_model.pkl")
|
||||
print("model loaded!")
|
||||
|
||||
# Data Loader
|
||||
raw_data_loader = BaseLoader(data_infer_path)
|
||||
infer_data = raw_data_loader.load_lines()
|
||||
# Load infer data
|
||||
infer_data = SeqLabelDataSet(loader=BaseLoader())
|
||||
infer_data.load(data_infer_path, vocabs={"word_vocab": word2index}, infer=True)
|
||||
|
||||
# Inference interface
|
||||
infer = Predictor(pickle_path, "seq_label")
|
||||
# inference
|
||||
infer = SeqLabelInfer(pickle_path)
|
||||
results = infer.predict(model, infer_data)
|
||||
|
||||
print(results)
|
||||
|
||||
|
||||
def train_test():
|
||||
# Config Loader
|
||||
train_args = ConfigSection()
|
||||
ConfigLoader("config.cfg").load_config(config_path, {"POS_infer": train_args})
|
||||
ConfigLoader().load_config(config_path, {"POS_infer": train_args})
|
||||
|
||||
# Data Loader
|
||||
loader = TokenizeDatasetLoader(cws_data_path)
|
||||
train_data = loader.load_pku()
|
||||
# define dataset
|
||||
data_train = SeqLabelDataSet(loader=TokenizeDataSetLoader())
|
||||
data_train.load(cws_data_path)
|
||||
train_args["vocab_size"] = len(data_train.word_vocab)
|
||||
train_args["num_classes"] = len(data_train.label_vocab)
|
||||
|
||||
# Preprocessor
|
||||
p = Preprocessor(label_is_seq=True)
|
||||
data_train = p.run(train_data, pickle_path=pickle_path)
|
||||
train_args["vocab_size"] = p.vocab_size
|
||||
train_args["num_classes"] = p.num_classes
|
||||
save_pickle(data_train.word_vocab, pickle_path, "word2id.pkl")
|
||||
save_pickle(data_train.label_vocab, pickle_path, "label2id.pkl")
|
||||
|
||||
# Trainer
|
||||
trainer = SeqLabelTrainer(**train_args.data)
|
||||
@ -73,7 +73,7 @@ def train_test():
|
||||
saver = ModelSaver("./save/saved_model.pkl")
|
||||
saver.save_pytorch(model)
|
||||
|
||||
del model, trainer, loader
|
||||
del model, trainer
|
||||
|
||||
# Define the same model
|
||||
model = SeqLabeling(train_args)
|
||||
@ -83,17 +83,16 @@ def train_test():
|
||||
|
||||
# Load test configuration
|
||||
test_args = ConfigSection()
|
||||
ConfigLoader("config.cfg").load_config(config_path, {"POS_infer": test_args})
|
||||
ConfigLoader().load_config(config_path, {"POS_infer": test_args})
|
||||
test_args["evaluator"] = SeqLabelEvaluator()
|
||||
|
||||
# Tester
|
||||
tester = SeqLabelTester(**test_args.data)
|
||||
|
||||
# Start testing
|
||||
change_field_is_target(data_train, "truth", True)
|
||||
tester.test(model, data_train)
|
||||
|
||||
# print test results
|
||||
print(tester.show_metrics())
|
||||
|
||||
|
||||
def test():
|
||||
os.makedirs("save", exist_ok=True)
|
||||
|
@ -1,11 +1,12 @@
|
||||
import os
|
||||
|
||||
from fastNLP.core.dataset import SeqLabelDataSet, change_field_is_target
|
||||
from fastNLP.core.metrics import SeqLabelEvaluator
|
||||
from fastNLP.core.optimizer import Optimizer
|
||||
from fastNLP.core.preprocess import SeqLabelPreprocess
|
||||
from fastNLP.core.preprocess import save_pickle
|
||||
from fastNLP.core.tester import SeqLabelTester
|
||||
from fastNLP.core.trainer import SeqLabelTrainer
|
||||
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
|
||||
from fastNLP.loader.dataset_loader import POSDatasetLoader
|
||||
from fastNLP.loader.model_loader import ModelLoader
|
||||
from fastNLP.models.sequence_modeling import SeqLabeling
|
||||
from fastNLP.saver.model_saver import ModelSaver
|
||||
@ -21,18 +22,17 @@ def test_training():
|
||||
# Config Loader
|
||||
trainer_args = ConfigSection()
|
||||
model_args = ConfigSection()
|
||||
ConfigLoader("_").load_config(config_dir, {
|
||||
ConfigLoader().load_config(config_dir, {
|
||||
"test_seq_label_trainer": trainer_args, "test_seq_label_model": model_args})
|
||||
|
||||
# Data Loader
|
||||
pos_loader = POSDatasetLoader(data_path)
|
||||
train_data = pos_loader.load_lines()
|
||||
data_set = SeqLabelDataSet()
|
||||
data_set.load(data_path)
|
||||
data_train, data_dev = data_set.split(0.3, shuffle=True)
|
||||
model_args["vocab_size"] = len(data_set.word_vocab)
|
||||
model_args["num_classes"] = len(data_set.label_vocab)
|
||||
|
||||
# Preprocessor
|
||||
p = SeqLabelPreprocess()
|
||||
data_train, data_dev = p.run(train_data, pickle_path=pickle_path, train_dev_split=0.5)
|
||||
model_args["vocab_size"] = p.vocab_size
|
||||
model_args["num_classes"] = p.num_classes
|
||||
save_pickle(data_set.word_vocab, pickle_path, "word2id.pkl")
|
||||
save_pickle(data_set.label_vocab, pickle_path, "label2id.pkl")
|
||||
|
||||
trainer = SeqLabelTrainer(
|
||||
epochs=trainer_args["epochs"],
|
||||
@ -55,7 +55,7 @@ def test_training():
|
||||
saver = ModelSaver(os.path.join(pickle_path, model_name))
|
||||
saver.save_pytorch(model)
|
||||
|
||||
del model, trainer, pos_loader
|
||||
del model, trainer
|
||||
|
||||
# Define the same model
|
||||
model = SeqLabeling(model_args)
|
||||
@ -65,21 +65,16 @@ def test_training():
|
||||
|
||||
# Load test configuration
|
||||
tester_args = ConfigSection()
|
||||
ConfigLoader("config.cfg").load_config(config_dir, {"test_seq_label_tester": tester_args})
|
||||
ConfigLoader().load_config(config_dir, {"test_seq_label_tester": tester_args})
|
||||
|
||||
# Tester
|
||||
tester = SeqLabelTester(save_output=False,
|
||||
save_loss=True,
|
||||
save_best_dev=False,
|
||||
batch_size=4,
|
||||
tester = SeqLabelTester(batch_size=4,
|
||||
use_cuda=False,
|
||||
pickle_path=pickle_path,
|
||||
model_name="seq_label_in_test.pkl",
|
||||
print_every_step=1
|
||||
evaluator=SeqLabelEvaluator()
|
||||
)
|
||||
|
||||
# Start testing with validation data
|
||||
change_field_is_target(data_dev, "truth", True)
|
||||
tester.test(model, data_dev)
|
||||
|
||||
loss, accuracy = tester.metrics
|
||||
assert 0 < accuracy < 1
|
||||
|
@ -9,13 +9,14 @@ sys.path.append("..")
|
||||
from fastNLP.core.predictor import ClassificationInfer
|
||||
from fastNLP.core.trainer import ClassificationTrainer
|
||||
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
|
||||
from fastNLP.loader.dataset_loader import ClassDatasetLoader
|
||||
from fastNLP.loader.dataset_loader import ClassDataSetLoader
|
||||
from fastNLP.loader.model_loader import ModelLoader
|
||||
from fastNLP.core.preprocess import ClassPreprocess
|
||||
from fastNLP.models.cnn_text_classification import CNNText
|
||||
from fastNLP.saver.model_saver import ModelSaver
|
||||
from fastNLP.core.optimizer import Optimizer
|
||||
from fastNLP.core.loss import Loss
|
||||
from fastNLP.core.dataset import TextClassifyDataSet
|
||||
from fastNLP.core.preprocess import save_pickle, load_pickle
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-s", "--save", type=str, default="./test_classification/", help="path to save pickle files")
|
||||
@ -34,21 +35,18 @@ config_dir = args.config
|
||||
def infer():
|
||||
# load dataset
|
||||
print("Loading data...")
|
||||
ds_loader = ClassDatasetLoader(train_data_dir)
|
||||
data = ds_loader.load()
|
||||
unlabeled_data = [x[0] for x in data]
|
||||
word_vocab = load_pickle(save_dir, "word2id.pkl")
|
||||
label_vocab = load_pickle(save_dir, "label2id.pkl")
|
||||
print("vocabulary size:", len(word_vocab))
|
||||
print("number of classes:", len(label_vocab))
|
||||
|
||||
# pre-process data
|
||||
pre = ClassPreprocess()
|
||||
data = pre.run(data, pickle_path=save_dir)
|
||||
print("vocabulary size:", pre.vocab_size)
|
||||
print("number of classes:", pre.num_classes)
|
||||
infer_data = TextClassifyDataSet(loader=ClassDataSetLoader())
|
||||
infer_data.load(train_data_dir, vocabs={"word_vocab": word_vocab, "label_vocab": label_vocab})
|
||||
|
||||
model_args = ConfigSection()
|
||||
# TODO: load from config file
|
||||
model_args["vocab_size"] = pre.vocab_size
|
||||
model_args["num_classes"] = pre.num_classes
|
||||
# ConfigLoader.load_config(config_dir, {"text_class_model": model_args})
|
||||
model_args["vocab_size"] = len(word_vocab)
|
||||
model_args["num_classes"] = len(label_vocab)
|
||||
ConfigLoader.load_config(config_dir, {"text_class_model": model_args})
|
||||
|
||||
# construct model
|
||||
print("Building model...")
|
||||
@ -59,7 +57,7 @@ def infer():
|
||||
print("model loaded!")
|
||||
|
||||
infer = ClassificationInfer(pickle_path=save_dir)
|
||||
results = infer.predict(cnn, unlabeled_data)
|
||||
results = infer.predict(cnn, infer_data)
|
||||
print(results)
|
||||
|
||||
|
||||
@ -69,32 +67,23 @@ def train():
|
||||
|
||||
# load dataset
|
||||
print("Loading data...")
|
||||
ds_loader = ClassDatasetLoader(train_data_dir)
|
||||
data = ds_loader.load()
|
||||
print(data[0])
|
||||
data = TextClassifyDataSet(loader=ClassDataSetLoader())
|
||||
data.load(train_data_dir)
|
||||
|
||||
# pre-process data
|
||||
pre = ClassPreprocess()
|
||||
data_train = pre.run(data, pickle_path=save_dir)
|
||||
print("vocabulary size:", pre.vocab_size)
|
||||
print("number of classes:", pre.num_classes)
|
||||
print("vocabulary size:", len(data.word_vocab))
|
||||
print("number of classes:", len(data.label_vocab))
|
||||
save_pickle(data.word_vocab, save_dir, "word2id.pkl")
|
||||
save_pickle(data.label_vocab, save_dir, "label2id.pkl")
|
||||
|
||||
model_args["num_classes"] = pre.num_classes
|
||||
model_args["vocab_size"] = pre.vocab_size
|
||||
model_args["num_classes"] = len(data.label_vocab)
|
||||
model_args["vocab_size"] = len(data.word_vocab)
|
||||
|
||||
# construct model
|
||||
print("Building model...")
|
||||
model = CNNText(model_args)
|
||||
|
||||
# ConfigSaver().save_config(config_dir, {"text_class_model": model_args})
|
||||
|
||||
# train
|
||||
print("Training...")
|
||||
|
||||
# 1
|
||||
# trainer = ClassificationTrainer(train_args)
|
||||
|
||||
# 2
|
||||
trainer = ClassificationTrainer(epochs=train_args["epochs"],
|
||||
batch_size=train_args["batch_size"],
|
||||
validate=train_args["validate"],
|
||||
@ -104,7 +93,7 @@ def train():
|
||||
model_name=model_name,
|
||||
loss=Loss("cross_entropy"),
|
||||
optimizer=Optimizer("SGD", lr=0.001, momentum=0.9))
|
||||
trainer.train(model, data_train)
|
||||
trainer.train(model, data)
|
||||
|
||||
print("Training finished!")
|
||||
|
||||
@ -115,4 +104,4 @@ def train():
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
||||
# infer()
|
||||
infer()
|
||||
|
@ -21,7 +21,7 @@ class TestConfigSaver(unittest.TestCase):
|
||||
|
||||
standard_section = ConfigSection()
|
||||
t_section = ConfigSection()
|
||||
ConfigLoader(config_file_path).load_config(config_file_path, {"test": standard_section, "t": t_section})
|
||||
ConfigLoader().load_config(config_file_path, {"test": standard_section, "t": t_section})
|
||||
|
||||
config_saver = ConfigSaver(config_file_path)
|
||||
|
||||
@ -48,7 +48,7 @@ class TestConfigSaver(unittest.TestCase):
|
||||
one_another_test_section = ConfigSection()
|
||||
a_test_case_2_section = ConfigSection()
|
||||
|
||||
ConfigLoader(config_file_path).load_config(config_file_path, {"test": test_section,
|
||||
ConfigLoader().load_config(config_file_path, {"test": test_section,
|
||||
"another-test": another_test_section,
|
||||
"t": at_section,
|
||||
"one-another-test": one_another_test_section,
|
||||
|
@ -54,7 +54,7 @@ def mock_cws():
|
||||
class2id = Vocabulary(need_default=False)
|
||||
label_list = ['B', 'M', 'E', 'S']
|
||||
class2id.update(label_list)
|
||||
save_pickle(class2id, "./mock/", "class2id.pkl")
|
||||
save_pickle(class2id, "./mock/", "label2id.pkl")
|
||||
|
||||
model_args = {"vocab_size": len(word2id), "word_emb_dim": 50, "rnn_hidden_units": 50, "num_classes": len(class2id)}
|
||||
config_file = """
|
||||
@ -115,7 +115,7 @@ def mock_pos_tag():
|
||||
idx2label = Vocabulary(need_default=False)
|
||||
label_list = ['B-n', 'M-v', 'E-nv', 'S-adj', 'B-v', 'M-vn', 'S-adv']
|
||||
idx2label.update(label_list)
|
||||
save_pickle(idx2label, "./mock/", "class2id.pkl")
|
||||
save_pickle(idx2label, "./mock/", "label2id.pkl")
|
||||
|
||||
model_args = {"vocab_size": len(vocab), "word_emb_dim": 50, "rnn_hidden_units": 50, "num_classes": len(idx2label)}
|
||||
config_file = """
|
||||
@ -163,7 +163,7 @@ def mock_text_classify():
|
||||
idx2label = Vocabulary(need_default=False)
|
||||
label_list = ['class_A', 'class_B', 'class_C', 'class_D', 'class_E', 'class_F']
|
||||
idx2label.update(label_list)
|
||||
save_pickle(idx2label, "./mock/", "class2id.pkl")
|
||||
save_pickle(idx2label, "./mock/", "label2id.pkl")
|
||||
|
||||
model_args = {"vocab_size": len(vocab), "word_emb_dim": 50, "rnn_hidden_units": 50, "num_classes": len(idx2label)}
|
||||
config_file = """
|
||||
|
Loading…
Reference in New Issue
Block a user