diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 8da246c7c..87034b3e3 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -858,7 +858,7 @@ def train_one_epoch( # NOTE: We use reduction==sum and loss is computed over utterances # in the batch and there is no normalization to it so far. scaler.scale(loss).backward() - if params.batch_idx_train % 10 == 0: + if int(params.batch_idx_train) % 10 == 0: set_batch_count(model, params.batch_idx_train) scheduler.step_batch(params.batch_idx_train)