From 2f43e4508b9eada64a9b89a4576935fb0b72694c Mon Sep 17 00:00:00 2001 From: Yuekai Zhang Date: Thu, 10 Nov 2022 22:28:04 +0800 Subject: [PATCH] fix mask errors when padding audios (#670) --- icefall/utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/icefall/utils.py b/icefall/utils.py index e83fccdde..c502cb4d8 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -1017,11 +1017,13 @@ def add_eos(ragged: k2.RaggedTensor, eos_id: int) -> k2.RaggedTensor: return concat(ragged, eos_id, direction="right") -def make_pad_mask(lengths: torch.Tensor) -> torch.Tensor: +def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: """ Args: lengths: A 1-D tensor containing sentence lengths. + max_len: + The length of masks. Returns: Return a 2-D bool tensor, where masked positions are filled with `True` and non-masked positions are @@ -1035,8 +1037,7 @@ def make_pad_mask(lengths: torch.Tensor) -> torch.Tensor: [False, False, False, False, False]]) """ assert lengths.ndim == 1, lengths.ndim - - max_len = lengths.max() + max_len = max(max_len, lengths.max()) n = lengths.size(0) expaned_lengths = torch.arange(max_len).expand(n, max_len).to(lengths)