From a1b59b8792535d4728580c65392957057376a7b7 Mon Sep 17 00:00:00 2001 From: dohe0342 Date: Mon, 26 Dec 2022 13:58:38 +0900 Subject: [PATCH] from local --- .../.train.py.swp | Bin 94208 -> 94208 bytes .../train.py | 60 +----------------- 2 files changed, 2 insertions(+), 58 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.train.py.swp b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.train.py.swp index a1ce21a4c02dc4f8316ed0280000a7df75c41ebd..a736ebcc59df3a1f3c88b6ef48e834f7f1e6a177 100644 GIT binary patch delta 770 zcmb8rO(+Cm7zgm@#n=x!>`pr?lt{v5%$9s4Id57y_$bj9t)o^sptLKFa(UId?1j@F z9JGg3a=>21iQ*!ba#4hwe8m67JFVhi>NoSuy#M!kW+s+ZVp&BWG5Z6-sy4$U9B!i1 zhU94PhF+mDb>6R~{Te&&P*W;Ve(_g8jeB}nP4peN&#(py&;=rr2XBbKg4mXX85o3i zsDKw2(Ic$G1VmsAba*NvI)OttfE5UX7xtV)yRhk`C@rBFhc0LWHx$B|gJ=s@U=D_1 z2wK4l*LI>ENWe7ofgkQ{7zhbihFKVcFmyu;G{a4y_){sRU{T>04|YCvuCfr1oaFxs%@;F4?M1Moet9{*X1+ mDx5NBJt@zfB>LK9&nu-02gSE*S&$#C`grzPOHYQ_qWc4VVB_lm delta 375 zcmXBQKS%;`7{>AU2c4;xWrx;KkrYKi$Qd*^)Ra(2+Ehl6U{I4o(ke)87*#sK}(zVhoCW{rNR1EFMQw~o_BbkVoNKwv|z?uSzHQ5!=^BPk-N$AX81BV zV5r!B#Bd^px}WuZY9eEu|ACm@yR1eF@g5NqFWn-~xJ4Be9OD3aBr%0IpGY02$YTvN zaN(&-`VGn7{ApkF4b&&=xP|~F!8I}}Q;X>8RA(BX79^=qZ@razGj4UQFjBmF{ z6E`@)4z{s{Aq=9<-)G3;uY=#B-FgfM?9jSeu-YGec6LJ@I*k!k Wb6N3r&Yn)FeyjCmIJ;?8^8WzB8%FE^ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train.py index 363edca92..519c78e63 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train.py @@ -1528,64 +1528,8 @@ def run_adapter(rank, world_size, args, wb=None): clipping_scale=None, parameters_names=[adapter_names], ) - - if params.multi_optim: - logging.info("Using seperate optimizers over encoder, decoder ...") - - enc_param = [] - enc_names = [] - - dec_names = [] - dec_param = [] - - for n, p in model.named_parameters(): - name = n.split('.')[1] - if name == 'encoder' and 'feature_extractor' not in n: - enc_names.append(n) - enc_param.append(p) - elif 'ctc_output' in n: - enc_names.append(n) - enc_param.append(p) - elif 'feature_extractor' not in n: - dec_names.append(n) - dec_param.append(p) - - optimizer_enc = ScaledAdam( - enc_param, - lr=params.peak_enc_lr, - clipping_scale=None, - parameters_names=[enc_names], - ) - optimizer_dec = ScaledAdam( - dec_param, - lr=params.peak_dec_lr, - clipping_scale=5.0, - parameters_names=[dec_names], - ) - - scheduler_enc = Eden(optimizer_enc, params.lr_batches, params.lr_epochs) - scheduler_dec = Eden(optimizer_dec, params.lr_batches, params.lr_epochs) - optimizer = [optimizer_enc, optimizer_dec] - scheduler = [scheduler_enc, scheduler_dec] - - else: - parameters_names = [] - parameters_names.append( - [name_param_pair[0] for name_param_pair in model.named_parameters()] - ) - - logging.info(f"len name = {len(parameters_names)}") - logging.info(f"len param = {len(list(model.parameters()))}") - - optimizer = ScaledAdam( - model.parameters(), - lr=params.base_lr, - clipping_scale=2.0, - parameters_names=parameters_names, - ) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - + scheduler_enc = Eden(optimizer_enc, params.lr_batches, params.lr_epochs) + if checkpoints and ("optimizer" in checkpoints or "optimizer_enc" in checkpoints): if params.multi_optim: logging.info("Loading optimizer state dict")