fix bug in conv_emformer_transducer_stateless2/emformer.py

This commit is contained in:
yaozengwei 2022-06-17 22:52:05 +08:00
parent c27bb1c554
commit 208bbb6325

View File

@ -542,7 +542,7 @@ class EmformerAttention(nn.Module):
padding_mask: Optional[torch.Tensor] = None,
left_context_key: Optional[torch.Tensor] = None,
left_context_val: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Underlying chunk-wise attention implementation."""
U, B, _ = utterance.size()
R = right_context.size(0)
@ -671,7 +671,7 @@ class EmformerAttention(nn.Module):
left_context_key: torch.Tensor,
left_context_val: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Forward pass for inference.
B: batch size;
@ -1388,7 +1388,7 @@ class EmformerEncoder(nn.Module):
output_lengths = torch.clamp(lengths - self.right_context_length, min=0)
attention_mask = self._gen_attention_mask(utterance)
M = right_context.size(0) // self.chunk_length - 1
M = right_context.size(0) // self.right_context_length - 1
padding_mask = make_pad_mask(M + right_context.size(0) + output_lengths)
output = utterance