mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-01 19:57:34 +08:00
更新oneflow.full的参数
This commit is contained in:
parent
189050d25d
commit
cce133d806
@ -169,7 +169,7 @@ class OneflowTensorPadder(Padder):
|
||||
else:
|
||||
max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)]
|
||||
|
||||
tensor = oneflow.full(max_shape, value=pad_val, dtype=dtype, device=device)
|
||||
tensor = oneflow.full(max_shape, fill_value=pad_val, dtype=dtype, device=device)
|
||||
for i, field in enumerate(batch_field):
|
||||
slices = (i, ) + tuple(slice(0, s) for s in shapes[i])
|
||||
tensor[slices] = field
|
||||
@ -221,6 +221,6 @@ def get_padded_oneflow_tensor(batch_field, dtype=None, pad_val=0):
|
||||
:return:
|
||||
"""
|
||||
shapes = get_shape(batch_field)
|
||||
tensor = oneflow.full(shapes, dtype=dtype, value=pad_val)
|
||||
tensor = oneflow.full(shapes, dtype=dtype, fill_value=pad_val)
|
||||
tensor = fill_tensor(batch_field, tensor, dtype=dtype)
|
||||
return tensor
|
||||
|
Loading…
Reference in New Issue
Block a user