mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
fix bug in conv_emformer_transducer_stateless2/emformer.py
This commit is contained in:
parent
c27bb1c554
commit
208bbb6325
@ -542,7 +542,7 @@ class EmformerAttention(nn.Module):
|
|||||||
padding_mask: Optional[torch.Tensor] = None,
|
padding_mask: Optional[torch.Tensor] = None,
|
||||||
left_context_key: Optional[torch.Tensor] = None,
|
left_context_key: Optional[torch.Tensor] = None,
|
||||||
left_context_val: Optional[torch.Tensor] = None,
|
left_context_val: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
"""Underlying chunk-wise attention implementation."""
|
"""Underlying chunk-wise attention implementation."""
|
||||||
U, B, _ = utterance.size()
|
U, B, _ = utterance.size()
|
||||||
R = right_context.size(0)
|
R = right_context.size(0)
|
||||||
@ -671,7 +671,7 @@ class EmformerAttention(nn.Module):
|
|||||||
left_context_key: torch.Tensor,
|
left_context_key: torch.Tensor,
|
||||||
left_context_val: torch.Tensor,
|
left_context_val: torch.Tensor,
|
||||||
padding_mask: Optional[torch.Tensor] = None,
|
padding_mask: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
"""Forward pass for inference.
|
"""Forward pass for inference.
|
||||||
|
|
||||||
B: batch size;
|
B: batch size;
|
||||||
@ -1388,7 +1388,7 @@ class EmformerEncoder(nn.Module):
|
|||||||
output_lengths = torch.clamp(lengths - self.right_context_length, min=0)
|
output_lengths = torch.clamp(lengths - self.right_context_length, min=0)
|
||||||
attention_mask = self._gen_attention_mask(utterance)
|
attention_mask = self._gen_attention_mask(utterance)
|
||||||
|
|
||||||
M = right_context.size(0) // self.chunk_length - 1
|
M = right_context.size(0) // self.right_context_length - 1
|
||||||
padding_mask = make_pad_mask(M + right_context.size(0) + output_lengths)
|
padding_mask = make_pad_mask(M + right_context.size(0) + output_lengths)
|
||||||
|
|
||||||
output = utterance
|
output = utterance
|
||||||
|
Loading…
x
Reference in New Issue
Block a user