mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-04 21:28:01 +08:00
seq_len_to_mask修改为直接使用max_len而不再和句中最大长度对比
This commit is contained in:
parent
76e2330a2e
commit
8a766f070b
@ -659,22 +659,26 @@ def seq_len_to_mask(seq_len, max_len=None):
|
||||
>>> mask = seq_len_to_mask(seq_len)
|
||||
>>> print(mask.shape)
|
||||
(14, 15)
|
||||
>>> seq_len = torch.arange(2, 16)
|
||||
>>> mask = seq_len_to_mask(seq_len, max_len=100)
|
||||
>>>print(mask.size())
|
||||
torch.Size([14, 100])
|
||||
|
||||
:param np.ndarray,torch.LongTensor seq_len: shape将是(B,)
|
||||
:param int max_len: 将长度pad到这个长度. 默认使用的是seq_len中最长的长度。但在nn.DataParallel的场景下可能不同卡的seq_len会有
|
||||
:param int max_len: 将长度pad到这个长度。默认(None)使用的是seq_len中最长的长度。但在nn.DataParallel的场景下可能不同卡的seq_len会有
|
||||
区别,所以需要传入一个max_len使得mask的长度是pad到该长度。
|
||||
:return: np.ndarray or torch.Tensor, shape将是(B, max_length)。 元素类似为bool或torch.uint8
|
||||
"""
|
||||
if isinstance(seq_len, np.ndarray):
|
||||
assert len(np.shape(seq_len)) == 1, f"seq_len can only have one dimension, got {len(np.shape(seq_len))}."
|
||||
max_len = max(max_len, int(seq_len.max())) if max_len else int(seq_len.max())
|
||||
max_len = int(max_len) if max_len else int(seq_len.max())
|
||||
broad_cast_seq_len = np.tile(np.arange(max_len), (len(seq_len), 1))
|
||||
mask = broad_cast_seq_len < seq_len.reshape(-1, 1)
|
||||
|
||||
elif isinstance(seq_len, torch.Tensor):
|
||||
assert seq_len.dim() == 1, f"seq_len can only have one dimension, got {seq_len.dim() == 1}."
|
||||
batch_size = seq_len.size(0)
|
||||
max_len = max(max_len, seq_len.max().long()) if max_len else seq_len.max().long()
|
||||
max_len = int(max_len) if max_len else seq_len.max().long()
|
||||
broad_cast_seq_len = torch.arange(max_len).expand(batch_size, -1).to(seq_len)
|
||||
mask = broad_cast_seq_len.lt(seq_len.unsqueeze(1))
|
||||
else:
|
||||
|
@ -237,6 +237,11 @@ class TestSeqLenToMask(unittest.TestCase):
|
||||
with self.assertRaises(AssertionError):
|
||||
mask = seq_len_to_mask(seq_len)
|
||||
|
||||
# 3. pad到指定长度
|
||||
seq_len = np.random.randint(1, 10, size=(10,))
|
||||
mask = seq_len_to_mask(seq_len, 100)
|
||||
self.assertEqual(100, mask.size(1))
|
||||
|
||||
|
||||
def test_pytorch_seq_len(self):
|
||||
# 1. 随机测试
|
||||
@ -250,3 +255,8 @@ class TestSeqLenToMask(unittest.TestCase):
|
||||
seq_len = torch.randn(3, 4)
|
||||
with self.assertRaises(AssertionError):
|
||||
mask = seq_len_to_mask(seq_len)
|
||||
|
||||
# 3. pad到指定长度
|
||||
seq_len = torch.randint(1, 10, size=(10, ))
|
||||
mask = seq_len_to_mask(seq_len, 100)
|
||||
self.assertEqual(100, mask.size(1))
|
Loading…
Reference in New Issue
Block a user