diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 389a7cb7f..963cb2cd9 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -177,6 +177,7 @@ class ConformerEncoderLayer(nn.Module): self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) self.scale_mha = ExpScale(1, speed=10.0, initial_scale=0.2) + self.post_scale_mha = ExpScale(1, speed=10.0, initial_scale=1.0) self.scale_conv = ExpScale(1, speed=10.0, initial_scale=0.5) self.scale_ff = ExpScale(1, speed=10.0, initial_scale=0.5) self.scale_ff_macaron = ExpScale(1, speed=10.0, initial_scale=0.5) @@ -230,14 +231,7 @@ class ConformerEncoderLayer(nn.Module): attn_mask=src_mask, key_padding_mask=src_key_padding_mask, )[0] - # natural rms scale of mha output is about 2 to 6. scaling down by 0.1 takes it - # to 0.2 to 0.6, which is suitable to add to the inputs assuming the output - # of the previous convolution layer had a magnitude of around 1.0 - # (this magnitude of 1.0, or a bit less, like 0.3, is learned but is - # dictated by considerations of what is done to the output of the - # encoder. - post_scale_mha = 0.1 - src = residual + post_scale_mha * self.dropout(src_att) + src = residual + post_scale_mha(self.dropout(src_att)) # convolution module src = src + self.dropout(self.conv_module(self.scale_conv(src))) diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index c72a9dd28..be771b517 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/randcombine1_expscale3_rework_2.0", + default="transducer_stateless/randcombine1_expscale3_rework_2.0_b", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved