diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py index a1226f712..752a5f774 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py @@ -35,7 +35,8 @@ class Joiner(nn.Module): self.output_linear = ScaledLinear(joiner_dim, vocab_size) def forward( - self, encoder_out: torch.Tensor, decoder_out: torch.Tensor + self, encoder_out: torch.Tensor, decoder_out: torch.Tensor, + project_input: bool = True ) -> torch.Tensor: """ Args: @@ -43,13 +44,20 @@ class Joiner(nn.Module): Output from the encoder. Its shape is (N, T, s_range, C). decoder_out: Output from the decoder. Its shape is (N, T, s_range, C). + project_input: + If true, apply input projections encoder_proj and decoder_proj. + If this is false, it is the user's responsibility to do this + manually. Returns: Return a tensor of shape (N, T, s_range, C). """ assert encoder_out.ndim == decoder_out.ndim == 4 assert encoder_out.shape[:-1] == decoder_out.shape[:-1] - logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) + if project_input: + logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) + else: + logit = encoder_out + decoder_out logit = self.output_linear(torch.tanh(logit)) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py index 1dd20c546..a9178c8b3 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py @@ -164,11 +164,17 @@ class Transducer(nn.Module): # am_pruned : [B, T, prune_range, encoder_dim] # lm_pruned : [B, T, prune_range, decoder_dim] am_pruned, lm_pruned = k2.do_rnnt_pruning( - am=encoder_out, lm=decoder_out, ranges=ranges + am=self.joiner.encoder_proj(encoder_out), + lm=self.joiner.decoder_proj(decoder_out), + ranges=ranges ) # logits : [B, T, prune_range, vocab_size] - logits = self.joiner(am_pruned, lm_pruned) + + # project_input=False since we applied the decoder's input projections + # prior to do_rnnt_pruning (this is an optimization for speed). + logits = self.joiner(am_pruned, lm_pruned, + project_input=False) pruned_loss = k2.rnnt_loss_pruned( logits=logits,