diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py index c5d862ad8..45ca03dd2 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py @@ -542,7 +542,7 @@ class EmformerAttention(nn.Module): padding_mask: Optional[torch.Tensor] = None, left_context_key: Optional[torch.Tensor] = None, left_context_val: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Underlying chunk-wise attention implementation.""" U, B, _ = utterance.size() R = right_context.size(0) @@ -671,7 +671,7 @@ class EmformerAttention(nn.Module): left_context_key: torch.Tensor, left_context_val: torch.Tensor, padding_mask: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Forward pass for inference. B: batch size; @@ -1388,7 +1388,7 @@ class EmformerEncoder(nn.Module): output_lengths = torch.clamp(lengths - self.right_context_length, min=0) attention_mask = self._gen_attention_mask(utterance) - M = right_context.size(0) // self.chunk_length - 1 + M = right_context.size(0) // self.right_context_length - 1 padding_mask = make_pad_mask(M + right_context.size(0) + output_lengths) output = utterance