diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index ae45db60f..851822aae 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -450,9 +450,8 @@ def compute_loss( lm_scale=params.lm_scale, warmup_mode=warmup_mode, ) - loss = params.simple_loss_scale * simple_loss - if not warmup_mode: - loss = loss + (pruned_loss * 0.01 if warmup_mode else pruned_loss) + loss = (params.simple_loss_scale * simple_loss + + (pruned_loss * 0.01 if warmup_mode else pruned_loss)) assert loss.requires_grad == is_training @@ -687,18 +686,8 @@ def run(rank, world_size, args): # Keep only utterances with duration between 1 second and 20 seconds return 1.0 <= c.duration <= 20.0 - num_in_total = len(train_cuts) - train_cuts = train_cuts.filter(remove_short_and_long_utt) - num_left = len(train_cuts) - num_removed = num_in_total - num_left - removed_percent = num_removed / num_in_total * 100 - - logging.info(f"Before removing short and long utterances: {num_in_total}") - logging.info(f"After removing short and long utterances: {num_left}") - logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)") - train_dl = librispeech.train_dataloaders(train_cuts) valid_cuts = librispeech.dev_clean_cuts()