Merge branch 'rework2i_restoredrop_scaled_warmup' into rework2i_restoredrop_scaled_warmup_2proj

# Conflicts:
#	egs/librispeech/ASR/pruned_transducer_stateless2/model.py
This commit is contained in:
Daniel Povey 2022-03-31 14:45:55 +08:00
commit 49bc761ba1
2 changed files with 9 additions and 19 deletions

View File

@ -206,10 +206,8 @@ class ConformerEncoderLayer(nn.Module):
pos_emb: Positional embedding tensor (required). pos_emb: Positional embedding tensor (required).
src_mask: the mask for the src sequence (optional). src_mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional). src_key_padding_mask: the mask for the src keys per batch (optional).
warmup: controls selective activation of layers; if < 0.5, it's possible that warmup: controls selective bypass of of layers; if < 1.0, we will
not all modules will be included. Actually we add the bypass layers more frequently.
feed_forward_macaron and self_attn modules at warmup=0.0
and the conv_module and feed_forward at warmup=0.5.
Shape: Shape:
src: (S, N, E). src: (S, N, E).
@ -219,19 +217,11 @@ class ConformerEncoderLayer(nn.Module):
S is the source sequence length, N is the batch size, E is the feature number S is the source sequence length, N is the batch size, E is the feature number
""" """
src_orig = src src_orig = src
# when warmup == 0.0, alpha is always 0.1, but it gradually changes to
# always being 1.0 when warmup equals 1.0. The reason for using 0.1 and not warmup_scale = min(0.1 + warmup, 1.0)
# 0.0 is that it gives us a gradient so we can learn something when we are turned # alpha = 1.0 means fully use this encoder layer, 0.0 would mean completely
# off. # bypass it.
# alpha = 0.1 if torch.rand(()).item() <= 0.9 else warmup_scale
# min(0.1, warmup)
# is used in place of warmup to ensure that even at the start of the warm-up
# period we sometimes use scale 1.0; this ensures that the modules do not
# compensate for the small scale by just producing larger output.
warmup = max(warmup, 0.1)
if self.training:
warmup = min(warmup, 0.95) # effectively, layer-drop with 1-in-20 prob.
alpha = 1.0 if torch.rand(()).item() <= warmup else 0.1
# macaron style feed forward module # macaron style feed forward module
src = src + self.dropout(self.feed_forward_macaron(src)) src = src + self.dropout(self.feed_forward_macaron(src))

View File

@ -62,9 +62,9 @@ class Transducer(nn.Module):
self.joiner = joiner self.joiner = joiner
self.simple_am_proj = ScaledLinear(embedding_dim, vocab_size, self.simple_am_proj = ScaledLinear(embedding_dim, vocab_size,
initial_speed=0.25) initial_speed=0.5)
self.simple_lm_proj = ScaledLinear(embedding_dim, vocab_size, self.simple_lm_proj = ScaledLinear(embedding_dim, vocab_size,
initial_speed=0.25) initial_speed=0.5)
with torch.no_grad(): with torch.no_grad():
# Initialize the two projections to be the same; this will be # Initialize the two projections to be the same; this will be
# convenient for the real joiner, which adds the endcoder # convenient for the real joiner, which adds the endcoder