diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 714d8db9a..aaffbfed5 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -298,11 +298,14 @@ class ScaledAdam(BatchedOptimizer): # case 2 or case 4 # the input is groups of parameter or named parameter. for cur_group in iterable_or_groups: - assert "named_params" in cur_group - name_list = [x[0] for x in cur_group["named_params"]] - p_list = [x[1] for x in cur_group["named_params"]] - del cur_group["named_params"] - cur_group["params"] = p_list + if "named_params" in cur_group: + name_list = [x[0] for x in cur_group["named_params"]] + p_list = [x[1] for x in cur_group["named_params"]] + del cur_group["named_params"] + cur_group["params"] = p_list + else: + assert "params" in cur_group + name_list = ["foo" for _ in cur_group["params"]] param_groups.append(cur_group) param_groups_names.append(name_list)