Add beam_search.py

This commit is contained in:
Fangjun Kuang 2022-04-14 11:49:48 +08:00
parent fd6416e6c1
commit 0c58a4b960
2 changed files with 12 additions and 0 deletions

View File

@ -0,0 +1 @@
../transducer_stateless/beam_search.py

View File

@ -30,6 +30,7 @@ class Joiner(nn.Module):
self,
encoder_out: torch.Tensor,
decoder_out: torch.Tensor,
*unused,
) -> torch.Tensor:
"""
Args:
@ -37,6 +38,10 @@ class Joiner(nn.Module):
Output from the encoder. Its shape is (N, T, self.input_dim).
decoder_out:
Output from the decoder. Its shape is (N, U, self.input_dim).
unused:
This is a placeholder so that we can reuse
transducer_stateless/beam_search.py in this folder as that
script assumes the joiner networks accepts 4 inputs.
Returns:
Return a tensor of shape (N, T, U, self.output_dim).
"""
@ -53,4 +58,10 @@ class Joiner(nn.Module):
logits = self.output_linear(activations)
if not self.training:
# We reuse the beam_search.py from transducer_stateless,
# which expects that the joiner network outputs
# a 2-D tensor.
logits = logits.unsqueeze(2).unsqueeze(1)
return logits