diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index 12095810e..704c17dd7 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -206,10 +206,8 @@ class ConformerEncoderLayer(nn.Module): pos_emb: Positional embedding tensor (required). src_mask: the mask for the src sequence (optional). src_key_padding_mask: the mask for the src keys per batch (optional). - warmup: controls selective activation of layers; if < 0.5, it's possible that - not all modules will be included. Actually we add the - feed_forward_macaron and self_attn modules at warmup=0.0 - and the conv_module and feed_forward at warmup=0.5. + warmup: controls selective bypass of of layers; if < 1.0, we will + bypass layers more frequently. Shape: src: (S, N, E). @@ -219,19 +217,11 @@ class ConformerEncoderLayer(nn.Module): S is the source sequence length, N is the batch size, E is the feature number """ src_orig = src - # when warmup == 0.0, alpha is always 0.1, but it gradually changes to - # always being 1.0 when warmup equals 1.0. The reason for using 0.1 and not - # 0.0 is that it gives us a gradient so we can learn something when we are turned - # off. - # - # min(0.1, warmup) - # is used in place of warmup to ensure that even at the start of the warm-up - # period we sometimes use scale 1.0; this ensures that the modules do not - # compensate for the small scale by just producing larger output. - warmup = max(warmup, 0.1) - if self.training: - warmup = min(warmup, 0.95) # effectively, layer-drop with 1-in-20 prob. - alpha = 1.0 if torch.rand(()).item() <= warmup else 0.1 + + warmup_scale = min(0.1 + warmup, 1.0) + # alpha = 1.0 means fully use this encoder layer, 0.0 would mean completely + # bypass it. + alpha = 0.1 if torch.rand(()).item() <= 0.9 else warmup_scale # macaron style feed forward module src = src + self.dropout(self.feed_forward_macaron(src)) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py index 47a7169b1..0355c4531 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py @@ -62,9 +62,9 @@ class Transducer(nn.Module): self.joiner = joiner self.simple_am_proj = ScaledLinear(embedding_dim, vocab_size, - initial_speed=0.25) + initial_speed=0.5) self.simple_lm_proj = ScaledLinear(embedding_dim, vocab_size, - initial_speed=0.25) + initial_speed=0.5) with torch.no_grad(): # Initialize the two projections to be the same; this will be # convenient for the real joiner, which adds the endcoder