From 88d0da7192101c4fcc025696f604e89eb65926ba Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 3 Oct 2022 17:54:56 +0800 Subject: [PATCH] Simplify the learned scaling factor on the modules --- .../pruned_transducer_stateless7/conformer.py | 30 +++++++++---------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index 8108ce7f7..30e387ce8 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -169,7 +169,6 @@ class ConformerEncoderLayer(nn.Module): self.self_attn = RelPositionMultiheadAttention( d_model, nhead, dropout=dropout, ) - self.self_attn_scale = LearnedScale() self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), @@ -180,7 +179,6 @@ class ConformerEncoderLayer(nn.Module): ScaledLinear(dim_feedforward, d_model, initial_scale=0.1), ) - self.feed_forward_scale = LearnedScale() self.feed_forward_macaron = nn.Sequential( nn.Linear(d_model, dim_feedforward), @@ -191,14 +189,15 @@ class ConformerEncoderLayer(nn.Module): ScaledLinear(dim_feedforward, d_model, initial_scale=0.1), ) - self.feed_forward_macaron_scale = LearnedScale() self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) - self.conv_scale = LearnedScale() + self.norm_final = BasicNorm(d_model) - self.final_scale = LearnedScale() + + # scale_alpha relates to a scale that can help work around layerdrop during training. + self.scale_alpha = torch.nn.Parameter(torch.tensor(0.0)) # try to ensure the output is close to zero-mean (or at least, zero-median). self.balancer = ActivationBalancer( @@ -284,8 +283,7 @@ class ConformerEncoderLayer(nn.Module): alpha = warmup_scale if self.training else 1.0 # macaron style feed forward module - src = src + self.feed_forward_macaron_scale(self.feed_forward_macaron(src), - layerdrop_indicator) + src = src + self.feed_forward_macaron(src) # multi-headed self-attention module src_att, _, attn_scores_out = self.self_attn( @@ -295,23 +293,23 @@ class ConformerEncoderLayer(nn.Module): attn_mask=src_mask, key_padding_mask=src_key_padding_mask, ) - src = src + self.self_attn_scale(src_att, layerdrop_indicator) + src = src + src_att # convolution module - src = src + self.conv_scale(self.conv_module(src, src_key_padding_mask=src_key_padding_mask), - layerdrop_indicator) + src = src + self.conv_module(src, src_key_padding_mask=src_key_padding_mask) # feed forward module - src = src + self.feed_forward_scale(self.feed_forward(src), - layerdrop_indicator) - - src = self.final_scale(src, layerdrop_indicator) + src = src + self.feed_forward(src) src = self.norm_final(self.balancer(src)) - if alpha != 1.0: - src = alpha * src + (1 - alpha) * src_orig + if alpha != 1.0 or layerdrop_indicator != 1.0 or self.training: + # the if(self.training) part is to ensure we have a derivative for + # self.scale_alpha. + src_offset = src - src_orig + scale = alpha * (1.0 + self.scale_alpha * (1.0 - layerdrop_indicator)) + src = src_orig + src_offset * scale return src, attn_scores_out