Make warmup work by scaling layer contributions; leave residual layer-drop

This commit is contained in:
Daniel Povey 2022-03-25 20:34:33 +08:00
parent 1f548548d2
commit 4b650e9f01
2 changed files with 32 additions and 11 deletions

View File

@ -219,9 +219,23 @@ class ConformerEncoderLayer(nn.Module):
src_key_padding_mask: (N, S). src_key_padding_mask: (N, S).
S is the source sequence length, N is the batch size, E is the feature number S is the source sequence length, N is the batch size, E is the feature number
""" """
src_orig = src
# when warmup == 0.0, alpha is always 0.1, but it gradually changes to
# always being 1.0 when warmup equals 1.0. The reason for using 0.1 and not
# 0.0 is that it gives us a gradient so we can learn something when we are not
# being very useful. The occasional 1.0 will ensure, via self.balancer, that
# the outputs of our modules don't get scaled up too much.
# min(0.1, warmup)
# is used in place of warmup to ensure that even at the start of the warm-up
# period we sometimes use scale 1.0; this ensures that the modules do not
# compensate for the small scale by just producing larger output.
warmup = max(warmup, 0.1)
warmup = min(warmup, 0.95) # effectively, layer-drop.
alpha = 1.0 if torch.rand(()).item() <= warmup else 0.1
# macaron style feed forward module # macaron style feed forward module
src = torch.add(src, self.dropout(self.feed_forward_macaron(src)), src = torch.add(src, self.dropout(self.feed_forward_macaron(src)))
alpha=(0.0 if warmup < 0.0 else 1.0))
# multi-headed self-attention module # multi-headed self-attention module
@ -233,19 +247,19 @@ class ConformerEncoderLayer(nn.Module):
attn_mask=src_mask, attn_mask=src_mask,
key_padding_mask=src_key_padding_mask, key_padding_mask=src_key_padding_mask,
)[0] )[0]
src = torch.add(src, self.dropout(src_att), src = torch.add(src, self.dropout(src_att))
alpha=(0.0 if warmup < 0.0 else 1.0))
# convolution module # convolution module
src = torch.add(src, self.dropout(self.conv_module(src)), src = torch.add(src, self.dropout(self.conv_module(src)))
alpha=(0.0 if warmup < 0.5 else 1.0))
# feed forward module # feed forward module
src = torch.add(src, self.dropout(self.feed_forward(src)), src = torch.add(src, self.dropout(self.feed_forward(src)))
alpha=(0.0 if warmup < 0.5 else 1.0))
src = self.norm_final(self.balancer(src)) src = self.norm_final(self.balancer(src))
if alpha != 1.0:
src = alpha * src + (1-alpha) * src_orig
return src return src
@ -309,7 +323,7 @@ class ConformerEncoder(nn.Module):
pos_emb, pos_emb,
src_mask=mask, src_mask=mask,
src_key_padding_mask=src_key_padding_mask, src_key_padding_mask=src_key_padding_mask,
warmup=warmup-0.5*(i / num_layers) warmup=warmup,
) )
return output return output

View File

@ -296,7 +296,7 @@ def get_params() -> AttributeDict:
"embedding_dim": 512, "embedding_dim": 512,
# parameters for Noam # parameters for Noam
"warm_step": 60000, # For the 100h subset, use 8k "warm_step": 60000, # For the 100h subset, use 8k
"model_warm_step": 4000, # arg given to model, not for lrate "model_warm_step": 3000, # arg given to model, not for lrate
"env_info": get_env_info(), "env_info": get_env_info(),
} }
) )
@ -501,8 +501,15 @@ def compute_loss(
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
warmup=warmup, warmup=warmup,
) )
# after the main warmup step, we keep pruned_loss_scale small
# for the same amount of time (model_warm_step), to avoid
# overwhelming the simple_loss and causing it to diverge,
# in case it had not fully learned the alignment yet.
pruned_loss_scale = (0.0 if warmup < 1.0 else
(0.1 if warmup > 1.0 and warmup < 2.0) else
1.0)
loss = (params.simple_loss_scale * simple_loss + loss = (params.simple_loss_scale * simple_loss +
(pruned_loss * 0.0 if warmup < 1.0 else pruned_loss)) pruned_loss_scale * pruned_loss)
assert loss.requires_grad == is_training assert loss.requires_grad == is_training