[bugfix] auto convert tensor type when batching

This commit is contained in:
yunfan 2020-03-15 15:19:42 +08:00
parent 65b141c117
commit 2b41e4dd29

View File

@ -69,13 +69,20 @@ class DataSetGetter:
def may_to_tensor(data):
dtype, dim = _get_ele_type_and_dim(data)
print(dtype, type(dtype))
# print(dtype, type(dtype), str(dtype))
if not self.as_numpy:
try:
data, flag = _to_tensor(data, dtype)
except TypeError as e:
logger.error(f"Field {n} cannot be converted to torch.tensor.")
raise e
# if torch.is_tensor(data):
# str_dtype = str(dtype)
# if 'float' in str_dtype:
# data = data.float()
# elif 'int' in str_dtype:
# data = data.long()
# print(data.dtype)
return data
def pad(batch_dict):
@ -293,14 +300,16 @@ def _to_tensor(batch, field_dtype):
if field_dtype is not None and isinstance(field_dtype, type)\
and issubclass(field_dtype, Number) \
and not isinstance(batch, torch.Tensor):
if issubclass(field_dtype, np.floating):
new_batch = torch.as_tensor(batch).float() # 默认使用float32
elif issubclass(field_dtype, np.integer):
new_batch = torch.as_tensor(batch).long() # 复用内存地址,避免复制
else:
new_batch = torch.as_tensor(batch)
return new_batch, True
new_batch = torch.as_tensor(batch)
flag = True
else:
return batch, False
new_batch = batch
flag = False
if torch.is_tensor(new_batch):
if 'float' in new_batch.dtype.__repr__():
new_batch = new_batch.float()
elif 'int' in new_batch.dtype.__repr__():
new_batch = new_batch.long()
return new_batch, flag
except Exception as e:
raise e