diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py index d76a913a5..b9c465398 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py @@ -23,7 +23,8 @@ class Joiner(nn.Module): def __init__(self, input_dim: int, output_dim: int): super().__init__() - self.output_linear = ScaledLinear(input_dim, output_dim) + self.output_linear = ScaledLinear(input_dim, output_dim, + initial_speed=0.5) def forward( self, encoder_out: torch.Tensor, decoder_out: torch.Tensor diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py index f1a3d4d11..ab729a429 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py @@ -61,9 +61,15 @@ class Transducer(nn.Module): self.decoder = decoder self.joiner = joiner - # could perhaps separate this into 2 linear projections, one - # for lm and one for am. - self.simple_joiner = ScaledLinear(embedding_dim, vocab_size) + self.simple_am_proj = ScaledLinear(embedding_dim, vocab_size) + self.simple_lm_proj = ScaledLinear(embedding_dim, vocab_size) + with torch.no_grad(): + # Initialize the two projections to be the same; this will be + # convenient for the real joiner, which adds the endcoder + # (acoustic-model/am) and decoder (language-model/lm) embeddings + self.simple_lm_proj.weight[:] = self.simple_am_proj.weight + self.simple_lm_proj.bias[:] = self.simple_am_proj.bias + def forward( self, @@ -140,8 +146,8 @@ class Transducer(nn.Module): boundary[:, 3] = x_lens simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( - lm=self.simple_joiner(decoder_out), - am=self.simple_joiner(encoder_out), + lm=self.simple_lm_proj(decoder_out), + am=self.simple_am_proj(encoder_out), symbols=y_padded, termination_symbol=blank_id, lm_only_scale=lm_scale,