diff --git a/egs/aishell/ASR/conformer_mmi/transformer.py b/egs/aishell/ASR/conformer_mmi/transformer.py index dfd888414..a3e50e385 100644 --- a/egs/aishell/ASR/conformer_mmi/transformer.py +++ b/egs/aishell/ASR/conformer_mmi/transformer.py @@ -545,7 +545,6 @@ class TransformerDecoderLayer(nn.Module): memory_mask: Optional[torch.Tensor] = None, tgt_key_padding_mask: Optional[torch.Tensor] = None, memory_key_padding_mask: Optional[torch.Tensor] = None, - **kwargs, ) -> torch.Tensor: """Pass the inputs (and mask) through the decoder layer. diff --git a/egs/tedlium3/ASR/conformer_ctc2/transformer.py b/egs/tedlium3/ASR/conformer_ctc2/transformer.py index 804c92957..9dbf32e48 100644 --- a/egs/tedlium3/ASR/conformer_ctc2/transformer.py +++ b/egs/tedlium3/ASR/conformer_ctc2/transformer.py @@ -612,7 +612,6 @@ class TransformerDecoderLayer(nn.Module): tgt_key_padding_mask: Optional[torch.Tensor] = None, memory_key_padding_mask: Optional[torch.Tensor] = None, warmup: float = 1.0, - **kwargs, ) -> torch.Tensor: """Pass the inputs (and mask) through the decoder layer. diff --git a/icefall/utils.py b/icefall/utils.py index a04bedffd..427755090 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -1391,20 +1391,13 @@ def add_eos(ragged: k2.RaggedTensor, eos_id: int) -> k2.RaggedTensor: return concat(ragged, eos_id, direction="right") -def make_pad_mask( - lengths: torch.Tensor, - max_len: int = 0, - pad_left: bool = False, -) -> torch.Tensor: +def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: """ Args: lengths: A 1-D tensor containing sentence lengths. max_len: The length of masks. - pad_left: - If ``False`` (default), padding is on the right. - If ``True``, padding is on the left. Returns: Return a 2-D bool tensor, where masked positions are filled with `True` and non-masked positions are @@ -1421,14 +1414,9 @@ def make_pad_mask( max_len = max(max_len, lengths.max()) n = lengths.size(0) seq_range = torch.arange(0, max_len, device=lengths.device) - expanded_lengths = seq_range.unsqueeze(0).expand(n, max_len) + expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len) - if pad_left: - mask = expanded_lengths < (max_len - lengths).unsqueeze(1) - else: - mask = expanded_lengths >= lengths.unsqueeze(-1) - - return mask + return expaned_lengths >= lengths.unsqueeze(-1) # Copied and modified from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/mask.py