add basic Field support

This commit is contained in:
yunfan 2018-09-14 12:38:38 +08:00
parent 82502aa67d
commit bc04b3e7fd
4 changed files with 223 additions and 0 deletions

86
fastNLP/data/batch.py Normal file
View File

@ -0,0 +1,86 @@
from collections import defaultdict
import torch
class Batch(object):
def __init__(self, dataset, sampler, batch_size):
self.dataset = dataset
self.sampler = sampler
self.batch_size = batch_size
self.idx_list = None
self.curidx = 0
def __iter__(self):
self.idx_list = self.sampler(self.dataset)
self.curidx = 0
self.lengths = self.dataset.get_length()
return self
def __next__(self):
if self.curidx >= len(self.idx_list):
raise StopIteration
else:
endidx = min(self.curidx + self.batch_size, len(self.idx_list))
padding_length = {field_name : max(field_length[self.curidx: endidx])
for field_name, field_length in self.lengths.items()}
batch_x, batch_y = defaultdict(list), defaultdict(list)
for idx in range(self.curidx, endidx):
x, y = self.dataset.to_tensor(idx, padding_length)
for name, tensor in x.items():
batch_x[name].append(tensor)
for name, tensor in y.items():
batch_y[name].append(tensor)
for batch in (batch_x, batch_y):
for name, tensor_list in batch.items():
print(name, " ", tensor_list)
batch[name] = torch.stack(tensor_list, dim=0)
self.curidx += endidx
return batch_x, batch_y
if __name__ == "__main__":
"""simple running example
"""
from field import TextField, LabelField
from instance import Instance
from dataset import DataSet
texts = ["i am a cat",
"this is a test of new batch",
"haha"
]
labels = [0, 1, 0]
# prepare vocabulary
vocab = {}
for text in texts:
for tokens in text.split():
if tokens not in vocab:
vocab[tokens] = len(vocab)
# prepare input dataset
data = DataSet()
for text, label in zip(texts, labels):
x = TextField(text.split(), False)
y = LabelField(label, is_target=True)
ins = Instance(text=x, label=y)
data.append(ins)
# use vocabulary to index data
data.index_field("text", vocab)
# define naive sampler for batch class
class SeqSampler:
def __call__(self, dataset):
return list(range(len(dataset)))
# use bacth to iterate dataset
batcher = Batch(data, SeqSampler(), 2)
for epoch in range(3):
for batch_x, batch_y in batcher:
print(batch_x)
print(batch_y)
# do stuff

29
fastNLP/data/dataset.py Normal file
View File

@ -0,0 +1,29 @@
from collections import defaultdict
class DataSet(list):
def __init__(self, name="", instances=None):
list.__init__([])
self.name = name
if instances is not None:
self.extend(instances)
def index_all(self, vocab):
for ins in self:
ins.index_all(vocab)
def index_field(self, field_name, vocab):
for ins in self:
ins.index_field(field_name, vocab)
def to_tensor(self, idx: int, padding_length: dict):
ins = self[idx]
return ins.to_tensor(padding_length)
def get_length(self):
lengths = defaultdict(list)
for ins in self:
for field_name, field_length in ins.get_length().items():
lengths[field_name].append(field_length)
return lengths

70
fastNLP/data/field.py Normal file
View File

@ -0,0 +1,70 @@
import torch
class Field(object):
def __init__(self, is_target: bool):
self.is_target = is_target
def index(self, vocab):
pass
def get_length(self):
pass
def to_tensor(self, padding_length):
pass
class TextField(Field):
def __init__(self, text: list, is_target):
"""
:param list text:
"""
super(TextField, self).__init__(is_target)
self.text = text
self._index = None
def index(self, vocab):
if self._index is None:
self._index = [vocab[c] for c in self.text]
else:
print('error')
return self._index
def get_length(self):
return len(self.text)
def to_tensor(self, padding_length: int):
pads = []
if self._index is None:
print('error')
if padding_length > self.get_length():
pads = [0 for i in range(padding_length - self.get_length())]
# (length, )
return torch.LongTensor(self._index + pads)
class LabelField(Field):
def __init__(self, label, is_target=True):
super(LabelField, self).__init__(is_target)
self.label = label
self._index = None
def get_length(self):
return 1
def index(self, vocab):
if self._index is None:
self._index = vocab[self.label]
else:
pass
return self._index
def to_tensor(self, padding_length):
if self._index is None:
return torch.LongTensor([self.label])
else:
return torch.LongTensor([self._index])
if __name__ == "__main__":
tf = TextField("test the code".split())

38
fastNLP/data/instance.py Normal file
View File

@ -0,0 +1,38 @@
class Instance(object):
def __init__(self, **fields):
self.fields = fields
self.has_index = False
self.indexes = {}
def add_field(self, field_name, field):
self.fields[field_name] = field
def get_length(self):
length = {name : field.get_length() for name, field in self.fields.items()}
return length
def index_field(self, field_name, vocab):
"""use `vocab` to index certain field
"""
self.indexes[field_name] = self.fields[field_name].index(vocab)
def index_all(self, vocab):
"""use `vocab` to index all fields
"""
if self.has_index:
print("error")
return self.indexes
indexes = {name : field.index(vocab) for name, field in self.fields.items()}
self.indexes = indexes
return indexes
def to_tensor(self, padding_length: dict):
tensorX = {}
tensorY = {}
for name, field in self.fields.items():
if field.is_target:
tensorY[name] = field.to_tensor(padding_length[name])
else:
tensorX[name] = field.to_tensor(padding_length[name])
return tensorX, tensorY