修改batch, 新增pipeline和processor的接口

This commit is contained in:
yh 2018-11-09 18:35:18 +08:00
parent 8fae3bc2e7
commit fcf5af93d8
4 changed files with 49 additions and 25 deletions

0
fastNLP/api/__init__.py Normal file
View File

23
fastNLP/api/pipeline.py Normal file
View File

@ -0,0 +1,23 @@
from fastNLP.api.processor import Processor
class Pipeline:
def __init__(self):
self.pipeline = []
def add_processor(self, processor):
assert isinstance(processor, Processor), "Must be a Processor, not {}.".format(type(processor))
processor_name = type(processor)
self.pipeline.append(processor)
def process(self, dataset):
assert len(self.pipeline)!=0, "You need to add some processor first."
for proc_name, proc in self.pipeline:
dataset = proc(dataset)
return dataset
def __call__(self, *args, **kwargs):
return self.process(*args, **kwargs)

15
fastNLP/api/processor.py Normal file
View File

@ -0,0 +1,15 @@
class Processor:
def __init__(self, field_name, new_added_field_name):
self.field_name = field_name
if new_added_field_name is None:
self.new_added_field_name = field_name
else:
self.new_added_field_name = new_added_field_name
def process(self):
pass
def __call__(self, *args, **kwargs):
return self.process(*args, **kwargs)

View File

@ -51,34 +51,20 @@ class Batch(object):
raise StopIteration
else:
endidx = min(self.curidx + self.batch_size, len(self.idx_list))
batch_idxes = self.idx_list[self.curidx: endidx]
padding_length = {field_name: max([field_length[idx] for idx in batch_idxes])
for field_name, field_length in self.lengths.items()}
batch_x, batch_y = defaultdict(list), defaultdict(list)
batch_x, batch_y = {}, {}
# transform index to tensor and do padding for sequences
batch = []
for idx in batch_idxes:
x, y = self.dataset.to_tensor(idx, padding_length)
batch.append((self.lengths[self.sort_key][idx] if self.sort_in_batch else None, x, y))
indices = self.idx_list[self.curidx:endidx]
if self.sort_in_batch:
batch = sorted(batch, key=lambda x: x[0], reverse=True)
for _, x, y in batch:
for name, tensor in x.items():
batch_x[name].append(tensor)
for name, tensor in y.items():
batch_y[name].append(tensor)
# combine instances to form a batch
for batch in (batch_x, batch_y):
for name, tensor_list in batch.items():
if self.use_cuda:
batch[name] = torch.stack(tensor_list, dim=0).cuda()
else:
batch[name] = torch.stack(tensor_list, dim=0)
for field_name, field in self.dataset.get_fields():
batch = field.get(indices)
if not field.tensorable: #TODO 修改
pass
elif field.is_target:
batch_y[field_name] = batch
else:
batch_x[field_name] = batch
self.curidx = endidx
return batch_x, batch_y