support left pad for make_pad_mask (#1990)

This commit is contained in:
Yifan Yang 2025-07-16 23:59:04 +08:00 committed by GitHub
parent e22bc78f98
commit 9fd0f2dc1d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1391,13 +1391,20 @@ def add_eos(ragged: k2.RaggedTensor, eos_id: int) -> k2.RaggedTensor:
return concat(ragged, eos_id, direction="right")
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
def make_pad_mask(
lengths: torch.Tensor,
max_len: int = 0,
pad_left: bool = False,
) -> torch.Tensor:
"""
Args:
lengths:
A 1-D tensor containing sentence lengths.
max_len:
The length of masks.
pad_left:
If ``False`` (default), padding is on the right.
If ``True``, padding is on the left.
Returns:
Return a 2-D bool tensor, where masked positions
are filled with `True` and non-masked positions are
@ -1414,9 +1421,14 @@ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
max_len = max(max_len, lengths.max())
n = lengths.size(0)
seq_range = torch.arange(0, max_len, device=lengths.device)
expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len)
expanded_lengths = seq_range.unsqueeze(0).expand(n, max_len)
return expaned_lengths >= lengths.unsqueeze(-1)
if pad_left:
mask = expanded_lengths < (max_len - lengths).unsqueeze(1)
else:
mask = expanded_lengths >= lengths.unsqueeze(-1)
return mask
# Copied and modified from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/mask.py