mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-03 20:57:37 +08:00
add basic Field support
This commit is contained in:
parent
82502aa67d
commit
bc04b3e7fd
86
fastNLP/data/batch.py
Normal file
86
fastNLP/data/batch.py
Normal file
@ -0,0 +1,86 @@
|
||||
from collections import defaultdict
|
||||
import torch
|
||||
|
||||
class Batch(object):
|
||||
def __init__(self, dataset, sampler, batch_size):
|
||||
self.dataset = dataset
|
||||
self.sampler = sampler
|
||||
self.batch_size = batch_size
|
||||
|
||||
self.idx_list = None
|
||||
self.curidx = 0
|
||||
|
||||
def __iter__(self):
|
||||
self.idx_list = self.sampler(self.dataset)
|
||||
self.curidx = 0
|
||||
self.lengths = self.dataset.get_length()
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.curidx >= len(self.idx_list):
|
||||
raise StopIteration
|
||||
else:
|
||||
endidx = min(self.curidx + self.batch_size, len(self.idx_list))
|
||||
padding_length = {field_name : max(field_length[self.curidx: endidx])
|
||||
for field_name, field_length in self.lengths.items()}
|
||||
|
||||
batch_x, batch_y = defaultdict(list), defaultdict(list)
|
||||
for idx in range(self.curidx, endidx):
|
||||
x, y = self.dataset.to_tensor(idx, padding_length)
|
||||
for name, tensor in x.items():
|
||||
batch_x[name].append(tensor)
|
||||
for name, tensor in y.items():
|
||||
batch_y[name].append(tensor)
|
||||
|
||||
for batch in (batch_x, batch_y):
|
||||
for name, tensor_list in batch.items():
|
||||
print(name, " ", tensor_list)
|
||||
batch[name] = torch.stack(tensor_list, dim=0)
|
||||
self.curidx += endidx
|
||||
return batch_x, batch_y
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""simple running example
|
||||
"""
|
||||
from field import TextField, LabelField
|
||||
from instance import Instance
|
||||
from dataset import DataSet
|
||||
|
||||
texts = ["i am a cat",
|
||||
"this is a test of new batch",
|
||||
"haha"
|
||||
]
|
||||
labels = [0, 1, 0]
|
||||
|
||||
# prepare vocabulary
|
||||
vocab = {}
|
||||
for text in texts:
|
||||
for tokens in text.split():
|
||||
if tokens not in vocab:
|
||||
vocab[tokens] = len(vocab)
|
||||
|
||||
# prepare input dataset
|
||||
data = DataSet()
|
||||
for text, label in zip(texts, labels):
|
||||
x = TextField(text.split(), False)
|
||||
y = LabelField(label, is_target=True)
|
||||
ins = Instance(text=x, label=y)
|
||||
data.append(ins)
|
||||
|
||||
# use vocabulary to index data
|
||||
data.index_field("text", vocab)
|
||||
|
||||
# define naive sampler for batch class
|
||||
class SeqSampler:
|
||||
def __call__(self, dataset):
|
||||
return list(range(len(dataset)))
|
||||
|
||||
# use bacth to iterate dataset
|
||||
batcher = Batch(data, SeqSampler(), 2)
|
||||
for epoch in range(3):
|
||||
for batch_x, batch_y in batcher:
|
||||
print(batch_x)
|
||||
print(batch_y)
|
||||
# do stuff
|
||||
|
29
fastNLP/data/dataset.py
Normal file
29
fastNLP/data/dataset.py
Normal file
@ -0,0 +1,29 @@
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
class DataSet(list):
|
||||
def __init__(self, name="", instances=None):
|
||||
list.__init__([])
|
||||
self.name = name
|
||||
if instances is not None:
|
||||
self.extend(instances)
|
||||
|
||||
def index_all(self, vocab):
|
||||
for ins in self:
|
||||
ins.index_all(vocab)
|
||||
|
||||
def index_field(self, field_name, vocab):
|
||||
for ins in self:
|
||||
ins.index_field(field_name, vocab)
|
||||
|
||||
def to_tensor(self, idx: int, padding_length: dict):
|
||||
ins = self[idx]
|
||||
return ins.to_tensor(padding_length)
|
||||
|
||||
def get_length(self):
|
||||
lengths = defaultdict(list)
|
||||
for ins in self:
|
||||
for field_name, field_length in ins.get_length().items():
|
||||
lengths[field_name].append(field_length)
|
||||
return lengths
|
||||
|
70
fastNLP/data/field.py
Normal file
70
fastNLP/data/field.py
Normal file
@ -0,0 +1,70 @@
|
||||
import torch
|
||||
|
||||
class Field(object):
|
||||
def __init__(self, is_target: bool):
|
||||
self.is_target = is_target
|
||||
|
||||
def index(self, vocab):
|
||||
pass
|
||||
|
||||
def get_length(self):
|
||||
pass
|
||||
|
||||
def to_tensor(self, padding_length):
|
||||
pass
|
||||
|
||||
|
||||
class TextField(Field):
|
||||
def __init__(self, text: list, is_target):
|
||||
"""
|
||||
:param list text:
|
||||
"""
|
||||
super(TextField, self).__init__(is_target)
|
||||
self.text = text
|
||||
self._index = None
|
||||
|
||||
def index(self, vocab):
|
||||
if self._index is None:
|
||||
self._index = [vocab[c] for c in self.text]
|
||||
else:
|
||||
print('error')
|
||||
return self._index
|
||||
|
||||
def get_length(self):
|
||||
return len(self.text)
|
||||
|
||||
def to_tensor(self, padding_length: int):
|
||||
pads = []
|
||||
if self._index is None:
|
||||
print('error')
|
||||
if padding_length > self.get_length():
|
||||
pads = [0 for i in range(padding_length - self.get_length())]
|
||||
# (length, )
|
||||
return torch.LongTensor(self._index + pads)
|
||||
|
||||
|
||||
class LabelField(Field):
|
||||
def __init__(self, label, is_target=True):
|
||||
super(LabelField, self).__init__(is_target)
|
||||
self.label = label
|
||||
self._index = None
|
||||
|
||||
def get_length(self):
|
||||
return 1
|
||||
|
||||
def index(self, vocab):
|
||||
if self._index is None:
|
||||
self._index = vocab[self.label]
|
||||
else:
|
||||
pass
|
||||
return self._index
|
||||
|
||||
def to_tensor(self, padding_length):
|
||||
if self._index is None:
|
||||
return torch.LongTensor([self.label])
|
||||
else:
|
||||
return torch.LongTensor([self._index])
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf = TextField("test the code".split())
|
||||
|
38
fastNLP/data/instance.py
Normal file
38
fastNLP/data/instance.py
Normal file
@ -0,0 +1,38 @@
|
||||
class Instance(object):
|
||||
def __init__(self, **fields):
|
||||
self.fields = fields
|
||||
self.has_index = False
|
||||
self.indexes = {}
|
||||
|
||||
def add_field(self, field_name, field):
|
||||
self.fields[field_name] = field
|
||||
|
||||
def get_length(self):
|
||||
length = {name : field.get_length() for name, field in self.fields.items()}
|
||||
return length
|
||||
|
||||
def index_field(self, field_name, vocab):
|
||||
"""use `vocab` to index certain field
|
||||
"""
|
||||
self.indexes[field_name] = self.fields[field_name].index(vocab)
|
||||
|
||||
def index_all(self, vocab):
|
||||
"""use `vocab` to index all fields
|
||||
"""
|
||||
if self.has_index:
|
||||
print("error")
|
||||
return self.indexes
|
||||
indexes = {name : field.index(vocab) for name, field in self.fields.items()}
|
||||
self.indexes = indexes
|
||||
return indexes
|
||||
|
||||
def to_tensor(self, padding_length: dict):
|
||||
tensorX = {}
|
||||
tensorY = {}
|
||||
for name, field in self.fields.items():
|
||||
if field.is_target:
|
||||
tensorY[name] = field.to_tensor(padding_length[name])
|
||||
else:
|
||||
tensorX[name] = field.to_tensor(padding_length[name])
|
||||
|
||||
return tensorX, tensorY
|
Loading…
Reference in New Issue
Block a user