mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-04 13:17:51 +08:00
add basic Field support
This commit is contained in:
parent
82502aa67d
commit
bc04b3e7fd
86
fastNLP/data/batch.py
Normal file
86
fastNLP/data/batch.py
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
from collections import defaultdict
|
||||||
|
import torch
|
||||||
|
|
||||||
|
class Batch(object):
|
||||||
|
def __init__(self, dataset, sampler, batch_size):
|
||||||
|
self.dataset = dataset
|
||||||
|
self.sampler = sampler
|
||||||
|
self.batch_size = batch_size
|
||||||
|
|
||||||
|
self.idx_list = None
|
||||||
|
self.curidx = 0
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
self.idx_list = self.sampler(self.dataset)
|
||||||
|
self.curidx = 0
|
||||||
|
self.lengths = self.dataset.get_length()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
if self.curidx >= len(self.idx_list):
|
||||||
|
raise StopIteration
|
||||||
|
else:
|
||||||
|
endidx = min(self.curidx + self.batch_size, len(self.idx_list))
|
||||||
|
padding_length = {field_name : max(field_length[self.curidx: endidx])
|
||||||
|
for field_name, field_length in self.lengths.items()}
|
||||||
|
|
||||||
|
batch_x, batch_y = defaultdict(list), defaultdict(list)
|
||||||
|
for idx in range(self.curidx, endidx):
|
||||||
|
x, y = self.dataset.to_tensor(idx, padding_length)
|
||||||
|
for name, tensor in x.items():
|
||||||
|
batch_x[name].append(tensor)
|
||||||
|
for name, tensor in y.items():
|
||||||
|
batch_y[name].append(tensor)
|
||||||
|
|
||||||
|
for batch in (batch_x, batch_y):
|
||||||
|
for name, tensor_list in batch.items():
|
||||||
|
print(name, " ", tensor_list)
|
||||||
|
batch[name] = torch.stack(tensor_list, dim=0)
|
||||||
|
self.curidx += endidx
|
||||||
|
return batch_x, batch_y
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
"""simple running example
|
||||||
|
"""
|
||||||
|
from field import TextField, LabelField
|
||||||
|
from instance import Instance
|
||||||
|
from dataset import DataSet
|
||||||
|
|
||||||
|
texts = ["i am a cat",
|
||||||
|
"this is a test of new batch",
|
||||||
|
"haha"
|
||||||
|
]
|
||||||
|
labels = [0, 1, 0]
|
||||||
|
|
||||||
|
# prepare vocabulary
|
||||||
|
vocab = {}
|
||||||
|
for text in texts:
|
||||||
|
for tokens in text.split():
|
||||||
|
if tokens not in vocab:
|
||||||
|
vocab[tokens] = len(vocab)
|
||||||
|
|
||||||
|
# prepare input dataset
|
||||||
|
data = DataSet()
|
||||||
|
for text, label in zip(texts, labels):
|
||||||
|
x = TextField(text.split(), False)
|
||||||
|
y = LabelField(label, is_target=True)
|
||||||
|
ins = Instance(text=x, label=y)
|
||||||
|
data.append(ins)
|
||||||
|
|
||||||
|
# use vocabulary to index data
|
||||||
|
data.index_field("text", vocab)
|
||||||
|
|
||||||
|
# define naive sampler for batch class
|
||||||
|
class SeqSampler:
|
||||||
|
def __call__(self, dataset):
|
||||||
|
return list(range(len(dataset)))
|
||||||
|
|
||||||
|
# use bacth to iterate dataset
|
||||||
|
batcher = Batch(data, SeqSampler(), 2)
|
||||||
|
for epoch in range(3):
|
||||||
|
for batch_x, batch_y in batcher:
|
||||||
|
print(batch_x)
|
||||||
|
print(batch_y)
|
||||||
|
# do stuff
|
||||||
|
|
29
fastNLP/data/dataset.py
Normal file
29
fastNLP/data/dataset.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
|
||||||
|
class DataSet(list):
|
||||||
|
def __init__(self, name="", instances=None):
|
||||||
|
list.__init__([])
|
||||||
|
self.name = name
|
||||||
|
if instances is not None:
|
||||||
|
self.extend(instances)
|
||||||
|
|
||||||
|
def index_all(self, vocab):
|
||||||
|
for ins in self:
|
||||||
|
ins.index_all(vocab)
|
||||||
|
|
||||||
|
def index_field(self, field_name, vocab):
|
||||||
|
for ins in self:
|
||||||
|
ins.index_field(field_name, vocab)
|
||||||
|
|
||||||
|
def to_tensor(self, idx: int, padding_length: dict):
|
||||||
|
ins = self[idx]
|
||||||
|
return ins.to_tensor(padding_length)
|
||||||
|
|
||||||
|
def get_length(self):
|
||||||
|
lengths = defaultdict(list)
|
||||||
|
for ins in self:
|
||||||
|
for field_name, field_length in ins.get_length().items():
|
||||||
|
lengths[field_name].append(field_length)
|
||||||
|
return lengths
|
||||||
|
|
70
fastNLP/data/field.py
Normal file
70
fastNLP/data/field.py
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
class Field(object):
|
||||||
|
def __init__(self, is_target: bool):
|
||||||
|
self.is_target = is_target
|
||||||
|
|
||||||
|
def index(self, vocab):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_length(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def to_tensor(self, padding_length):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TextField(Field):
|
||||||
|
def __init__(self, text: list, is_target):
|
||||||
|
"""
|
||||||
|
:param list text:
|
||||||
|
"""
|
||||||
|
super(TextField, self).__init__(is_target)
|
||||||
|
self.text = text
|
||||||
|
self._index = None
|
||||||
|
|
||||||
|
def index(self, vocab):
|
||||||
|
if self._index is None:
|
||||||
|
self._index = [vocab[c] for c in self.text]
|
||||||
|
else:
|
||||||
|
print('error')
|
||||||
|
return self._index
|
||||||
|
|
||||||
|
def get_length(self):
|
||||||
|
return len(self.text)
|
||||||
|
|
||||||
|
def to_tensor(self, padding_length: int):
|
||||||
|
pads = []
|
||||||
|
if self._index is None:
|
||||||
|
print('error')
|
||||||
|
if padding_length > self.get_length():
|
||||||
|
pads = [0 for i in range(padding_length - self.get_length())]
|
||||||
|
# (length, )
|
||||||
|
return torch.LongTensor(self._index + pads)
|
||||||
|
|
||||||
|
|
||||||
|
class LabelField(Field):
|
||||||
|
def __init__(self, label, is_target=True):
|
||||||
|
super(LabelField, self).__init__(is_target)
|
||||||
|
self.label = label
|
||||||
|
self._index = None
|
||||||
|
|
||||||
|
def get_length(self):
|
||||||
|
return 1
|
||||||
|
|
||||||
|
def index(self, vocab):
|
||||||
|
if self._index is None:
|
||||||
|
self._index = vocab[self.label]
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
return self._index
|
||||||
|
|
||||||
|
def to_tensor(self, padding_length):
|
||||||
|
if self._index is None:
|
||||||
|
return torch.LongTensor([self.label])
|
||||||
|
else:
|
||||||
|
return torch.LongTensor([self._index])
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
tf = TextField("test the code".split())
|
||||||
|
|
38
fastNLP/data/instance.py
Normal file
38
fastNLP/data/instance.py
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
class Instance(object):
|
||||||
|
def __init__(self, **fields):
|
||||||
|
self.fields = fields
|
||||||
|
self.has_index = False
|
||||||
|
self.indexes = {}
|
||||||
|
|
||||||
|
def add_field(self, field_name, field):
|
||||||
|
self.fields[field_name] = field
|
||||||
|
|
||||||
|
def get_length(self):
|
||||||
|
length = {name : field.get_length() for name, field in self.fields.items()}
|
||||||
|
return length
|
||||||
|
|
||||||
|
def index_field(self, field_name, vocab):
|
||||||
|
"""use `vocab` to index certain field
|
||||||
|
"""
|
||||||
|
self.indexes[field_name] = self.fields[field_name].index(vocab)
|
||||||
|
|
||||||
|
def index_all(self, vocab):
|
||||||
|
"""use `vocab` to index all fields
|
||||||
|
"""
|
||||||
|
if self.has_index:
|
||||||
|
print("error")
|
||||||
|
return self.indexes
|
||||||
|
indexes = {name : field.index(vocab) for name, field in self.fields.items()}
|
||||||
|
self.indexes = indexes
|
||||||
|
return indexes
|
||||||
|
|
||||||
|
def to_tensor(self, padding_length: dict):
|
||||||
|
tensorX = {}
|
||||||
|
tensorY = {}
|
||||||
|
for name, field in self.fields.items():
|
||||||
|
if field.is_target:
|
||||||
|
tensorY[name] = field.to_tensor(padding_length[name])
|
||||||
|
else:
|
||||||
|
tensorX[name] = field.to_tensor(padding_length[name])
|
||||||
|
|
||||||
|
return tensorX, tensorY
|
Loading…
Reference in New Issue
Block a user