mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
fix mask errors when padding audios (#670)
This commit is contained in:
parent
32de2766d5
commit
2f43e4508b
@ -1017,11 +1017,13 @@ def add_eos(ragged: k2.RaggedTensor, eos_id: int) -> k2.RaggedTensor:
|
||||
return concat(ragged, eos_id, direction="right")
|
||||
|
||||
|
||||
def make_pad_mask(lengths: torch.Tensor) -> torch.Tensor:
|
||||
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
lengths:
|
||||
A 1-D tensor containing sentence lengths.
|
||||
max_len:
|
||||
The length of masks.
|
||||
Returns:
|
||||
Return a 2-D bool tensor, where masked positions
|
||||
are filled with `True` and non-masked positions are
|
||||
@ -1035,8 +1037,7 @@ def make_pad_mask(lengths: torch.Tensor) -> torch.Tensor:
|
||||
[False, False, False, False, False]])
|
||||
"""
|
||||
assert lengths.ndim == 1, lengths.ndim
|
||||
|
||||
max_len = lengths.max()
|
||||
max_len = max(max_len, lengths.max())
|
||||
n = lengths.size(0)
|
||||
|
||||
expaned_lengths = torch.arange(max_len).expand(n, max_len).to(lengths)
|
||||
|
Loading…
x
Reference in New Issue
Block a user