mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-09 17:14:20 +00:00
fix optim.py
This commit is contained in:
parent
9ef5dfe380
commit
90dac69bc5
@ -298,19 +298,13 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
# case 2 or case 4
|
# case 2 or case 4
|
||||||
# the input is groups of parameter or named parameter.
|
# the input is groups of parameter or named parameter.
|
||||||
for cur_group in iterable_or_groups:
|
for cur_group in iterable_or_groups:
|
||||||
if "params" in cur_group:
|
assert "named_params" in cur_group
|
||||||
for p in cur_group["params"]:
|
name_list = [x[0] for x in cur_group["named_params"]]
|
||||||
assert isinstance(p, torch.Tensor)
|
p_list = [x[1] for x in cur_group["named_params"]]
|
||||||
param_groups.append(cur_group)
|
del cur_group["named_params"]
|
||||||
param_groups_names.append(["foo"] * len(p_list))
|
cur_group["params"] = p_list
|
||||||
else:
|
param_groups.append(cur_group)
|
||||||
assert "named_params" in cur_group
|
param_groups_names.append(name_list)
|
||||||
name_list = [ x[0] for x in cur_group["named_params"] ]
|
|
||||||
p_list = [ x[1] for x in cur_group["named_params"] ]
|
|
||||||
del cur_group["named_params"]
|
|
||||||
cur_group["params"] = p_list
|
|
||||||
param_groups.append(cur_group)
|
|
||||||
param_groups_names.append(name_list)
|
|
||||||
|
|
||||||
return param_groups, param_groups_names
|
return param_groups, param_groups_names
|
||||||
|
|
||||||
@ -674,8 +668,7 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
# We have to look at the trained model for parameters at or around the
|
# We have to look at the trained model for parameters at or around the
|
||||||
# param_max_rms, because sometimes they can indicate a problem with the
|
# param_max_rms, because sometimes they can indicate a problem with the
|
||||||
# topology or settings.
|
# topology or settings.
|
||||||
scale_step = torch.minimum(scale_step,
|
scale_step = torch.minimum(scale_step, (param_max_rms - param_rms) / param_rms)
|
||||||
(param_max_rms - param_rms) / param_rms)
|
|
||||||
|
|
||||||
delta = state["delta"]
|
delta = state["delta"]
|
||||||
# the factor of (1-beta1) relates to momentum.
|
# the factor of (1-beta1) relates to momentum.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user