diff --git a/icefall/utils.py b/icefall/utils.py index 427755090..a04bedffd 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -1391,13 +1391,20 @@ def add_eos(ragged: k2.RaggedTensor, eos_id: int) -> k2.RaggedTensor: return concat(ragged, eos_id, direction="right") -def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: +def make_pad_mask( + lengths: torch.Tensor, + max_len: int = 0, + pad_left: bool = False, +) -> torch.Tensor: """ Args: lengths: A 1-D tensor containing sentence lengths. max_len: The length of masks. + pad_left: + If ``False`` (default), padding is on the right. + If ``True``, padding is on the left. Returns: Return a 2-D bool tensor, where masked positions are filled with `True` and non-masked positions are @@ -1414,9 +1421,14 @@ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: max_len = max(max_len, lengths.max()) n = lengths.size(0) seq_range = torch.arange(0, max_len, device=lengths.device) - expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len) + expanded_lengths = seq_range.unsqueeze(0).expand(n, max_len) - return expaned_lengths >= lengths.unsqueeze(-1) + if pad_left: + mask = expanded_lengths < (max_len - lengths).unsqueeze(1) + else: + mask = expanded_lengths >= lengths.unsqueeze(-1) + + return mask # Copied and modified from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/mask.py