diff --git a/egs/aishell/ASR/conformer_ctc/transformer.py b/egs/aishell/ASR/conformer_ctc/transformer.py index a3e50e385..dfd888414 100644 --- a/egs/aishell/ASR/conformer_ctc/transformer.py +++ b/egs/aishell/ASR/conformer_ctc/transformer.py @@ -545,6 +545,7 @@ class TransformerDecoderLayer(nn.Module): memory_mask: Optional[torch.Tensor] = None, tgt_key_padding_mask: Optional[torch.Tensor] = None, memory_key_padding_mask: Optional[torch.Tensor] = None, + **kwargs, ) -> torch.Tensor: """Pass the inputs (and mask) through the decoder layer. diff --git a/egs/aishell/ASR/conformer_mmi/transformer.py b/egs/aishell/ASR/conformer_mmi/transformer.py index a3e50e385..dfd888414 100644 --- a/egs/aishell/ASR/conformer_mmi/transformer.py +++ b/egs/aishell/ASR/conformer_mmi/transformer.py @@ -545,6 +545,7 @@ class TransformerDecoderLayer(nn.Module): memory_mask: Optional[torch.Tensor] = None, tgt_key_padding_mask: Optional[torch.Tensor] = None, memory_key_padding_mask: Optional[torch.Tensor] = None, + **kwargs, ) -> torch.Tensor: """Pass the inputs (and mask) through the decoder layer. diff --git a/egs/gigaspeech/ASR/conformer_ctc/transformer.py b/egs/gigaspeech/ASR/conformer_ctc/transformer.py index 0566cfc81..2d797cc67 100644 --- a/egs/gigaspeech/ASR/conformer_ctc/transformer.py +++ b/egs/gigaspeech/ASR/conformer_ctc/transformer.py @@ -549,6 +549,7 @@ class TransformerDecoderLayer(nn.Module): memory_mask: Optional[torch.Tensor] = None, tgt_key_padding_mask: Optional[torch.Tensor] = None, memory_key_padding_mask: Optional[torch.Tensor] = None, + **kwargs, ) -> torch.Tensor: """Pass the inputs (and mask) through the decoder layer. diff --git a/egs/librispeech/ASR/conformer_ctc/transformer.py b/egs/librispeech/ASR/conformer_ctc/transformer.py index 0566cfc81..2d797cc67 100644 --- a/egs/librispeech/ASR/conformer_ctc/transformer.py +++ b/egs/librispeech/ASR/conformer_ctc/transformer.py @@ -549,6 +549,7 @@ class TransformerDecoderLayer(nn.Module): memory_mask: Optional[torch.Tensor] = None, tgt_key_padding_mask: Optional[torch.Tensor] = None, memory_key_padding_mask: Optional[torch.Tensor] = None, + **kwargs, ) -> torch.Tensor: """Pass the inputs (and mask) through the decoder layer. diff --git a/egs/librispeech/ASR/conformer_ctc2/transformer.py b/egs/librispeech/ASR/conformer_ctc2/transformer.py index d3443dc94..6b62a5993 100644 --- a/egs/librispeech/ASR/conformer_ctc2/transformer.py +++ b/egs/librispeech/ASR/conformer_ctc2/transformer.py @@ -550,6 +550,7 @@ class TransformerDecoderLayer(nn.Module): tgt_key_padding_mask: Optional[torch.Tensor] = None, memory_key_padding_mask: Optional[torch.Tensor] = None, warmup: float = 1.0, + **kwargs, ) -> torch.Tensor: """Pass the inputs (and mask) through the decoder layer. diff --git a/egs/librispeech/ASR/conformer_mmi/transformer.py b/egs/librispeech/ASR/conformer_mmi/transformer.py index 2542d9abe..3bc6b88ec 100644 --- a/egs/librispeech/ASR/conformer_mmi/transformer.py +++ b/egs/librispeech/ASR/conformer_mmi/transformer.py @@ -537,6 +537,7 @@ class TransformerDecoderLayer(nn.Module): memory_mask: Optional[torch.Tensor] = None, tgt_key_padding_mask: Optional[torch.Tensor] = None, memory_key_padding_mask: Optional[torch.Tensor] = None, + **kwargs, ) -> torch.Tensor: """Pass the inputs (and mask) through the decoder layer. diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/transformer.py b/egs/librispeech/ASR/streaming_conformer_ctc/transformer.py index 0c87fdf1b..987a45b1f 100644 --- a/egs/librispeech/ASR/streaming_conformer_ctc/transformer.py +++ b/egs/librispeech/ASR/streaming_conformer_ctc/transformer.py @@ -567,6 +567,7 @@ class TransformerDecoderLayer(nn.Module): memory_mask: Optional[torch.Tensor] = None, tgt_key_padding_mask: Optional[torch.Tensor] = None, memory_key_padding_mask: Optional[torch.Tensor] = None, + **kwargs, ) -> torch.Tensor: """Pass the inputs (and mask) through the decoder layer. diff --git a/egs/tedlium3/ASR/conformer_ctc2/transformer.py b/egs/tedlium3/ASR/conformer_ctc2/transformer.py index 9dbf32e48..804c92957 100644 --- a/egs/tedlium3/ASR/conformer_ctc2/transformer.py +++ b/egs/tedlium3/ASR/conformer_ctc2/transformer.py @@ -612,6 +612,7 @@ class TransformerDecoderLayer(nn.Module): tgt_key_padding_mask: Optional[torch.Tensor] = None, memory_key_padding_mask: Optional[torch.Tensor] = None, warmup: float = 1.0, + **kwargs, ) -> torch.Tensor: """Pass the inputs (and mask) through the decoder layer. diff --git a/icefall/mmi.py b/icefall/mmi.py index b7777b434..2de8f1aae 100644 --- a/icefall/mmi.py +++ b/icefall/mmi.py @@ -124,6 +124,7 @@ def _compute_mmi_loss_exact_non_optimized( den_tot_scores = den_lats.get_tot_scores(log_semiring=True, use_double_scores=True) tot_scores = num_tot_scores - den_scale * den_tot_scores + tot_scores = tot_scores.masked_fill(torch.isinf(tot_scores), 0.0) loss = -1 * tot_scores.sum() return loss diff --git a/icefall/utils.py b/icefall/utils.py index 427755090..a04bedffd 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -1391,13 +1391,20 @@ def add_eos(ragged: k2.RaggedTensor, eos_id: int) -> k2.RaggedTensor: return concat(ragged, eos_id, direction="right") -def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: +def make_pad_mask( + lengths: torch.Tensor, + max_len: int = 0, + pad_left: bool = False, +) -> torch.Tensor: """ Args: lengths: A 1-D tensor containing sentence lengths. max_len: The length of masks. + pad_left: + If ``False`` (default), padding is on the right. + If ``True``, padding is on the left. Returns: Return a 2-D bool tensor, where masked positions are filled with `True` and non-masked positions are @@ -1414,9 +1421,14 @@ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: max_len = max(max_len, lengths.max()) n = lengths.size(0) seq_range = torch.arange(0, max_len, device=lengths.device) - expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len) + expanded_lengths = seq_range.unsqueeze(0).expand(n, max_len) - return expaned_lengths >= lengths.unsqueeze(-1) + if pad_left: + mask = expanded_lengths < (max_len - lengths).unsqueeze(1) + else: + mask = expanded_lengths >= lengths.unsqueeze(-1) + + return mask # Copied and modified from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/mask.py