mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-03 20:57:37 +08:00
Merge pull request #97 from fastnlp/dev
update DataSet and DataSetLoader
This commit is contained in:
commit
6a1d237c64
@ -6,91 +6,33 @@ from copy import deepcopy
|
||||
from fastNLP.core.field import TextField, LabelField
|
||||
from fastNLP.core.instance import Instance
|
||||
from fastNLP.core.vocabulary import Vocabulary
|
||||
from fastNLP.loader.dataset_loader import POSDataSetLoader, ClassDataSetLoader
|
||||
|
||||
|
||||
def create_dataset_from_lists(str_lists: list, word_vocab: dict, has_target: bool = False, label_vocab: dict = None):
|
||||
if has_target is True:
|
||||
if label_vocab is None:
|
||||
raise RuntimeError("Must provide label vocabulary to transform labels.")
|
||||
return create_labeled_dataset_from_lists(str_lists, word_vocab, label_vocab)
|
||||
else:
|
||||
return create_unlabeled_dataset_from_lists(str_lists, word_vocab)
|
||||
|
||||
|
||||
def create_labeled_dataset_from_lists(str_lists, word_vocab, label_vocab):
|
||||
"""Create an DataSet instance that contains labels.
|
||||
|
||||
:param str_lists: list of list of strings, [num_examples, 2, *].
|
||||
::
|
||||
[
|
||||
[[word_11, word_12, ...], [label_11, label_12, ...]],
|
||||
...
|
||||
]
|
||||
|
||||
:param word_vocab: dict of (str: int), which means (word: index).
|
||||
:param label_vocab: dict of (str: int), which means (word: index).
|
||||
:return data_set: a DataSet instance.
|
||||
|
||||
"""
|
||||
data_set = DataSet()
|
||||
for example in str_lists:
|
||||
word_seq, label_seq = example[0], example[1]
|
||||
x = TextField(word_seq, is_target=False)
|
||||
y = TextField(label_seq, is_target=True)
|
||||
data_set.append(Instance(word_seq=x, label_seq=y))
|
||||
data_set.index_field("word_seq", word_vocab)
|
||||
data_set.index_field("label_seq", label_vocab)
|
||||
return data_set
|
||||
|
||||
|
||||
def create_unlabeled_dataset_from_lists(str_lists, word_vocab):
|
||||
"""Create an DataSet instance that contains no labels.
|
||||
|
||||
:param str_lists: list of list of strings, [num_examples, *].
|
||||
::
|
||||
[
|
||||
[word_11, word_12, ...],
|
||||
...
|
||||
]
|
||||
|
||||
:param word_vocab: dict of (str: int), which means (word: index).
|
||||
:return data_set: a DataSet instance.
|
||||
|
||||
"""
|
||||
data_set = DataSet()
|
||||
for word_seq in str_lists:
|
||||
x = TextField(word_seq, is_target=False)
|
||||
data_set.append(Instance(word_seq=x))
|
||||
data_set.index_field("word_seq", word_vocab)
|
||||
return data_set
|
||||
|
||||
|
||||
class DataSet(list):
|
||||
"""A DataSet object is a list of Instance objects.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, name="", instances=None, load_func=None):
|
||||
def __init__(self, name="", instances=None):
|
||||
"""
|
||||
|
||||
:param name: str, the name of the dataset. (default: "")
|
||||
:param instances: list of Instance objects. (default: None)
|
||||
:param load_func: a function that takes the dataset path (string) as input and returns multi-level lists.
|
||||
"""
|
||||
list.__init__([])
|
||||
self.name = name
|
||||
self.origin_len = None
|
||||
if instances is not None:
|
||||
self.extend(instances)
|
||||
self.data_set_load_func = load_func
|
||||
|
||||
def index_all(self, vocab):
|
||||
for ins in self:
|
||||
ins.index_all(vocab)
|
||||
return self
|
||||
|
||||
def index_field(self, field_name, vocab):
|
||||
for ins in self:
|
||||
ins.index_field(field_name, vocab)
|
||||
return self
|
||||
|
||||
def to_tensor(self, idx: int, padding_length: dict):
|
||||
"""Convert an instance in a dataset to tensor.
|
||||
@ -102,7 +44,7 @@ class DataSet(list):
|
||||
|
||||
"""
|
||||
ins = self[idx]
|
||||
return ins.to_tensor(padding_length)
|
||||
return ins.to_tensor(padding_length, self.origin_len)
|
||||
|
||||
def get_length(self):
|
||||
"""Fetch lengths of all fields in all instances in a dataset.
|
||||
@ -117,42 +59,9 @@ class DataSet(list):
|
||||
lengths[field_name].append(field_length)
|
||||
return lengths
|
||||
|
||||
def convert(self, data):
|
||||
"""Convert lists of strings into Instances with Fields, creating Vocabulary for labeled data. Used in Training."""
|
||||
raise NotImplementedError
|
||||
|
||||
def convert_with_vocabs(self, data, vocabs):
|
||||
"""Convert lists of strings into Instances with Fields, using existing Vocabulary, with labels. Used in Testing."""
|
||||
raise NotImplementedError
|
||||
|
||||
def convert_for_infer(self, data, vocabs):
|
||||
"""Convert lists of strings into Instances with Fields, using existing Vocabulary, without labels. Used in predicting."""
|
||||
|
||||
def load(self, data_path, vocabs=None, infer=False):
|
||||
"""Load data from the given files.
|
||||
|
||||
:param data_path: str, the path to the data
|
||||
:param infer: bool. If True, there is no label information in the data. Default: False.
|
||||
:param vocabs: dict of (name: Vocabulary object), used to index data. If not provided, a new vocabulary will be constructed.
|
||||
|
||||
"""
|
||||
raw_data = self.data_set_load_func(data_path)
|
||||
if infer is True:
|
||||
self.convert_for_infer(raw_data, vocabs)
|
||||
else:
|
||||
if vocabs is not None:
|
||||
self.convert_with_vocabs(raw_data, vocabs)
|
||||
else:
|
||||
self.convert(raw_data)
|
||||
|
||||
def load_raw(self, raw_data, vocabs):
|
||||
"""Load raw data without loader. Used in FastNLP class.
|
||||
|
||||
:param raw_data:
|
||||
:param vocabs:
|
||||
:return:
|
||||
"""
|
||||
self.convert_for_infer(raw_data, vocabs)
|
||||
def shuffle(self):
|
||||
random.shuffle(self)
|
||||
return self
|
||||
|
||||
def split(self, ratio, shuffle=True):
|
||||
"""Train/dev splitting
|
||||
@ -165,7 +74,7 @@ class DataSet(list):
|
||||
"""
|
||||
assert 0 < ratio < 1
|
||||
if shuffle:
|
||||
random.shuffle(self)
|
||||
self.shuffle()
|
||||
split_idx = int(len(self) * ratio)
|
||||
dev_set = deepcopy(self)
|
||||
train_set = deepcopy(self)
|
||||
@ -173,134 +82,46 @@ class DataSet(list):
|
||||
del dev_set[split_idx:]
|
||||
return train_set, dev_set
|
||||
|
||||
|
||||
class SeqLabelDataSet(DataSet):
|
||||
def __init__(self, instances=None, load_func=POSDataSetLoader().load):
|
||||
super(SeqLabelDataSet, self).__init__(name="", instances=instances, load_func=load_func)
|
||||
self.word_vocab = Vocabulary()
|
||||
self.label_vocab = Vocabulary()
|
||||
|
||||
def convert(self, data):
|
||||
"""Convert lists of strings into Instances with Fields.
|
||||
|
||||
:param data: 3-level lists. Entries are strings.
|
||||
def rename_field(self, old_name, new_name):
|
||||
"""rename a field
|
||||
"""
|
||||
bar = ProgressBar(total=len(data))
|
||||
for example in data:
|
||||
word_seq, label_seq = example[0], example[1]
|
||||
# list, list
|
||||
self.word_vocab.update(word_seq)
|
||||
self.label_vocab.update(label_seq)
|
||||
x = TextField(word_seq, is_target=False)
|
||||
x_len = LabelField(len(word_seq), is_target=False)
|
||||
y = TextField(label_seq, is_target=False)
|
||||
instance = Instance()
|
||||
instance.add_field("word_seq", x)
|
||||
instance.add_field("truth", y)
|
||||
instance.add_field("word_seq_origin_len", x_len)
|
||||
self.append(instance)
|
||||
bar.move()
|
||||
self.index_field("word_seq", self.word_vocab)
|
||||
self.index_field("truth", self.label_vocab)
|
||||
# no need to index "word_seq_origin_len"
|
||||
for ins in self:
|
||||
ins.rename_field(old_name, new_name)
|
||||
return self
|
||||
|
||||
def convert_with_vocabs(self, data, vocabs):
|
||||
for example in data:
|
||||
word_seq, label_seq = example[0], example[1]
|
||||
# list, list
|
||||
x = TextField(word_seq, is_target=False)
|
||||
x_len = LabelField(len(word_seq), is_target=False)
|
||||
y = TextField(label_seq, is_target=False)
|
||||
instance = Instance()
|
||||
instance.add_field("word_seq", x)
|
||||
instance.add_field("truth", y)
|
||||
instance.add_field("word_seq_origin_len", x_len)
|
||||
self.append(instance)
|
||||
self.index_field("word_seq", vocabs["word_vocab"])
|
||||
self.index_field("truth", vocabs["label_vocab"])
|
||||
# no need to index "word_seq_origin_len"
|
||||
def set_target(self, **fields):
|
||||
"""Change the flag of `is_target` for all instance. For fields not set here, leave their `is_target` unchanged.
|
||||
|
||||
def convert_for_infer(self, data, vocabs):
|
||||
for word_seq in data:
|
||||
# list
|
||||
x = TextField(word_seq, is_target=False)
|
||||
x_len = LabelField(len(word_seq), is_target=False)
|
||||
instance = Instance()
|
||||
instance.add_field("word_seq", x)
|
||||
instance.add_field("word_seq_origin_len", x_len)
|
||||
self.append(instance)
|
||||
self.index_field("word_seq", vocabs["word_vocab"])
|
||||
# no need to index "word_seq_origin_len"
|
||||
:param key-value pairs for field-name and `is_target` value(True, False or None).
|
||||
"""
|
||||
for ins in self:
|
||||
ins.set_target(**fields)
|
||||
return self
|
||||
|
||||
def update_vocab(self, **name_vocab):
|
||||
"""using certain field data to update vocabulary.
|
||||
|
||||
class TextClassifyDataSet(DataSet):
|
||||
def __init__(self, instances=None, load_func=ClassDataSetLoader().load):
|
||||
super(TextClassifyDataSet, self).__init__(name="", instances=instances, load_func=load_func)
|
||||
self.word_vocab = Vocabulary()
|
||||
self.label_vocab = Vocabulary(need_default=False)
|
||||
e.g. ::
|
||||
|
||||
def convert(self, data):
|
||||
for example in data:
|
||||
word_seq, label = example[0], example[1]
|
||||
# list, str
|
||||
self.word_vocab.update(word_seq)
|
||||
self.label_vocab.update(label)
|
||||
x = TextField(word_seq, is_target=False)
|
||||
y = LabelField(label, is_target=True)
|
||||
instance = Instance()
|
||||
instance.add_field("word_seq", x)
|
||||
instance.add_field("label", y)
|
||||
self.append(instance)
|
||||
self.index_field("word_seq", self.word_vocab)
|
||||
self.index_field("label", self.label_vocab)
|
||||
# update word vocab and label vocab seperately
|
||||
dataset.update_vocab(word_seq=word_vocab, label_seq=label_vocab)
|
||||
"""
|
||||
for field_name, vocab in name_vocab.items():
|
||||
for ins in self:
|
||||
vocab.update(ins[field_name].contents())
|
||||
return self
|
||||
|
||||
def convert_with_vocabs(self, data, vocabs):
|
||||
for example in data:
|
||||
word_seq, label = example[0], example[1]
|
||||
# list, str
|
||||
x = TextField(word_seq, is_target=False)
|
||||
y = LabelField(label, is_target=True)
|
||||
instance = Instance()
|
||||
instance.add_field("word_seq", x)
|
||||
instance.add_field("label", y)
|
||||
self.append(instance)
|
||||
self.index_field("word_seq", vocabs["word_vocab"])
|
||||
self.index_field("label", vocabs["label_vocab"])
|
||||
def set_origin_len(self, origin_field, origin_len_name=None):
|
||||
"""make dataset tensor output contain origin_len field.
|
||||
|
||||
def convert_for_infer(self, data, vocabs):
|
||||
for word_seq in data:
|
||||
# list
|
||||
x = TextField(word_seq, is_target=False)
|
||||
instance = Instance()
|
||||
instance.add_field("word_seq", x)
|
||||
self.append(instance)
|
||||
self.index_field("word_seq", vocabs["word_vocab"])
|
||||
e.g. ::
|
||||
|
||||
|
||||
def change_field_is_target(data_set, field_name, new_target):
|
||||
"""Change the flag of is_target in a field.
|
||||
|
||||
:param data_set: a DataSet object
|
||||
:param field_name: str, the name of the field
|
||||
:param new_target: one of (True, False, None), representing this field is batch_x / is batch_y / neither.
|
||||
|
||||
"""
|
||||
for inst in data_set:
|
||||
inst.fields[field_name].is_target = new_target
|
||||
|
||||
|
||||
class ProgressBar:
|
||||
|
||||
def __init__(self, count=0, total=0, width=100):
|
||||
self.count = count
|
||||
self.total = total
|
||||
self.width = width
|
||||
|
||||
def move(self):
|
||||
self.count += 1
|
||||
progress = self.width * self.count // self.total
|
||||
sys.stdout.write('{0:3}/{1:3}: '.format(self.count, self.total))
|
||||
sys.stdout.write('#' * progress + '-' * (self.width - progress) + '\r')
|
||||
if progress == self.width:
|
||||
sys.stdout.write('\n')
|
||||
sys.stdout.flush()
|
||||
# output "word_seq_origin_len", lengths based on "word_seq" field
|
||||
dataset.set_origin_len("word_seq")
|
||||
"""
|
||||
if origin_field is None:
|
||||
self.origin_len = None
|
||||
else:
|
||||
self.origin_len = (origin_field + "_origin_len", origin_field) \
|
||||
if origin_len_name is None else (origin_len_name, origin_field)
|
||||
return self
|
||||
|
@ -18,6 +18,8 @@ class Field(object):
|
||||
def to_tensor(self, padding_length):
|
||||
raise NotImplementedError
|
||||
|
||||
def contents(self):
|
||||
raise NotImplementedError
|
||||
|
||||
class TextField(Field):
|
||||
def __init__(self, text, is_target):
|
||||
@ -57,6 +59,8 @@ class TextField(Field):
|
||||
pads = [0] * (padding_length - self.get_length())
|
||||
return torch.LongTensor(self._index + pads)
|
||||
|
||||
def contents(self):
|
||||
return self.text.copy()
|
||||
|
||||
class LabelField(Field):
|
||||
"""The Field representing a single label. Can be a string or integer.
|
||||
@ -92,6 +96,8 @@ class LabelField(Field):
|
||||
else:
|
||||
return torch.LongTensor([self._index])
|
||||
|
||||
def contents(self):
|
||||
return [self.label]
|
||||
|
||||
class SeqLabelField(Field):
|
||||
def __init__(self, label_seq, is_target=True):
|
||||
@ -122,6 +128,8 @@ class SeqLabelField(Field):
|
||||
else:
|
||||
return torch.LongTensor(self._index + pads)
|
||||
|
||||
def contents(self):
|
||||
return self.label_seq.copy()
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf = TextField("test the code".split(), is_target=False)
|
||||
|
@ -1,3 +1,5 @@
|
||||
import torch
|
||||
|
||||
class Instance(object):
|
||||
"""An instance which consists of Fields is an example in the DataSet.
|
||||
|
||||
@ -10,6 +12,28 @@ class Instance(object):
|
||||
|
||||
def add_field(self, field_name, field):
|
||||
self.fields[field_name] = field
|
||||
return self
|
||||
|
||||
def rename_field(self, old_name, new_name):
|
||||
if old_name in self.fields:
|
||||
self.fields[new_name] = self.fields.pop(old_name)
|
||||
if old_name in self.indexes:
|
||||
self.indexes[new_name] = self.indexes.pop(old_name)
|
||||
else:
|
||||
raise KeyError("error, no such field: {}".format(old_name))
|
||||
return self
|
||||
|
||||
def set_target(self, **fields):
|
||||
for name, val in fields.items():
|
||||
if name in self.fields:
|
||||
self.fields[name].is_target = val
|
||||
return self
|
||||
|
||||
def __getitem__(self, name):
|
||||
if name in self.fields:
|
||||
return self.fields[name]
|
||||
else:
|
||||
raise KeyError("{} not found".format(name))
|
||||
|
||||
def get_length(self):
|
||||
"""Fetch the length of all fields in the instance.
|
||||
@ -24,6 +48,7 @@ class Instance(object):
|
||||
"""use `vocab` to index certain field
|
||||
"""
|
||||
self.indexes[field_name] = self.fields[field_name].index(vocab)
|
||||
return self
|
||||
|
||||
def index_all(self, vocab):
|
||||
"""use `vocab` to index all fields
|
||||
@ -35,7 +60,7 @@ class Instance(object):
|
||||
self.indexes = indexes
|
||||
return indexes
|
||||
|
||||
def to_tensor(self, padding_length: dict):
|
||||
def to_tensor(self, padding_length: dict, origin_len=None):
|
||||
"""Convert instance to tensor.
|
||||
|
||||
:param padding_length: dict of (str: int), which means (field name: padding_length of this field)
|
||||
@ -53,4 +78,7 @@ class Instance(object):
|
||||
else:
|
||||
# is_target is None
|
||||
continue
|
||||
if origin_len is not None:
|
||||
name, field_name = origin_len
|
||||
tensor_x[name] = torch.LongTensor([self.fields[field_name].get_length()])
|
||||
return tensor_x, tensor_y
|
||||
|
@ -2,9 +2,9 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from fastNLP.core.batch import Batch
|
||||
from fastNLP.core.dataset import create_dataset_from_lists
|
||||
from fastNLP.core.preprocess import load_pickle
|
||||
from fastNLP.core.sampler import SequentialSampler
|
||||
from fastNLP.loader.dataset_loader import convert_seq2seq_dataset, convert_seq2tag_dataset, convert_seq_dataset
|
||||
|
||||
|
||||
class Predictor(object):
|
||||
@ -79,7 +79,8 @@ class Predictor(object):
|
||||
:return data_set: a DataSet instance.
|
||||
"""
|
||||
assert isinstance(data, list)
|
||||
return create_dataset_from_lists(data, self.word_vocab, has_target=False)
|
||||
data = convert_seq_dataset(data)
|
||||
data.index_field("word_seq", self.word_vocab)
|
||||
|
||||
|
||||
class SeqLabelInfer(Predictor):
|
||||
|
@ -1,6 +1,7 @@
|
||||
import os
|
||||
|
||||
from fastNLP.core.dataset import SeqLabelDataSet, TextClassifyDataSet
|
||||
from fastNLP.core.dataset import DataSet
|
||||
from fastNLP.loader.dataset_loader import convert_seq_dataset
|
||||
from fastNLP.core.predictor import SeqLabelInfer, ClassificationInfer
|
||||
from fastNLP.core.preprocess import load_pickle
|
||||
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
|
||||
@ -178,13 +179,11 @@ class FastNLP(object):
|
||||
:param infer_input: 2-D lists of strings
|
||||
:return data_set: a DataSet object
|
||||
"""
|
||||
if self.infer_type == "seq_label":
|
||||
data_set = SeqLabelDataSet()
|
||||
data_set.load_raw(infer_input, {"word_vocab": self.word_vocab})
|
||||
return data_set
|
||||
elif self.infer_type == "text_class":
|
||||
data_set = TextClassifyDataSet()
|
||||
data_set.load_raw(infer_input, {"word_vocab": self.word_vocab})
|
||||
if self.infer_type in ["seq_label", "text_class"]:
|
||||
data_set = convert_seq_dataset(infer_input)
|
||||
data_set.index_field("word_seq", self.word_vocab)
|
||||
if self.infer_type == "seq_label":
|
||||
data_set.set_origin_len("word_seq")
|
||||
return data_set
|
||||
else:
|
||||
raise RuntimeError("fail to make outputs with infer type {}".format(self.infer_type))
|
||||
|
@ -1,6 +1,71 @@
|
||||
import os
|
||||
|
||||
from fastNLP.loader.base_loader import BaseLoader
|
||||
from fastNLP.core.dataset import DataSet
|
||||
from fastNLP.core.instance import Instance
|
||||
from fastNLP.core.field import *
|
||||
|
||||
def convert_seq_dataset(data):
|
||||
"""Create an DataSet instance that contains no labels.
|
||||
|
||||
:param data: list of list of strings, [num_examples, *].
|
||||
::
|
||||
[
|
||||
[word_11, word_12, ...],
|
||||
...
|
||||
]
|
||||
|
||||
:return: a DataSet.
|
||||
"""
|
||||
dataset = DataSet()
|
||||
for word_seq in data:
|
||||
x = TextField(word_seq, is_target=False)
|
||||
dataset.append(Instance(word_seq=x))
|
||||
return dataset
|
||||
|
||||
def convert_seq2tag_dataset(data):
|
||||
"""Convert list of data into DataSet
|
||||
|
||||
:param data: list of list of strings, [num_examples, *].
|
||||
::
|
||||
[
|
||||
[ [word_11, word_12, ...], label_1 ],
|
||||
[ [word_21, word_22, ...], label_2 ],
|
||||
...
|
||||
]
|
||||
|
||||
:return: a DataSet.
|
||||
"""
|
||||
dataset = DataSet()
|
||||
for sample in data:
|
||||
word_seq, label = sample[0], sample[1]
|
||||
ins = Instance()
|
||||
ins.add_field("word_seq", TextField(word_seq, is_target=False)) \
|
||||
.add_field("label", LabelField(label, is_target=True))
|
||||
dataset.append(ins)
|
||||
return dataset
|
||||
|
||||
def convert_seq2seq_dataset(data):
|
||||
"""Convert list of data into DataSet
|
||||
|
||||
:param data: list of list of strings, [num_examples, *].
|
||||
::
|
||||
[
|
||||
[ [word_11, word_12, ...], [label_1, label_1, ...] ],
|
||||
[ [word_21, word_22, ...], [label_2, label_1, ...] ],
|
||||
...
|
||||
]
|
||||
|
||||
:return: a DataSet.
|
||||
"""
|
||||
dataset = DataSet()
|
||||
for sample in data:
|
||||
word_seq, label_seq = sample[0], sample[1]
|
||||
ins = Instance()
|
||||
ins.add_field("word_seq", TextField(word_seq, is_target=False)) \
|
||||
.add_field("label_seq", TextField(label_seq, is_target=True))
|
||||
dataset.append(ins)
|
||||
return dataset
|
||||
|
||||
|
||||
class DataSetLoader(BaseLoader):
|
||||
@ -10,8 +75,28 @@ class DataSetLoader(BaseLoader):
|
||||
super(DataSetLoader, self).__init__()
|
||||
|
||||
def load(self, path):
|
||||
""" load data in `path` into a dataset
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def convert(self, data):
|
||||
"""convert list of data into dataset
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
class RawDataSetLoader(DataSetLoader):
|
||||
def __init__(self):
|
||||
super(RawDataSetLoader, self).__init__()
|
||||
|
||||
def load(self, data_path, split=None):
|
||||
with open(data_path, "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
lines = lines if split is None else [l.split(split) for l in lines]
|
||||
lines = list(filter(lambda x: len(x) > 0, lines))
|
||||
return self.convert(lines)
|
||||
|
||||
def convert(self, data):
|
||||
return convert_seq_dataset(data)
|
||||
|
||||
class POSDataSetLoader(DataSetLoader):
|
||||
"""Dataset Loader for POS Tag datasets.
|
||||
@ -48,7 +133,8 @@ class POSDataSetLoader(DataSetLoader):
|
||||
"""
|
||||
with open(data_path, "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
return self.parse(lines)
|
||||
data = self.parse(lines)
|
||||
return self.convert(data)
|
||||
|
||||
@staticmethod
|
||||
def parse(lines):
|
||||
@ -75,6 +161,10 @@ class POSDataSetLoader(DataSetLoader):
|
||||
data.append([words, labels])
|
||||
return data
|
||||
|
||||
def convert(self, data):
|
||||
"""Convert lists of strings into Instances with Fields.
|
||||
"""
|
||||
return convert_seq2seq_dataset(data)
|
||||
|
||||
class TokenizeDataSetLoader(DataSetLoader):
|
||||
"""
|
||||
@ -84,8 +174,7 @@ class TokenizeDataSetLoader(DataSetLoader):
|
||||
def __init__(self):
|
||||
super(TokenizeDataSetLoader, self).__init__()
|
||||
|
||||
@staticmethod
|
||||
def load(data_path, max_seq_len=32):
|
||||
def load(self, data_path, max_seq_len=32):
|
||||
"""
|
||||
load pku dataset for Chinese word segmentation
|
||||
CWS (Chinese Word Segmentation) pku training dataset format:
|
||||
@ -130,7 +219,10 @@ class TokenizeDataSetLoader(DataSetLoader):
|
||||
seq_words = words[start:end]
|
||||
seq_labels = labels[start:end]
|
||||
data.append([seq_words, seq_labels])
|
||||
return data
|
||||
return self.convert(data)
|
||||
|
||||
def convert(self, data):
|
||||
return convert_seq2seq_dataset(data)
|
||||
|
||||
|
||||
class ClassDataSetLoader(DataSetLoader):
|
||||
@ -143,7 +235,8 @@ class ClassDataSetLoader(DataSetLoader):
|
||||
assert os.path.exists(data_path)
|
||||
with open(data_path, "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
return self.parse(lines)
|
||||
data = self.parse(lines)
|
||||
return self.convert(data)
|
||||
|
||||
@staticmethod
|
||||
def parse(lines):
|
||||
@ -166,16 +259,18 @@ class ClassDataSetLoader(DataSetLoader):
|
||||
dataset.append(sentence)
|
||||
return dataset
|
||||
|
||||
def convert(self, data):
|
||||
return convert_seq2tag_dataset(data)
|
||||
|
||||
|
||||
class ConllLoader(DataSetLoader):
|
||||
"""loader for conll format files"""
|
||||
|
||||
def __int__(self, data_path):
|
||||
def __init__(self):
|
||||
"""
|
||||
:param str data_path: the path to the conll data set
|
||||
"""
|
||||
super(ConllLoader, self).__init__()
|
||||
self.data_set = self.parse(self.load(data_path))
|
||||
|
||||
def load(self, data_path):
|
||||
"""
|
||||
@ -183,7 +278,8 @@ class ConllLoader(DataSetLoader):
|
||||
"""
|
||||
with open(data_path, "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
return lines
|
||||
data = self.parse(lines)
|
||||
return self.convert(data)
|
||||
|
||||
@staticmethod
|
||||
def parse(lines):
|
||||
@ -204,6 +300,9 @@ class ConllLoader(DataSetLoader):
|
||||
tokens.append(line.split())
|
||||
return sentences
|
||||
|
||||
def convert(self, data):
|
||||
pass
|
||||
|
||||
|
||||
class LMDataSetLoader(DataSetLoader):
|
||||
"""Language Model Dataset Loader
|
||||
@ -222,7 +321,8 @@ class LMDataSetLoader(DataSetLoader):
|
||||
with open(data_path, "r", encoding="utf=8") as f:
|
||||
text = " ".join(f.readlines())
|
||||
tokens = text.strip().split()
|
||||
return self.sentence_cut(tokens)
|
||||
data = self.sentence_cut(tokens)
|
||||
return self.convert(data)
|
||||
|
||||
def sentence_cut(self, tokens, sentence_length=15):
|
||||
start_idx = 0
|
||||
@ -236,6 +336,8 @@ class LMDataSetLoader(DataSetLoader):
|
||||
data_set.append([x, y])
|
||||
return data_set
|
||||
|
||||
def convert(self, data):
|
||||
pass
|
||||
|
||||
class PeopleDailyCorpusLoader(DataSetLoader):
|
||||
"""
|
||||
@ -286,3 +388,5 @@ class PeopleDailyCorpusLoader(DataSetLoader):
|
||||
ner_examples.append([sent_words, sent_ner])
|
||||
return pos_tag_examples, ner_examples
|
||||
|
||||
def convert(self, data):
|
||||
pass
|
||||
|
@ -12,7 +12,7 @@ from fastNLP.loader.model_loader import ModelLoader
|
||||
from fastNLP.core.tester import SeqLabelTester
|
||||
from fastNLP.models.sequence_modeling import AdvSeqLabel
|
||||
from fastNLP.core.predictor import SeqLabelInfer
|
||||
from fastNLP.core.dataset import SeqLabelDataSet, change_field_is_target
|
||||
from fastNLP.core.dataset import DataSet
|
||||
from fastNLP.core.preprocess import save_pickle
|
||||
from fastNLP.core.metrics import SeqLabelEvaluator
|
||||
|
||||
|
@ -3,7 +3,7 @@ import unittest
|
||||
import torch
|
||||
|
||||
from fastNLP.core.batch import Batch
|
||||
from fastNLP.core.dataset import DataSet, create_dataset_from_lists
|
||||
from fastNLP.core.dataset import DataSet
|
||||
from fastNLP.core.field import TextField, LabelField
|
||||
from fastNLP.core.instance import Instance
|
||||
|
||||
@ -51,14 +51,3 @@ class TestCase1(unittest.TestCase):
|
||||
self.assertTrue(isinstance(batch_x["text"], torch.LongTensor))
|
||||
self.assertTrue(isinstance(batch_y, dict))
|
||||
self.assertTrue(isinstance(batch_y["label"], torch.LongTensor))
|
||||
|
||||
|
||||
class TestCase2(unittest.TestCase):
|
||||
def test(self):
|
||||
data = DataSet()
|
||||
for text in texts:
|
||||
x = TextField(text, is_target=False)
|
||||
ins = Instance(text=x)
|
||||
data.append(ins)
|
||||
data_set = create_dataset_from_lists(texts, vocab, has_target=False)
|
||||
self.assertTrue(type(data) == type(data_set))
|
||||
|
@ -1,7 +1,6 @@
|
||||
import unittest
|
||||
|
||||
from fastNLP.core.dataset import SeqLabelDataSet, TextClassifyDataSet
|
||||
from fastNLP.core.dataset import create_dataset_from_lists
|
||||
from fastNLP.loader.dataset_loader import convert_seq2seq_dataset, convert_seq_dataset
|
||||
|
||||
|
||||
class TestDataSet(unittest.TestCase):
|
||||
@ -19,8 +18,9 @@ class TestDataSet(unittest.TestCase):
|
||||
label_vocab = {"1": 1, "2": 2, "3": 3, "4": 4}
|
||||
|
||||
def test_case_1(self):
|
||||
data_set = create_dataset_from_lists(self.labeled_data_list, self.word_vocab, has_target=True,
|
||||
label_vocab=self.label_vocab)
|
||||
data_set = convert_seq2seq_dataset(self.labeled_data_list)
|
||||
data_set.index_field("word_seq", self.word_vocab)
|
||||
data_set.index_field("label_seq", self.label_vocab)
|
||||
self.assertEqual(len(data_set), len(self.labeled_data_list))
|
||||
self.assertTrue(len(data_set) > 0)
|
||||
self.assertTrue(hasattr(data_set[0], "fields"))
|
||||
@ -39,7 +39,8 @@ class TestDataSet(unittest.TestCase):
|
||||
[self.label_vocab[c] for c in self.labeled_data_list[0][1]])
|
||||
|
||||
def test_case_2(self):
|
||||
data_set = create_dataset_from_lists(self.unlabeled_data_list, self.word_vocab, has_target=False)
|
||||
data_set = convert_seq_dataset(self.unlabeled_data_list)
|
||||
data_set.index_field("word_seq", self.word_vocab)
|
||||
|
||||
self.assertEqual(len(data_set), len(self.unlabeled_data_list))
|
||||
self.assertTrue(len(data_set) > 0)
|
||||
@ -51,193 +52,3 @@ class TestDataSet(unittest.TestCase):
|
||||
self.assertEqual(data_set[0].fields["word_seq"]._index,
|
||||
[self.word_vocab[c] for c in self.unlabeled_data_list[0]])
|
||||
|
||||
|
||||
class TestDataSetConvertion(unittest.TestCase):
|
||||
labeled_data_list = [
|
||||
[["a", "b", "e", "d"], ["1", "2", "3", "4"]],
|
||||
[["a", "b", "e", "d"], ["1", "2", "3", "4"]],
|
||||
[["a", "b", "e", "d"], ["1", "2", "3", "4"]],
|
||||
]
|
||||
unlabeled_data_list = [
|
||||
["a", "b", "e", "d"],
|
||||
["a", "b", "e", "d"],
|
||||
["a", "b", "e", "d"]
|
||||
]
|
||||
word_vocab = {"a": 0, "b": 1, "e": 2, "d": 3}
|
||||
label_vocab = {"1": 1, "2": 2, "3": 3, "4": 4}
|
||||
|
||||
def test_case_1(self):
|
||||
def loader(path):
|
||||
labeled_data_list = [
|
||||
[["a", "b", "e", "d"], ["1", "2", "3", "4"]],
|
||||
[["a", "b", "e", "d"], ["1", "2", "3", "4"]],
|
||||
[["a", "b", "e", "d"], ["1", "2", "3", "4"]],
|
||||
]
|
||||
return labeled_data_list
|
||||
|
||||
data_set = SeqLabelDataSet(load_func=loader)
|
||||
data_set.load("any_path")
|
||||
|
||||
self.assertEqual(len(data_set), len(self.labeled_data_list))
|
||||
self.assertTrue(len(data_set) > 0)
|
||||
self.assertTrue(hasattr(data_set[0], "fields"))
|
||||
self.assertTrue("word_seq" in data_set[0].fields)
|
||||
|
||||
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text"))
|
||||
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index"))
|
||||
self.assertEqual(data_set[0].fields["word_seq"].text, self.labeled_data_list[0][0])
|
||||
|
||||
self.assertTrue("truth" in data_set[0].fields)
|
||||
self.assertTrue(hasattr(data_set[0].fields["truth"], "text"))
|
||||
self.assertTrue(hasattr(data_set[0].fields["truth"], "_index"))
|
||||
self.assertEqual(data_set[0].fields["truth"].text, self.labeled_data_list[0][1])
|
||||
|
||||
self.assertTrue("word_seq_origin_len" in data_set[0].fields)
|
||||
|
||||
def test_case_2(self):
|
||||
def loader(path):
|
||||
unlabeled_data_list = [
|
||||
["a", "b", "e", "d"],
|
||||
["a", "b", "e", "d"],
|
||||
["a", "b", "e", "d"]
|
||||
]
|
||||
return unlabeled_data_list
|
||||
|
||||
data_set = SeqLabelDataSet(load_func=loader)
|
||||
data_set.load("any_path", vocabs={"word_vocab": self.word_vocab}, infer=True)
|
||||
|
||||
self.assertEqual(len(data_set), len(self.labeled_data_list))
|
||||
self.assertTrue(len(data_set) > 0)
|
||||
self.assertTrue(hasattr(data_set[0], "fields"))
|
||||
self.assertTrue("word_seq" in data_set[0].fields)
|
||||
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text"))
|
||||
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index"))
|
||||
self.assertEqual(data_set[0].fields["word_seq"].text, self.labeled_data_list[0][0])
|
||||
self.assertEqual(data_set[0].fields["word_seq"]._index,
|
||||
[self.word_vocab[c] for c in self.labeled_data_list[0][0]])
|
||||
|
||||
self.assertTrue("word_seq_origin_len" in data_set[0].fields)
|
||||
|
||||
def test_case_3(self):
|
||||
def loader(path):
|
||||
labeled_data_list = [
|
||||
[["a", "b", "e", "d"], ["1", "2", "3", "4"]],
|
||||
[["a", "b", "e", "d"], ["1", "2", "3", "4"]],
|
||||
[["a", "b", "e", "d"], ["1", "2", "3", "4"]],
|
||||
]
|
||||
return labeled_data_list
|
||||
|
||||
data_set = SeqLabelDataSet(load_func=loader)
|
||||
data_set.load("any_path", vocabs={"word_vocab": self.word_vocab, "label_vocab": self.label_vocab})
|
||||
|
||||
self.assertEqual(len(data_set), len(self.labeled_data_list))
|
||||
self.assertTrue(len(data_set) > 0)
|
||||
self.assertTrue(hasattr(data_set[0], "fields"))
|
||||
self.assertTrue("word_seq" in data_set[0].fields)
|
||||
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text"))
|
||||
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index"))
|
||||
self.assertEqual(data_set[0].fields["word_seq"].text, self.labeled_data_list[0][0])
|
||||
self.assertEqual(data_set[0].fields["word_seq"]._index,
|
||||
[self.word_vocab[c] for c in self.labeled_data_list[0][0]])
|
||||
|
||||
self.assertTrue("truth" in data_set[0].fields)
|
||||
self.assertTrue(hasattr(data_set[0].fields["truth"], "text"))
|
||||
self.assertTrue(hasattr(data_set[0].fields["truth"], "_index"))
|
||||
self.assertEqual(data_set[0].fields["truth"].text, self.labeled_data_list[0][1])
|
||||
self.assertEqual(data_set[0].fields["truth"]._index,
|
||||
[self.label_vocab[c] for c in self.labeled_data_list[0][1]])
|
||||
|
||||
self.assertTrue("word_seq_origin_len" in data_set[0].fields)
|
||||
|
||||
|
||||
class TestDataSetConvertionHHH(unittest.TestCase):
|
||||
labeled_data_list = [
|
||||
[["a", "b", "e", "d"], "A"],
|
||||
[["a", "b", "e", "d"], "C"],
|
||||
[["a", "b", "e", "d"], "B"],
|
||||
]
|
||||
unlabeled_data_list = [
|
||||
["a", "b", "e", "d"],
|
||||
["a", "b", "e", "d"],
|
||||
["a", "b", "e", "d"]
|
||||
]
|
||||
word_vocab = {"a": 0, "b": 1, "e": 2, "d": 3}
|
||||
label_vocab = {"A": 1, "B": 2, "C": 3}
|
||||
|
||||
def test_case_1(self):
|
||||
def loader(path):
|
||||
labeled_data_list = [
|
||||
[["a", "b", "e", "d"], "A"],
|
||||
[["a", "b", "e", "d"], "C"],
|
||||
[["a", "b", "e", "d"], "B"],
|
||||
]
|
||||
return labeled_data_list
|
||||
|
||||
data_set = TextClassifyDataSet(load_func=loader)
|
||||
data_set.load("xxx")
|
||||
|
||||
self.assertEqual(len(data_set), len(self.labeled_data_list))
|
||||
self.assertTrue(len(data_set) > 0)
|
||||
self.assertTrue(hasattr(data_set[0], "fields"))
|
||||
self.assertTrue("word_seq" in data_set[0].fields)
|
||||
|
||||
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text"))
|
||||
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index"))
|
||||
self.assertEqual(data_set[0].fields["word_seq"].text, self.labeled_data_list[0][0])
|
||||
|
||||
self.assertTrue("label" in data_set[0].fields)
|
||||
self.assertTrue(hasattr(data_set[0].fields["label"], "label"))
|
||||
self.assertTrue(hasattr(data_set[0].fields["label"], "_index"))
|
||||
self.assertEqual(data_set[0].fields["label"].label, self.labeled_data_list[0][1])
|
||||
|
||||
def test_case_2(self):
|
||||
def loader(path):
|
||||
labeled_data_list = [
|
||||
[["a", "b", "e", "d"], "A"],
|
||||
[["a", "b", "e", "d"], "C"],
|
||||
[["a", "b", "e", "d"], "B"],
|
||||
]
|
||||
return labeled_data_list
|
||||
|
||||
data_set = TextClassifyDataSet(load_func=loader)
|
||||
data_set.load("xxx", vocabs={"word_vocab": self.word_vocab, "label_vocab": self.label_vocab})
|
||||
|
||||
self.assertEqual(len(data_set), len(self.labeled_data_list))
|
||||
self.assertTrue(len(data_set) > 0)
|
||||
self.assertTrue(hasattr(data_set[0], "fields"))
|
||||
self.assertTrue("word_seq" in data_set[0].fields)
|
||||
|
||||
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text"))
|
||||
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index"))
|
||||
self.assertEqual(data_set[0].fields["word_seq"].text, self.labeled_data_list[0][0])
|
||||
self.assertEqual(data_set[0].fields["word_seq"]._index,
|
||||
[self.word_vocab[c] for c in self.labeled_data_list[0][0]])
|
||||
|
||||
self.assertTrue("label" in data_set[0].fields)
|
||||
self.assertTrue(hasattr(data_set[0].fields["label"], "label"))
|
||||
self.assertTrue(hasattr(data_set[0].fields["label"], "_index"))
|
||||
self.assertEqual(data_set[0].fields["label"].label, self.labeled_data_list[0][1])
|
||||
self.assertEqual(data_set[0].fields["label"]._index, self.label_vocab[self.labeled_data_list[0][1]])
|
||||
|
||||
def test_case_3(self):
|
||||
def loader(path):
|
||||
unlabeled_data_list = [
|
||||
["a", "b", "e", "d"],
|
||||
["a", "b", "e", "d"],
|
||||
["a", "b", "e", "d"]
|
||||
]
|
||||
return unlabeled_data_list
|
||||
|
||||
data_set = TextClassifyDataSet(load_func=loader)
|
||||
data_set.load("xxx", vocabs={"word_vocab": self.word_vocab}, infer=True)
|
||||
|
||||
self.assertEqual(len(data_set), len(self.labeled_data_list))
|
||||
self.assertTrue(len(data_set) > 0)
|
||||
self.assertTrue(hasattr(data_set[0], "fields"))
|
||||
self.assertTrue("word_seq" in data_set[0].fields)
|
||||
|
||||
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text"))
|
||||
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index"))
|
||||
self.assertEqual(data_set[0].fields["word_seq"].text, self.labeled_data_list[0][0])
|
||||
self.assertEqual(data_set[0].fields["word_seq"]._index,
|
||||
[self.word_vocab[c] for c in self.labeled_data_list[0][0]])
|
||||
|
@ -1,11 +1,12 @@
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from fastNLP.core.dataset import TextClassifyDataSet, SeqLabelDataSet
|
||||
from fastNLP.core.dataset import DataSet
|
||||
from fastNLP.core.predictor import Predictor
|
||||
from fastNLP.core.preprocess import save_pickle
|
||||
from fastNLP.core.vocabulary import Vocabulary
|
||||
from fastNLP.loader.base_loader import BaseLoader
|
||||
from fastNLP.loader.dataset_loader import convert_seq_dataset
|
||||
from fastNLP.models.cnn_text_classification import CNNText
|
||||
from fastNLP.models.sequence_modeling import SeqLabeling
|
||||
|
||||
@ -42,8 +43,8 @@ class TestPredictor(unittest.TestCase):
|
||||
predictor = Predictor("./save/", pre.text_classify_post_processor)
|
||||
|
||||
# Load infer data
|
||||
infer_data_set = TextClassifyDataSet(load_func=BaseLoader.load)
|
||||
infer_data_set.convert_for_infer(infer_data, vocabs={"word_vocab": vocab.word2idx})
|
||||
infer_data_set = convert_seq_dataset(infer_data)
|
||||
infer_data_set.index_field("word_seq", vocab)
|
||||
|
||||
results = predictor.predict(network=model, data=infer_data_set)
|
||||
|
||||
@ -54,14 +55,12 @@ class TestPredictor(unittest.TestCase):
|
||||
self.assertTrue(isinstance(res, str))
|
||||
self.assertTrue(res in class_vocab.word2idx)
|
||||
|
||||
del model, predictor, infer_data_set
|
||||
del model, predictor
|
||||
infer_data_set.set_origin_len("word_seq")
|
||||
|
||||
model = SeqLabeling(model_args)
|
||||
predictor = Predictor("./save/", pre.seq_label_post_processor)
|
||||
|
||||
infer_data_set = SeqLabelDataSet(load_func=BaseLoader.load)
|
||||
infer_data_set.convert_for_infer(infer_data, vocabs={"word_vocab": vocab.word2idx})
|
||||
|
||||
results = predictor.predict(network=model, data=infer_data_set)
|
||||
self.assertTrue(isinstance(results, list))
|
||||
self.assertEqual(len(results), len(infer_data))
|
||||
|
@ -1,7 +1,7 @@
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from fastNLP.core.dataset import SeqLabelDataSet
|
||||
from fastNLP.core.dataset import DataSet
|
||||
from fastNLP.core.metrics import SeqLabelEvaluator
|
||||
from fastNLP.core.field import TextField, LabelField
|
||||
from fastNLP.core.instance import Instance
|
||||
@ -35,7 +35,7 @@ class TestTester(unittest.TestCase):
|
||||
vocab = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9}
|
||||
label_vocab = {'a': 0, '@': 1, 'c': 2, 'd': 3, 'e': 4}
|
||||
|
||||
data_set = SeqLabelDataSet()
|
||||
data_set = DataSet()
|
||||
for example in train_data:
|
||||
text, label = example[0], example[1]
|
||||
x = TextField(text, False)
|
||||
|
@ -1,7 +1,7 @@
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from fastNLP.core.dataset import SeqLabelDataSet
|
||||
from fastNLP.core.dataset import DataSet
|
||||
from fastNLP.core.metrics import SeqLabelEvaluator
|
||||
from fastNLP.core.field import TextField, LabelField
|
||||
from fastNLP.core.instance import Instance
|
||||
@ -36,7 +36,7 @@ class TestTrainer(unittest.TestCase):
|
||||
vocab = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9}
|
||||
label_vocab = {'a': 0, '@': 1, 'c': 2, 'd': 3, 'e': 4}
|
||||
|
||||
data_set = SeqLabelDataSet()
|
||||
data_set = DataSet()
|
||||
for example in train_data:
|
||||
text, label = example[0], example[1]
|
||||
x = TextField(text, False)
|
||||
|
@ -1,13 +1,14 @@
|
||||
import os
|
||||
|
||||
from fastNLP.core.dataset import SeqLabelDataSet, change_field_is_target
|
||||
from fastNLP.core.dataset import DataSet
|
||||
from fastNLP.core.vocabulary import Vocabulary
|
||||
from fastNLP.core.metrics import SeqLabelEvaluator
|
||||
from fastNLP.core.predictor import SeqLabelInfer
|
||||
from fastNLP.core.preprocess import save_pickle, load_pickle
|
||||
from fastNLP.core.tester import SeqLabelTester
|
||||
from fastNLP.core.trainer import SeqLabelTrainer
|
||||
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
|
||||
from fastNLP.loader.dataset_loader import TokenizeDataSetLoader, BaseLoader
|
||||
from fastNLP.loader.dataset_loader import TokenizeDataSetLoader, BaseLoader, RawDataSetLoader
|
||||
from fastNLP.loader.model_loader import ModelLoader
|
||||
from fastNLP.models.sequence_modeling import SeqLabeling
|
||||
from fastNLP.saver.model_saver import ModelSaver
|
||||
@ -37,9 +38,9 @@ def infer():
|
||||
print("model loaded!")
|
||||
|
||||
# Load infer data
|
||||
infer_data = SeqLabelDataSet(load_func=BaseLoader.load)
|
||||
infer_data.load(data_infer_path, vocabs={"word_vocab": word2index}, infer=True)
|
||||
|
||||
infer_data = RawDataSetLoader().load(data_infer_path)
|
||||
infer_data.index_field("word_seq", word2index)
|
||||
infer_data.set_origin_len("word_seq")
|
||||
# inference
|
||||
infer = SeqLabelInfer(pickle_path)
|
||||
results = infer.predict(model, infer_data)
|
||||
@ -52,13 +53,18 @@ def train_test():
|
||||
ConfigLoader().load_config(config_path, {"POS_infer": train_args})
|
||||
|
||||
# define dataset
|
||||
data_train = SeqLabelDataSet(load_func=TokenizeDataSetLoader.load)
|
||||
data_train.load(cws_data_path)
|
||||
train_args["vocab_size"] = len(data_train.word_vocab)
|
||||
train_args["num_classes"] = len(data_train.label_vocab)
|
||||
data_train = TokenizeDataSetLoader().load(cws_data_path)
|
||||
word_vocab = Vocabulary()
|
||||
label_vocab = Vocabulary()
|
||||
data_train.update_vocab(word_seq=word_vocab, label_seq=label_vocab)
|
||||
data_train.index_field("word_seq", word_vocab).index_field("label_seq", label_vocab)
|
||||
data_train.set_origin_len("word_seq")
|
||||
data_train.rename_field("label_seq", "truth").set_target(truth=False)
|
||||
train_args["vocab_size"] = len(word_vocab)
|
||||
train_args["num_classes"] = len(label_vocab)
|
||||
|
||||
save_pickle(data_train.word_vocab, pickle_path, "word2id.pkl")
|
||||
save_pickle(data_train.label_vocab, pickle_path, "label2id.pkl")
|
||||
save_pickle(word_vocab, pickle_path, "word2id.pkl")
|
||||
save_pickle(label_vocab, pickle_path, "label2id.pkl")
|
||||
|
||||
# Trainer
|
||||
trainer = SeqLabelTrainer(**train_args.data)
|
||||
@ -90,7 +96,7 @@ def train_test():
|
||||
tester = SeqLabelTester(**test_args.data)
|
||||
|
||||
# Start testing
|
||||
change_field_is_target(data_train, "truth", True)
|
||||
data_train.set_target(truth=True)
|
||||
tester.test(model, data_train)
|
||||
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
import os
|
||||
|
||||
from fastNLP.core.dataset import SeqLabelDataSet, change_field_is_target
|
||||
from fastNLP.core.vocabulary import Vocabulary
|
||||
from fastNLP.loader.dataset_loader import TokenizeDataSetLoader
|
||||
from fastNLP.core.metrics import SeqLabelEvaluator
|
||||
from fastNLP.core.optimizer import Optimizer
|
||||
from fastNLP.core.preprocess import save_pickle
|
||||
@ -25,14 +26,19 @@ def test_training():
|
||||
ConfigLoader().load_config(config_dir, {
|
||||
"test_seq_label_trainer": trainer_args, "test_seq_label_model": model_args})
|
||||
|
||||
data_set = SeqLabelDataSet()
|
||||
data_set.load(data_path)
|
||||
data_set = TokenizeDataSetLoader().load(data_path)
|
||||
word_vocab = Vocabulary()
|
||||
label_vocab = Vocabulary()
|
||||
data_set.update_vocab(word_seq=word_vocab, label_seq=label_vocab)
|
||||
data_set.index_field("word_seq", word_vocab).index_field("label_seq", label_vocab)
|
||||
data_set.set_origin_len("word_seq")
|
||||
data_set.rename_field("label_seq", "truth").set_target(truth=False)
|
||||
data_train, data_dev = data_set.split(0.3, shuffle=True)
|
||||
model_args["vocab_size"] = len(data_set.word_vocab)
|
||||
model_args["num_classes"] = len(data_set.label_vocab)
|
||||
model_args["vocab_size"] = len(word_vocab)
|
||||
model_args["num_classes"] = len(label_vocab)
|
||||
|
||||
save_pickle(data_set.word_vocab, pickle_path, "word2id.pkl")
|
||||
save_pickle(data_set.label_vocab, pickle_path, "label2id.pkl")
|
||||
save_pickle(word_vocab, pickle_path, "word2id.pkl")
|
||||
save_pickle(label_vocab, pickle_path, "label2id.pkl")
|
||||
|
||||
trainer = SeqLabelTrainer(
|
||||
epochs=trainer_args["epochs"],
|
||||
@ -76,5 +82,5 @@ def test_training():
|
||||
)
|
||||
|
||||
# Start testing with validation data
|
||||
change_field_is_target(data_dev, "truth", True)
|
||||
data_dev.set_target(truth=True)
|
||||
tester.test(model, data_dev)
|
||||
|
Loading…
Reference in New Issue
Block a user