add dataset support for sampler, update batch

This commit is contained in:
yunfan 2018-11-09 20:12:06 +08:00
parent 0cbbfd5221
commit ff6d99bcb2
2 changed files with 9 additions and 2 deletions

View File

@ -56,8 +56,8 @@ class Batch(object):
indices = self.idx_list[self.curidx:endidx]
for field_name, field in self.dataset.get_fields():
batch = field.get(indices)
if not field.tensorable: #TODO 修改
batch = torch.from_numpy(field.get(indices))
if not field.need_tensor: #TODO 修改
pass
elif field.is_target:
batch_y[field_name] = batch

View File

@ -40,6 +40,13 @@ class DataSet(object):
assert name in self.field_arrays
self.field_arrays[name].append(field)
def get_fields(self):
return self.field_arrays
def __len__(self):
field = self.field_arrays.values()[0]
return len(field)
def get_length(self):
"""Fetch lengths of all fields in all instances in a dataset.