mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-05 05:38:31 +08:00
[bugfix] auto convert tensor type when batching
This commit is contained in:
parent
65b141c117
commit
2b41e4dd29
@ -69,13 +69,20 @@ class DataSetGetter:
|
||||
|
||||
def may_to_tensor(data):
|
||||
dtype, dim = _get_ele_type_and_dim(data)
|
||||
print(dtype, type(dtype))
|
||||
# print(dtype, type(dtype), str(dtype))
|
||||
if not self.as_numpy:
|
||||
try:
|
||||
data, flag = _to_tensor(data, dtype)
|
||||
except TypeError as e:
|
||||
logger.error(f"Field {n} cannot be converted to torch.tensor.")
|
||||
raise e
|
||||
# if torch.is_tensor(data):
|
||||
# str_dtype = str(dtype)
|
||||
# if 'float' in str_dtype:
|
||||
# data = data.float()
|
||||
# elif 'int' in str_dtype:
|
||||
# data = data.long()
|
||||
# print(data.dtype)
|
||||
return data
|
||||
|
||||
def pad(batch_dict):
|
||||
@ -293,14 +300,16 @@ def _to_tensor(batch, field_dtype):
|
||||
if field_dtype is not None and isinstance(field_dtype, type)\
|
||||
and issubclass(field_dtype, Number) \
|
||||
and not isinstance(batch, torch.Tensor):
|
||||
if issubclass(field_dtype, np.floating):
|
||||
new_batch = torch.as_tensor(batch).float() # 默认使用float32
|
||||
elif issubclass(field_dtype, np.integer):
|
||||
new_batch = torch.as_tensor(batch).long() # 复用内存地址,避免复制
|
||||
else:
|
||||
new_batch = torch.as_tensor(batch)
|
||||
return new_batch, True
|
||||
flag = True
|
||||
else:
|
||||
return batch, False
|
||||
new_batch = batch
|
||||
flag = False
|
||||
if torch.is_tensor(new_batch):
|
||||
if 'float' in new_batch.dtype.__repr__():
|
||||
new_batch = new_batch.float()
|
||||
elif 'int' in new_batch.dtype.__repr__():
|
||||
new_batch = new_batch.long()
|
||||
return new_batch, flag
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
Loading…
Reference in New Issue
Block a user