diff --git a/egs/librispeech/ASR/transducer_stateless2/beam_search.py b/egs/librispeech/ASR/transducer_stateless2/beam_search.py new file mode 120000 index 000000000..08cb32ef7 --- /dev/null +++ b/egs/librispeech/ASR/transducer_stateless2/beam_search.py @@ -0,0 +1 @@ +../transducer_stateless/beam_search.py \ No newline at end of file diff --git a/egs/librispeech/ASR/transducer_stateless2/joiner.py b/egs/librispeech/ASR/transducer_stateless2/joiner.py index e56ba859d..b30b6895c 100644 --- a/egs/librispeech/ASR/transducer_stateless2/joiner.py +++ b/egs/librispeech/ASR/transducer_stateless2/joiner.py @@ -30,6 +30,7 @@ class Joiner(nn.Module): self, encoder_out: torch.Tensor, decoder_out: torch.Tensor, + *unused, ) -> torch.Tensor: """ Args: @@ -37,6 +38,10 @@ class Joiner(nn.Module): Output from the encoder. Its shape is (N, T, self.input_dim). decoder_out: Output from the decoder. Its shape is (N, U, self.input_dim). + unused: + This is a placeholder so that we can reuse + transducer_stateless/beam_search.py in this folder as that + script assumes the joiner networks accepts 4 inputs. Returns: Return a tensor of shape (N, T, U, self.output_dim). """ @@ -53,4 +58,10 @@ class Joiner(nn.Module): logits = self.output_linear(activations) + if not self.training: + # We reuse the beam_search.py from transducer_stateless, + # which expects that the joiner network outputs + # a 2-D tensor. + logits = logits.unsqueeze(2).unsqueeze(1) + return logits