diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 0e13dd20f..d4c1709e9 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -854,6 +854,9 @@ def train_one_epoch( cur_batch_idx = params.get("cur_batch_idx", 0) for batch_idx, batch in enumerate(train_dl): + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + if batch_idx < cur_batch_idx: continue cur_batch_idx = batch_idx @@ -876,8 +879,7 @@ def train_one_epoch( # NOTE: We use reduction==sum and loss is computed over utterances # in the batch and there is no normalization to it so far. scaler.scale(loss).backward() - if int(params.batch_idx_train) % 10 == 1: - set_batch_count(model, get_adjusted_batch_count(params)) + scheduler.step_batch(params.batch_idx_train) scaler.step(optimizer)