diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.train.py.swp b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.train.py.swp index 2afb5b7b4..dc73f9efb 100644 Binary files a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.train.py.swp and b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.train.py.swp differ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train.py index c6f410bd2..51a67e637 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train.py @@ -958,12 +958,22 @@ def train_one_epoch( # NOTE: We use reduction==sum and loss is computed over utterances # in the batch and there is no normalization to it so far. scaler.scale(loss).backward() - set_batch_count(model, params.batch_idx_train) - scheduler.step_batch(params.batch_idx_train) + if params.multi_optim and batch_idx % params.accum_grads == 0: + set_batch_count(model, params.batch_idx_train) + scheduler_enc.step_batch(params.batch_idx_train) + scheduler_dec.step_batch(params.batch_idx_train) + scaler.step(optimizer_enc) + scaler.step(optimizer_dec) + scaler.update() + optimizer_enc.zero_grad() + optimizer_dec.zero_grad() + elif not params.multi_optim and batch_idx % params.accum_grads == 0: + set_batch_count(model, params.batch_idx_train) + scheduler.step_batch(params.batch_idx_train) + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() except: # noqa display_and_save_batch(batch, params=params, sp=sp) raise