From a56fb98700de889d2a023f58ae52aa256fbd9eab Mon Sep 17 00:00:00 2001 From: dohe0342 Date: Sat, 10 Dec 2022 14:59:48 +0900 Subject: [PATCH] from local --- .../.train.py.swp | Bin 110592 -> 110592 bytes .../train.py | 32 +++++++++++++----- 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.train.py.swp b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.train.py.swp index dcac3d7802ea32f2ed48afbc3794b06a4007997f..663a77250c8d9cb53fc773cd910749c2e30d8799 100644 GIT binary patch delta 448 zcmbu(ze@sP7zgm@wezmBNG**rIBC#q5N!x5$%YyV8iWc0E$9?#krE>6&L5CLOCY?# zO+k}H1r`@IwFI>kjbTe%|3PC*(D$xCp$9&E-}iap%UdzUiYfZmVoS60!LewJFm9qV z`B-_p_Dy(M!bo|IOlLqK3OoNE6Wp_zEPr$m)gcc95P%;S(K8&u3qERa05CR7@MWRcn!U9Y}90a(1Tnl>it`&S!r!$R zR}Q@TX_`6zQ{=K*mF>c6Hm7h?i*#9cYph)^Z}0RymG$Z63={RY0_)aacbT*=7w_aE H`yBZN`H^qM delta 430 zcmXZW+bhFy7zXh7%Z_Z!Y9$wbi(Vn+w-+ zYDvnqOXWsF85geo37+uOr{3PT-kwxmlgeu>6ONIdUTd4(L5PXyL^+s@U0K{_Vb2Yk zb7RI=4bg@6?`Y?;#g5fS1JOFH!7R9d;I5eH5W1iRn!x~HMMPQHf=w8Qa?rrJ9(x#v zRxrSuj_4k4VFx_$jRu#1CQdrVkcJQpK_e)jhHKQ#fD>#`0MBZo{hWHKN+rLxML%~A znc1DFRU>$rBsnyX#eD`dxnnV+J=+!U+f q(O5TUBQAcA)&IWbig*b(#QQiBZICa^qEtRagmzs?+$ah0(fk8S$7vA& diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train.py index c87de67cd..ff9d887c1 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train.py @@ -1037,16 +1037,30 @@ def train_one_epoch( ) if batch_idx % params.log_interval == 0: - cur_lr = scheduler.get_last_lr()[0] - cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + if params.multi_optim: + cur_enc_lr = scheduler.get_last_lr()[0] + cur_dec_lr = scheduler.get_last_lr()[0] + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") - ) + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + else: + cur_lr = scheduler.get_last_lr()[0] + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) if tb_writer is not None: tb_writer.add_scalar(