更新oneflow.full的参数

This commit is contained in:
x54-729 2022-07-17 00:10:05 +08:00
parent 189050d25d
commit cce133d806

View File

@ -169,7 +169,7 @@ class OneflowTensorPadder(Padder):
else:
max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)]
tensor = oneflow.full(max_shape, value=pad_val, dtype=dtype, device=device)
tensor = oneflow.full(max_shape, fill_value=pad_val, dtype=dtype, device=device)
for i, field in enumerate(batch_field):
slices = (i, ) + tuple(slice(0, s) for s in shapes[i])
tensor[slices] = field
@ -221,6 +221,6 @@ def get_padded_oneflow_tensor(batch_field, dtype=None, pad_val=0):
:return:
"""
shapes = get_shape(batch_field)
tensor = oneflow.full(shapes, dtype=dtype, value=pad_val)
tensor = oneflow.full(shapes, dtype=dtype, fill_value=pad_val)
tensor = fill_tensor(batch_field, tensor, dtype=dtype)
return tensor