Simplify the learned scaling factor on the modules

This commit is contained in:
Daniel Povey 2022-10-03 17:54:56 +08:00
parent b3af9f67ae
commit 88d0da7192

View File

@ -169,7 +169,6 @@ class ConformerEncoderLayer(nn.Module):
self.self_attn = RelPositionMultiheadAttention(
d_model, nhead, dropout=dropout,
)
self.self_attn_scale = LearnedScale()
self.feed_forward = nn.Sequential(
nn.Linear(d_model, dim_feedforward),
@ -180,7 +179,6 @@ class ConformerEncoderLayer(nn.Module):
ScaledLinear(dim_feedforward, d_model,
initial_scale=0.1),
)
self.feed_forward_scale = LearnedScale()
self.feed_forward_macaron = nn.Sequential(
nn.Linear(d_model, dim_feedforward),
@ -191,14 +189,15 @@ class ConformerEncoderLayer(nn.Module):
ScaledLinear(dim_feedforward, d_model,
initial_scale=0.1),
)
self.feed_forward_macaron_scale = LearnedScale()
self.conv_module = ConvolutionModule(d_model,
cnn_module_kernel)
self.conv_scale = LearnedScale()
self.norm_final = BasicNorm(d_model)
self.final_scale = LearnedScale()
# scale_alpha relates to a scale that can help work around layerdrop during training.
self.scale_alpha = torch.nn.Parameter(torch.tensor(0.0))
# try to ensure the output is close to zero-mean (or at least, zero-median).
self.balancer = ActivationBalancer(
@ -284,8 +283,7 @@ class ConformerEncoderLayer(nn.Module):
alpha = warmup_scale if self.training else 1.0
# macaron style feed forward module
src = src + self.feed_forward_macaron_scale(self.feed_forward_macaron(src),
layerdrop_indicator)
src = src + self.feed_forward_macaron(src)
# multi-headed self-attention module
src_att, _, attn_scores_out = self.self_attn(
@ -295,23 +293,23 @@ class ConformerEncoderLayer(nn.Module):
attn_mask=src_mask,
key_padding_mask=src_key_padding_mask,
)
src = src + self.self_attn_scale(src_att, layerdrop_indicator)
src = src + src_att
# convolution module
src = src + self.conv_scale(self.conv_module(src, src_key_padding_mask=src_key_padding_mask),
layerdrop_indicator)
src = src + self.conv_module(src, src_key_padding_mask=src_key_padding_mask)
# feed forward module
src = src + self.feed_forward_scale(self.feed_forward(src),
layerdrop_indicator)
src = self.final_scale(src, layerdrop_indicator)
src = src + self.feed_forward(src)
src = self.norm_final(self.balancer(src))
if alpha != 1.0:
src = alpha * src + (1 - alpha) * src_orig
if alpha != 1.0 or layerdrop_indicator != 1.0 or self.training:
# the if(self.training) part is to ensure we have a derivative for
# self.scale_alpha.
src_offset = src - src_orig
scale = alpha * (1.0 + self.scale_alpha * (1.0 - layerdrop_indicator))
src = src_orig + src_offset * scale
return src, attn_scores_out