diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index a81777353..64030ef90 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -219,9 +219,23 @@ class ConformerEncoderLayer(nn.Module): src_key_padding_mask: (N, S). S is the source sequence length, N is the batch size, E is the feature number """ + src_orig = src + # when warmup == 0.0, alpha is always 0.1, but it gradually changes to + # always being 1.0 when warmup equals 1.0. The reason for using 0.1 and not + # 0.0 is that it gives us a gradient so we can learn something when we are not + # being very useful. The occasional 1.0 will ensure, via self.balancer, that + # the outputs of our modules don't get scaled up too much. + + # min(0.1, warmup) + # is used in place of warmup to ensure that even at the start of the warm-up + # period we sometimes use scale 1.0; this ensures that the modules do not + # compensate for the small scale by just producing larger output. + warmup = max(warmup, 0.1) + warmup = min(warmup, 0.95) # effectively, layer-drop. + alpha = 1.0 if torch.rand(()).item() <= warmup else 0.1 + # macaron style feed forward module - src = torch.add(src, self.dropout(self.feed_forward_macaron(src)), - alpha=(0.0 if warmup < 0.0 else 1.0)) + src = torch.add(src, self.dropout(self.feed_forward_macaron(src))) # multi-headed self-attention module @@ -233,19 +247,19 @@ class ConformerEncoderLayer(nn.Module): attn_mask=src_mask, key_padding_mask=src_key_padding_mask, )[0] - src = torch.add(src, self.dropout(src_att), - alpha=(0.0 if warmup < 0.0 else 1.0)) + src = torch.add(src, self.dropout(src_att)) # convolution module - src = torch.add(src, self.dropout(self.conv_module(src)), - alpha=(0.0 if warmup < 0.5 else 1.0)) + src = torch.add(src, self.dropout(self.conv_module(src))) # feed forward module - src = torch.add(src, self.dropout(self.feed_forward(src)), - alpha=(0.0 if warmup < 0.5 else 1.0)) + src = torch.add(src, self.dropout(self.feed_forward(src))) src = self.norm_final(self.balancer(src)) + if alpha != 1.0: + src = alpha * src + (1-alpha) * src_orig + return src @@ -309,7 +323,7 @@ class ConformerEncoder(nn.Module): pos_emb, src_mask=mask, src_key_padding_mask=src_key_padding_mask, - warmup=warmup-0.5*(i / num_layers) + warmup=warmup, ) return output diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 01cf289f5..35991f5e9 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -296,7 +296,7 @@ def get_params() -> AttributeDict: "embedding_dim": 512, # parameters for Noam "warm_step": 60000, # For the 100h subset, use 8k - "model_warm_step": 4000, # arg given to model, not for lrate + "model_warm_step": 3000, # arg given to model, not for lrate "env_info": get_env_info(), } ) @@ -501,8 +501,15 @@ def compute_loss( lm_scale=params.lm_scale, warmup=warmup, ) + # after the main warmup step, we keep pruned_loss_scale small + # for the same amount of time (model_warm_step), to avoid + # overwhelming the simple_loss and causing it to diverge, + # in case it had not fully learned the alignment yet. + pruned_loss_scale = (0.0 if warmup < 1.0 else + (0.1 if warmup > 1.0 and warmup < 2.0) else + 1.0) loss = (params.simple_loss_scale * simple_loss + - (pruned_loss * 0.0 if warmup < 1.0 else pruned_loss)) + pruned_loss_scale * pruned_loss) assert loss.requires_grad == is_training