mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
fix bug in conv_emformer_transducer_stateless2/emformer.py
This commit is contained in:
parent
c27bb1c554
commit
208bbb6325
@ -542,7 +542,7 @@ class EmformerAttention(nn.Module):
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
left_context_key: Optional[torch.Tensor] = None,
|
||||
left_context_val: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Underlying chunk-wise attention implementation."""
|
||||
U, B, _ = utterance.size()
|
||||
R = right_context.size(0)
|
||||
@ -671,7 +671,7 @@ class EmformerAttention(nn.Module):
|
||||
left_context_key: torch.Tensor,
|
||||
left_context_val: torch.Tensor,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Forward pass for inference.
|
||||
|
||||
B: batch size;
|
||||
@ -1388,7 +1388,7 @@ class EmformerEncoder(nn.Module):
|
||||
output_lengths = torch.clamp(lengths - self.right_context_length, min=0)
|
||||
attention_mask = self._gen_attention_mask(utterance)
|
||||
|
||||
M = right_context.size(0) // self.chunk_length - 1
|
||||
M = right_context.size(0) // self.right_context_length - 1
|
||||
padding_mask = make_pad_mask(M + right_context.size(0) + output_lengths)
|
||||
|
||||
output = utterance
|
||||
|
Loading…
x
Reference in New Issue
Block a user