diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 1aa9d1fe3..c4b6677dc 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -666,10 +666,14 @@ def compute_loss( s if batch_idx_train >= warm_step else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) ) + pruned_loss_scale = ( + 1.0 if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) loss = ( - simple_loss_scale * simple_loss - + pruned_loss + simple_loss_scale * simple_loss + + pruned_loss_scale * pruned_loss ) assert loss.requires_grad == is_training