diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.train.py.swp b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.train.py.swp index a1ce21a4c..a736ebcc5 100644 Binary files a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.train.py.swp and b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.train.py.swp differ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train.py index 363edca92..519c78e63 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train.py @@ -1528,64 +1528,8 @@ def run_adapter(rank, world_size, args, wb=None): clipping_scale=None, parameters_names=[adapter_names], ) - - if params.multi_optim: - logging.info("Using seperate optimizers over encoder, decoder ...") - - enc_param = [] - enc_names = [] - - dec_names = [] - dec_param = [] - - for n, p in model.named_parameters(): - name = n.split('.')[1] - if name == 'encoder' and 'feature_extractor' not in n: - enc_names.append(n) - enc_param.append(p) - elif 'ctc_output' in n: - enc_names.append(n) - enc_param.append(p) - elif 'feature_extractor' not in n: - dec_names.append(n) - dec_param.append(p) - - optimizer_enc = ScaledAdam( - enc_param, - lr=params.peak_enc_lr, - clipping_scale=None, - parameters_names=[enc_names], - ) - optimizer_dec = ScaledAdam( - dec_param, - lr=params.peak_dec_lr, - clipping_scale=5.0, - parameters_names=[dec_names], - ) - - scheduler_enc = Eden(optimizer_enc, params.lr_batches, params.lr_epochs) - scheduler_dec = Eden(optimizer_dec, params.lr_batches, params.lr_epochs) - optimizer = [optimizer_enc, optimizer_dec] - scheduler = [scheduler_enc, scheduler_dec] - - else: - parameters_names = [] - parameters_names.append( - [name_param_pair[0] for name_param_pair in model.named_parameters()] - ) - - logging.info(f"len name = {len(parameters_names)}") - logging.info(f"len param = {len(list(model.parameters()))}") - - optimizer = ScaledAdam( - model.parameters(), - lr=params.base_lr, - clipping_scale=2.0, - parameters_names=parameters_names, - ) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - + scheduler_enc = Eden(optimizer_enc, params.lr_batches, params.lr_epochs) + if checkpoints and ("optimizer" in checkpoints or "optimizer_enc" in checkpoints): if params.multi_optim: logging.info("Loading optimizer state dict")