diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index b7cd45334..f95d8e73c 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -450,7 +450,9 @@ def compute_loss( lm_scale=params.lm_scale, warmup_mode=warmup_mode, ) - loss = params.simple_loss_scale * simple_loss + pruned_loss + loss = params.simple_loss_scale * simple_loss + if not warmup_mode: + loss = loss + pruned_loss * (0.1 if warmup_mode else 1.0) assert loss.requires_grad == is_training