mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 16:14:17 +00:00
Add beam_search.py
This commit is contained in:
parent
fd6416e6c1
commit
0c58a4b960
1
egs/librispeech/ASR/transducer_stateless2/beam_search.py
Symbolic link
1
egs/librispeech/ASR/transducer_stateless2/beam_search.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../transducer_stateless/beam_search.py
|
@ -30,6 +30,7 @@ class Joiner(nn.Module):
|
|||||||
self,
|
self,
|
||||||
encoder_out: torch.Tensor,
|
encoder_out: torch.Tensor,
|
||||||
decoder_out: torch.Tensor,
|
decoder_out: torch.Tensor,
|
||||||
|
*unused,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -37,6 +38,10 @@ class Joiner(nn.Module):
|
|||||||
Output from the encoder. Its shape is (N, T, self.input_dim).
|
Output from the encoder. Its shape is (N, T, self.input_dim).
|
||||||
decoder_out:
|
decoder_out:
|
||||||
Output from the decoder. Its shape is (N, U, self.input_dim).
|
Output from the decoder. Its shape is (N, U, self.input_dim).
|
||||||
|
unused:
|
||||||
|
This is a placeholder so that we can reuse
|
||||||
|
transducer_stateless/beam_search.py in this folder as that
|
||||||
|
script assumes the joiner networks accepts 4 inputs.
|
||||||
Returns:
|
Returns:
|
||||||
Return a tensor of shape (N, T, U, self.output_dim).
|
Return a tensor of shape (N, T, U, self.output_dim).
|
||||||
"""
|
"""
|
||||||
@ -53,4 +58,10 @@ class Joiner(nn.Module):
|
|||||||
|
|
||||||
logits = self.output_linear(activations)
|
logits = self.output_linear(activations)
|
||||||
|
|
||||||
|
if not self.training:
|
||||||
|
# We reuse the beam_search.py from transducer_stateless,
|
||||||
|
# which expects that the joiner network outputs
|
||||||
|
# a 2-D tensor.
|
||||||
|
logits = logits.unsqueeze(2).unsqueeze(1)
|
||||||
|
|
||||||
return logits
|
return logits
|
||||||
|
Loading…
x
Reference in New Issue
Block a user