mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-03 20:57:37 +08:00
add dataset support for sampler, update batch
This commit is contained in:
parent
0cbbfd5221
commit
ff6d99bcb2
@ -56,8 +56,8 @@ class Batch(object):
|
||||
indices = self.idx_list[self.curidx:endidx]
|
||||
|
||||
for field_name, field in self.dataset.get_fields():
|
||||
batch = field.get(indices)
|
||||
if not field.tensorable: #TODO 修改
|
||||
batch = torch.from_numpy(field.get(indices))
|
||||
if not field.need_tensor: #TODO 修改
|
||||
pass
|
||||
elif field.is_target:
|
||||
batch_y[field_name] = batch
|
||||
|
@ -40,6 +40,13 @@ class DataSet(object):
|
||||
assert name in self.field_arrays
|
||||
self.field_arrays[name].append(field)
|
||||
|
||||
def get_fields(self):
|
||||
return self.field_arrays
|
||||
|
||||
def __len__(self):
|
||||
field = self.field_arrays.values()[0]
|
||||
return len(field)
|
||||
|
||||
def get_length(self):
|
||||
"""Fetch lengths of all fields in all instances in a dataset.
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user