From 05baa3bef9200cc91fc5a0b287a4ba050c2e8650 Mon Sep 17 00:00:00 2001 From: dohe0342 Date: Sat, 10 Dec 2022 13:54:32 +0900 Subject: [PATCH] from local --- .../.train.py.swp | Bin 90112 -> 90112 bytes .../train.py | 12 ++++++++---- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.train.py.swp b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.train.py.swp index 302494e07f07036cb640036a8be60cf307f2d37c..9da16018f9291fb8023968f8c1738bef81f13fd0 100644 GIT binary patch delta 508 zcmXBQ-z!657zgm@v7Is7u+=2Ou@P?8mdKA*k{d6VMKCW};Vkgi|a+oJfOnP@xed zcs3E0;24e|2ojVeqBA.MT_IKT=Pc)>qh81YCNLqZFMxY-~+$GQ8uV(X<$6sFj% z<)&CR?DrqoL~)wC#7&s%g-PKCvy7zpZIjjbSU^!`voX6dvRbF{g07TaG7XS1PikE~KMs@~8a<*mrVJR$u$*73=L~NodMxFM3x=f)7EpB z$EhN#L8lRPs)s&%W0lCb>AxYO*N>=B&RQq(VHdeT2F>uJ0k1ZZJ?vrwS;Wx-FWzfK z&Txt~Brpg!o~$B|xI__Kn8XA+;KYkXa&l}DIl>ZJ;l*=}$O-llgAL_skz3rLfOYhv zY^H`C6p+CzLQwc--!I(Z05kZzK^kKw2KY%njc8Q)DmUM<`a(=~|9!m`nr0B=W_>QL qEc*SZ>MCXq-MTNSJo-;kh4pYsH5n@@b*;?p+Wex{W(=g&mHGv;yh%#{ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train.py index 305ebd8bb..3545e1f35 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train.py @@ -1178,7 +1178,7 @@ def run(rank, world_size, args, wb=None): scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - if checkpoints and ("optimizer" in checkpoints) or ("optimizer_enc" in checkpoints): + if checkpoints and ("optimizer" in checkpoints or "optimizer_enc" in checkpoints): if params.multi_optim: logging.info("Loading optimizer state dict") optimizer_enc.load_state_dict(checkpoints["optimizer_enc"]) @@ -1190,11 +1190,15 @@ def run(rank, world_size, args, wb=None): if ( checkpoints - and "scheduler" in checkpoints + and ("scheduler" in checkpoints or "scheduler_enc" in checkpoints) and checkpoints["scheduler"] is not None ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) + if params.multi_optim: + scheduler_enc.load_state_dict(checkpoints["scheduler_enc"]) + scheduler_dec.load_state_dict(checkpoints["scheduler_dec"]) + else: + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions(