mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Make warmup work by scaling layer contributions; leave residual layer-drop
This commit is contained in:
parent
1f548548d2
commit
4b650e9f01
@ -219,9 +219,23 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
src_key_padding_mask: (N, S).
|
src_key_padding_mask: (N, S).
|
||||||
S is the source sequence length, N is the batch size, E is the feature number
|
S is the source sequence length, N is the batch size, E is the feature number
|
||||||
"""
|
"""
|
||||||
|
src_orig = src
|
||||||
|
# when warmup == 0.0, alpha is always 0.1, but it gradually changes to
|
||||||
|
# always being 1.0 when warmup equals 1.0. The reason for using 0.1 and not
|
||||||
|
# 0.0 is that it gives us a gradient so we can learn something when we are not
|
||||||
|
# being very useful. The occasional 1.0 will ensure, via self.balancer, that
|
||||||
|
# the outputs of our modules don't get scaled up too much.
|
||||||
|
|
||||||
|
# min(0.1, warmup)
|
||||||
|
# is used in place of warmup to ensure that even at the start of the warm-up
|
||||||
|
# period we sometimes use scale 1.0; this ensures that the modules do not
|
||||||
|
# compensate for the small scale by just producing larger output.
|
||||||
|
warmup = max(warmup, 0.1)
|
||||||
|
warmup = min(warmup, 0.95) # effectively, layer-drop.
|
||||||
|
alpha = 1.0 if torch.rand(()).item() <= warmup else 0.1
|
||||||
|
|
||||||
# macaron style feed forward module
|
# macaron style feed forward module
|
||||||
src = torch.add(src, self.dropout(self.feed_forward_macaron(src)),
|
src = torch.add(src, self.dropout(self.feed_forward_macaron(src)))
|
||||||
alpha=(0.0 if warmup < 0.0 else 1.0))
|
|
||||||
|
|
||||||
|
|
||||||
# multi-headed self-attention module
|
# multi-headed self-attention module
|
||||||
@ -233,19 +247,19 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
attn_mask=src_mask,
|
attn_mask=src_mask,
|
||||||
key_padding_mask=src_key_padding_mask,
|
key_padding_mask=src_key_padding_mask,
|
||||||
)[0]
|
)[0]
|
||||||
src = torch.add(src, self.dropout(src_att),
|
src = torch.add(src, self.dropout(src_att))
|
||||||
alpha=(0.0 if warmup < 0.0 else 1.0))
|
|
||||||
|
|
||||||
# convolution module
|
# convolution module
|
||||||
src = torch.add(src, self.dropout(self.conv_module(src)),
|
src = torch.add(src, self.dropout(self.conv_module(src)))
|
||||||
alpha=(0.0 if warmup < 0.5 else 1.0))
|
|
||||||
|
|
||||||
# feed forward module
|
# feed forward module
|
||||||
src = torch.add(src, self.dropout(self.feed_forward(src)),
|
src = torch.add(src, self.dropout(self.feed_forward(src)))
|
||||||
alpha=(0.0 if warmup < 0.5 else 1.0))
|
|
||||||
|
|
||||||
src = self.norm_final(self.balancer(src))
|
src = self.norm_final(self.balancer(src))
|
||||||
|
|
||||||
|
if alpha != 1.0:
|
||||||
|
src = alpha * src + (1-alpha) * src_orig
|
||||||
|
|
||||||
return src
|
return src
|
||||||
|
|
||||||
|
|
||||||
@ -309,7 +323,7 @@ class ConformerEncoder(nn.Module):
|
|||||||
pos_emb,
|
pos_emb,
|
||||||
src_mask=mask,
|
src_mask=mask,
|
||||||
src_key_padding_mask=src_key_padding_mask,
|
src_key_padding_mask=src_key_padding_mask,
|
||||||
warmup=warmup-0.5*(i / num_layers)
|
warmup=warmup,
|
||||||
)
|
)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
@ -296,7 +296,7 @@ def get_params() -> AttributeDict:
|
|||||||
"embedding_dim": 512,
|
"embedding_dim": 512,
|
||||||
# parameters for Noam
|
# parameters for Noam
|
||||||
"warm_step": 60000, # For the 100h subset, use 8k
|
"warm_step": 60000, # For the 100h subset, use 8k
|
||||||
"model_warm_step": 4000, # arg given to model, not for lrate
|
"model_warm_step": 3000, # arg given to model, not for lrate
|
||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -501,8 +501,15 @@ def compute_loss(
|
|||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
warmup=warmup,
|
warmup=warmup,
|
||||||
)
|
)
|
||||||
|
# after the main warmup step, we keep pruned_loss_scale small
|
||||||
|
# for the same amount of time (model_warm_step), to avoid
|
||||||
|
# overwhelming the simple_loss and causing it to diverge,
|
||||||
|
# in case it had not fully learned the alignment yet.
|
||||||
|
pruned_loss_scale = (0.0 if warmup < 1.0 else
|
||||||
|
(0.1 if warmup > 1.0 and warmup < 2.0) else
|
||||||
|
1.0)
|
||||||
loss = (params.simple_loss_scale * simple_loss +
|
loss = (params.simple_loss_scale * simple_loss +
|
||||||
(pruned_loss * 0.0 if warmup < 1.0 else pruned_loss))
|
pruned_loss_scale * pruned_loss)
|
||||||
|
|
||||||
assert loss.requires_grad == is_training
|
assert loss.requires_grad == is_training
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user