diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless7/joiner.py index afcd690e9..4b9c21506 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/joiner.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/joiner.py @@ -16,6 +16,9 @@ import torch import torch.nn as nn +from scaling import ( + ScaledLinear +) class Joiner(nn.Module): @@ -28,8 +31,8 @@ class Joiner(nn.Module): ): super().__init__() - self.encoder_proj = nn.Linear(encoder_dim, joiner_dim) - self.decoder_proj = nn.Linear(decoder_dim, joiner_dim) + self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim, initial_scale=0.25) + self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim, initial_scale=0.25) self.output_linear = nn.Linear(joiner_dim, vocab_size) def forward( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py index 53cde6c6f..8f707cf4f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py @@ -22,7 +22,11 @@ import random from encoder_interface import EncoderInterface from icefall.utils import add_sos -from scaling import penalize_abs_values_gt +from scaling import ( + penalize_abs_values_gt, + ScaledLinear +) + class Transducer(nn.Module): @@ -64,10 +68,13 @@ class Transducer(nn.Module): self.decoder = decoder self.joiner = joiner - self.simple_am_proj = nn.Linear( - encoder_dim, vocab_size, + self.simple_am_proj = ScaledLinear( + encoder_dim, vocab_size, initial_scale=0.25, ) - self.simple_lm_proj = nn.Linear(decoder_dim, vocab_size) + self.simple_lm_proj = ScaledLinear( + decoder_dim, vocab_size, initial_scale=0.25, + ) + def forward( self,