mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Simplify the learned scaling factor on the modules
This commit is contained in:
parent
b3af9f67ae
commit
88d0da7192
@ -169,7 +169,6 @@ class ConformerEncoderLayer(nn.Module):
|
||||
self.self_attn = RelPositionMultiheadAttention(
|
||||
d_model, nhead, dropout=dropout,
|
||||
)
|
||||
self.self_attn_scale = LearnedScale()
|
||||
|
||||
self.feed_forward = nn.Sequential(
|
||||
nn.Linear(d_model, dim_feedforward),
|
||||
@ -180,7 +179,6 @@ class ConformerEncoderLayer(nn.Module):
|
||||
ScaledLinear(dim_feedforward, d_model,
|
||||
initial_scale=0.1),
|
||||
)
|
||||
self.feed_forward_scale = LearnedScale()
|
||||
|
||||
self.feed_forward_macaron = nn.Sequential(
|
||||
nn.Linear(d_model, dim_feedforward),
|
||||
@ -191,14 +189,15 @@ class ConformerEncoderLayer(nn.Module):
|
||||
ScaledLinear(dim_feedforward, d_model,
|
||||
initial_scale=0.1),
|
||||
)
|
||||
self.feed_forward_macaron_scale = LearnedScale()
|
||||
|
||||
self.conv_module = ConvolutionModule(d_model,
|
||||
cnn_module_kernel)
|
||||
self.conv_scale = LearnedScale()
|
||||
|
||||
|
||||
self.norm_final = BasicNorm(d_model)
|
||||
self.final_scale = LearnedScale()
|
||||
|
||||
# scale_alpha relates to a scale that can help work around layerdrop during training.
|
||||
self.scale_alpha = torch.nn.Parameter(torch.tensor(0.0))
|
||||
|
||||
# try to ensure the output is close to zero-mean (or at least, zero-median).
|
||||
self.balancer = ActivationBalancer(
|
||||
@ -284,8 +283,7 @@ class ConformerEncoderLayer(nn.Module):
|
||||
alpha = warmup_scale if self.training else 1.0
|
||||
|
||||
# macaron style feed forward module
|
||||
src = src + self.feed_forward_macaron_scale(self.feed_forward_macaron(src),
|
||||
layerdrop_indicator)
|
||||
src = src + self.feed_forward_macaron(src)
|
||||
|
||||
# multi-headed self-attention module
|
||||
src_att, _, attn_scores_out = self.self_attn(
|
||||
@ -295,23 +293,23 @@ class ConformerEncoderLayer(nn.Module):
|
||||
attn_mask=src_mask,
|
||||
key_padding_mask=src_key_padding_mask,
|
||||
)
|
||||
src = src + self.self_attn_scale(src_att, layerdrop_indicator)
|
||||
src = src + src_att
|
||||
|
||||
# convolution module
|
||||
src = src + self.conv_scale(self.conv_module(src, src_key_padding_mask=src_key_padding_mask),
|
||||
layerdrop_indicator)
|
||||
src = src + self.conv_module(src, src_key_padding_mask=src_key_padding_mask)
|
||||
|
||||
|
||||
# feed forward module
|
||||
src = src + self.feed_forward_scale(self.feed_forward(src),
|
||||
layerdrop_indicator)
|
||||
|
||||
src = self.final_scale(src, layerdrop_indicator)
|
||||
src = src + self.feed_forward(src)
|
||||
|
||||
src = self.norm_final(self.balancer(src))
|
||||
|
||||
if alpha != 1.0:
|
||||
src = alpha * src + (1 - alpha) * src_orig
|
||||
if alpha != 1.0 or layerdrop_indicator != 1.0 or self.training:
|
||||
# the if(self.training) part is to ensure we have a derivative for
|
||||
# self.scale_alpha.
|
||||
src_offset = src - src_orig
|
||||
scale = alpha * (1.0 + self.scale_alpha * (1.0 - layerdrop_indicator))
|
||||
src = src_orig + src_offset * scale
|
||||
|
||||
return src, attn_scores_out
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user