From b566f3b9254bab734ca8117cfb48f1a95c8aad9f Mon Sep 17 00:00:00 2001 From: dohe0342 Date: Sat, 10 Dec 2022 14:01:43 +0900 Subject: [PATCH] from local --- .../.train.py.swp | Bin 98304 -> 98304 bytes .../train.py | 20 +++++++++++++----- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.train.py.swp b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.train.py.swp index 2afb5b7b4e1ddc96b2d00baf36b823b9825701b5..dc73f9efb89b8c8a326b446ba009e998cd49596b 100644 GIT binary patch delta 590 zcmY+oYfC(6cCU*P|CHM&Ypzl2gPoW#kP&bs$XPM9z zm?XZ2nZ(z#z2Y>@M)|R)Tb!hsl_#@y{YU0+vkpBl>p>a86Imt|htC)L`SynaepWT{ z<%pwoZ$$b;Z-xyBZ;rWmEi(By6pMz{#A-^_{7HX&Z7iOSrJ}x7EftNc{=|wH$+Mw< z1ts4~DA=m%E?OjCDA9iW-@@9dXANTtUa8vnM#L)XnZu$d!zwX&d6&qM*8HO-tx-4h5+^m(aMAp_mg6qHincUd{HMF6 zC7PV+mg z=!e~YXo{QpiYJ|A{y$)IkdgR$!6C z&#?53Pu$@U4a6A7Ads#DEvn5?$2w*)idSE{#W~ujV+|GoKG^V#C;aU>%3K+5-g8OC gceWoKb-6L-N;{_DzG`OT)K-J;d(%WizakGhN&o-= diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train.py index c6f410bd2..51a67e637 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train.py @@ -958,12 +958,22 @@ def train_one_epoch( # NOTE: We use reduction==sum and loss is computed over utterances # in the batch and there is no normalization to it so far. scaler.scale(loss).backward() - set_batch_count(model, params.batch_idx_train) - scheduler.step_batch(params.batch_idx_train) + if params.multi_optim and batch_idx % params.accum_grads == 0: + set_batch_count(model, params.batch_idx_train) + scheduler_enc.step_batch(params.batch_idx_train) + scheduler_dec.step_batch(params.batch_idx_train) + scaler.step(optimizer_enc) + scaler.step(optimizer_dec) + scaler.update() + optimizer_enc.zero_grad() + optimizer_dec.zero_grad() + elif not params.multi_optim and batch_idx % params.accum_grads == 0: + set_batch_count(model, params.batch_idx_train) + scheduler.step_batch(params.batch_idx_train) + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() except: # noqa display_and_save_batch(batch, params=params, sp=sp) raise