minor fix of zipformer/optim.py

This commit is contained in:
yaozengwei 2024-01-26 15:39:58 +08:00
parent 231bbcd2b6
commit 951b537e40

View File

@ -298,11 +298,14 @@ class ScaledAdam(BatchedOptimizer):
# case 2 or case 4
# the input is groups of parameter or named parameter.
for cur_group in iterable_or_groups:
assert "named_params" in cur_group
name_list = [x[0] for x in cur_group["named_params"]]
p_list = [x[1] for x in cur_group["named_params"]]
del cur_group["named_params"]
cur_group["params"] = p_list
if "named_params" in cur_group:
name_list = [x[0] for x in cur_group["named_params"]]
p_list = [x[1] for x in cur_group["named_params"]]
del cur_group["named_params"]
cur_group["params"] = p_list
else:
assert "params" in cur_group
name_list = ["foo" for _ in cur_group["params"]]
param_groups.append(cur_group)
param_groups_names.append(name_list)