mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-03 20:57:37 +08:00
修改batch, 新增pipeline和processor的接口
This commit is contained in:
parent
8fae3bc2e7
commit
fcf5af93d8
0
fastNLP/api/__init__.py
Normal file
0
fastNLP/api/__init__.py
Normal file
23
fastNLP/api/pipeline.py
Normal file
23
fastNLP/api/pipeline.py
Normal file
@ -0,0 +1,23 @@
|
||||
from fastNLP.api.processor import Processor
|
||||
|
||||
|
||||
|
||||
class Pipeline:
|
||||
def __init__(self):
|
||||
self.pipeline = []
|
||||
|
||||
def add_processor(self, processor):
|
||||
assert isinstance(processor, Processor), "Must be a Processor, not {}.".format(type(processor))
|
||||
processor_name = type(processor)
|
||||
self.pipeline.append(processor)
|
||||
|
||||
def process(self, dataset):
|
||||
assert len(self.pipeline)!=0, "You need to add some processor first."
|
||||
|
||||
for proc_name, proc in self.pipeline:
|
||||
dataset = proc(dataset)
|
||||
|
||||
return dataset
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.process(*args, **kwargs)
|
15
fastNLP/api/processor.py
Normal file
15
fastNLP/api/processor.py
Normal file
@ -0,0 +1,15 @@
|
||||
|
||||
|
||||
class Processor:
|
||||
def __init__(self, field_name, new_added_field_name):
|
||||
self.field_name = field_name
|
||||
if new_added_field_name is None:
|
||||
self.new_added_field_name = field_name
|
||||
else:
|
||||
self.new_added_field_name = new_added_field_name
|
||||
|
||||
def process(self):
|
||||
pass
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.process(*args, **kwargs)
|
@ -51,34 +51,20 @@ class Batch(object):
|
||||
raise StopIteration
|
||||
else:
|
||||
endidx = min(self.curidx + self.batch_size, len(self.idx_list))
|
||||
batch_idxes = self.idx_list[self.curidx: endidx]
|
||||
padding_length = {field_name: max([field_length[idx] for idx in batch_idxes])
|
||||
for field_name, field_length in self.lengths.items()}
|
||||
batch_x, batch_y = defaultdict(list), defaultdict(list)
|
||||
batch_x, batch_y = {}, {}
|
||||
|
||||
# transform index to tensor and do padding for sequences
|
||||
batch = []
|
||||
for idx in batch_idxes:
|
||||
x, y = self.dataset.to_tensor(idx, padding_length)
|
||||
batch.append((self.lengths[self.sort_key][idx] if self.sort_in_batch else None, x, y))
|
||||
indices = self.idx_list[self.curidx:endidx]
|
||||
|
||||
if self.sort_in_batch:
|
||||
batch = sorted(batch, key=lambda x: x[0], reverse=True)
|
||||
|
||||
for _, x, y in batch:
|
||||
for name, tensor in x.items():
|
||||
batch_x[name].append(tensor)
|
||||
for name, tensor in y.items():
|
||||
batch_y[name].append(tensor)
|
||||
|
||||
# combine instances to form a batch
|
||||
for batch in (batch_x, batch_y):
|
||||
for name, tensor_list in batch.items():
|
||||
if self.use_cuda:
|
||||
batch[name] = torch.stack(tensor_list, dim=0).cuda()
|
||||
else:
|
||||
batch[name] = torch.stack(tensor_list, dim=0)
|
||||
for field_name, field in self.dataset.get_fields():
|
||||
batch = field.get(indices)
|
||||
if not field.tensorable: #TODO 修改
|
||||
pass
|
||||
elif field.is_target:
|
||||
batch_y[field_name] = batch
|
||||
else:
|
||||
batch_x[field_name] = batch
|
||||
|
||||
self.curidx = endidx
|
||||
|
||||
return batch_x, batch_y
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user