From 90dac69bc5a83ee282b76f085284c3bd40ece9dd Mon Sep 17 00:00:00 2001 From: marcoyang Date: Tue, 10 Oct 2023 16:51:07 +0800 Subject: [PATCH] fix optim.py --- .../ASR/zipformer_prompt_asr/optim.py | 23 +++++++------------ 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/optim.py b/egs/libriheavy/ASR/zipformer_prompt_asr/optim.py index 4d99983f6..a767761eb 100644 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/optim.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/optim.py @@ -298,19 +298,13 @@ class ScaledAdam(BatchedOptimizer): # case 2 or case 4 # the input is groups of parameter or named parameter. for cur_group in iterable_or_groups: - if "params" in cur_group: - for p in cur_group["params"]: - assert isinstance(p, torch.Tensor) - param_groups.append(cur_group) - param_groups_names.append(["foo"] * len(p_list)) - else: - assert "named_params" in cur_group - name_list = [ x[0] for x in cur_group["named_params"] ] - p_list = [ x[1] for x in cur_group["named_params"] ] - del cur_group["named_params"] - cur_group["params"] = p_list - param_groups.append(cur_group) - param_groups_names.append(name_list) + assert "named_params" in cur_group + name_list = [x[0] for x in cur_group["named_params"]] + p_list = [x[1] for x in cur_group["named_params"]] + del cur_group["named_params"] + cur_group["params"] = p_list + param_groups.append(cur_group) + param_groups_names.append(name_list) return param_groups, param_groups_names @@ -674,8 +668,7 @@ class ScaledAdam(BatchedOptimizer): # We have to look at the trained model for parameters at or around the # param_max_rms, because sometimes they can indicate a problem with the # topology or settings. - scale_step = torch.minimum(scale_step, - (param_max_rms - param_rms) / param_rms) + scale_step = torch.minimum(scale_step, (param_max_rms - param_rms) / param_rms) delta = state["delta"] # the factor of (1-beta1) relates to momentum.