minor fix of zipformer/optim.py (#1474)

This commit is contained in:
Zengwei Yao 2024-01-26 15:50:11 +08:00 committed by GitHub
parent 9c494a3329
commit c401a2646b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -298,11 +298,14 @@ class ScaledAdam(BatchedOptimizer):
# case 2 or case 4 # case 2 or case 4
# the input is groups of parameter or named parameter. # the input is groups of parameter or named parameter.
for cur_group in iterable_or_groups: for cur_group in iterable_or_groups:
assert "named_params" in cur_group if "named_params" in cur_group:
name_list = [x[0] for x in cur_group["named_params"]] name_list = [x[0] for x in cur_group["named_params"]]
p_list = [x[1] for x in cur_group["named_params"]] p_list = [x[1] for x in cur_group["named_params"]]
del cur_group["named_params"] del cur_group["named_params"]
cur_group["params"] = p_list cur_group["params"] = p_list
else:
assert "params" in cur_group
name_list = ["foo" for _ in cur_group["params"]]
param_groups.append(cur_group) param_groups.append(cur_group)
param_groups_names.append(name_list) param_groups_names.append(name_list)