fix optim.py

This commit is contained in:
marcoyang 2023-10-10 16:51:07 +08:00
parent 9ef5dfe380
commit 90dac69bc5

View File

@ -298,19 +298,13 @@ class ScaledAdam(BatchedOptimizer):
# case 2 or case 4
# the input is groups of parameter or named parameter.
for cur_group in iterable_or_groups:
if "params" in cur_group:
for p in cur_group["params"]:
assert isinstance(p, torch.Tensor)
param_groups.append(cur_group)
param_groups_names.append(["foo"] * len(p_list))
else:
assert "named_params" in cur_group
name_list = [ x[0] for x in cur_group["named_params"] ]
p_list = [ x[1] for x in cur_group["named_params"] ]
del cur_group["named_params"]
cur_group["params"] = p_list
param_groups.append(cur_group)
param_groups_names.append(name_list)
assert "named_params" in cur_group
name_list = [x[0] for x in cur_group["named_params"]]
p_list = [x[1] for x in cur_group["named_params"]]
del cur_group["named_params"]
cur_group["params"] = p_list
param_groups.append(cur_group)
param_groups_names.append(name_list)
return param_groups, param_groups_names
@ -674,8 +668,7 @@ class ScaledAdam(BatchedOptimizer):
# We have to look at the trained model for parameters at or around the
# param_max_rms, because sometimes they can indicate a problem with the
# topology or settings.
scale_step = torch.minimum(scale_step,
(param_max_rms - param_rms) / param_rms)
scale_step = torch.minimum(scale_step, (param_max_rms - param_rms) / param_rms)
delta = state["delta"]
# the factor of (1-beta1) relates to momentum.