fix mask errors when padding audios (#670)

This commit is contained in:
Yuekai Zhang 2022-11-10 22:28:04 +08:00 committed by GitHub
parent 32de2766d5
commit 2f43e4508b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1017,11 +1017,13 @@ def add_eos(ragged: k2.RaggedTensor, eos_id: int) -> k2.RaggedTensor:
return concat(ragged, eos_id, direction="right")
def make_pad_mask(lengths: torch.Tensor) -> torch.Tensor:
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
"""
Args:
lengths:
A 1-D tensor containing sentence lengths.
max_len:
The length of masks.
Returns:
Return a 2-D bool tensor, where masked positions
are filled with `True` and non-masked positions are
@ -1035,8 +1037,7 @@ def make_pad_mask(lengths: torch.Tensor) -> torch.Tensor:
[False, False, False, False, False]])
"""
assert lengths.ndim == 1, lengths.ndim
max_len = lengths.max()
max_len = max(max_len, lengths.max())
n = lengths.size(0)
expaned_lengths = torch.arange(max_len).expand(n, max_len).to(lengths)