mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
fix optim.py
This commit is contained in:
parent
9ef5dfe380
commit
90dac69bc5
@ -298,19 +298,13 @@ class ScaledAdam(BatchedOptimizer):
|
||||
# case 2 or case 4
|
||||
# the input is groups of parameter or named parameter.
|
||||
for cur_group in iterable_or_groups:
|
||||
if "params" in cur_group:
|
||||
for p in cur_group["params"]:
|
||||
assert isinstance(p, torch.Tensor)
|
||||
param_groups.append(cur_group)
|
||||
param_groups_names.append(["foo"] * len(p_list))
|
||||
else:
|
||||
assert "named_params" in cur_group
|
||||
name_list = [ x[0] for x in cur_group["named_params"] ]
|
||||
p_list = [ x[1] for x in cur_group["named_params"] ]
|
||||
del cur_group["named_params"]
|
||||
cur_group["params"] = p_list
|
||||
param_groups.append(cur_group)
|
||||
param_groups_names.append(name_list)
|
||||
assert "named_params" in cur_group
|
||||
name_list = [x[0] for x in cur_group["named_params"]]
|
||||
p_list = [x[1] for x in cur_group["named_params"]]
|
||||
del cur_group["named_params"]
|
||||
cur_group["params"] = p_list
|
||||
param_groups.append(cur_group)
|
||||
param_groups_names.append(name_list)
|
||||
|
||||
return param_groups, param_groups_names
|
||||
|
||||
@ -674,8 +668,7 @@ class ScaledAdam(BatchedOptimizer):
|
||||
# We have to look at the trained model for parameters at or around the
|
||||
# param_max_rms, because sometimes they can indicate a problem with the
|
||||
# topology or settings.
|
||||
scale_step = torch.minimum(scale_step,
|
||||
(param_max_rms - param_rms) / param_rms)
|
||||
scale_step = torch.minimum(scale_step, (param_max_rms - param_rms) / param_rms)
|
||||
|
||||
delta = state["delta"]
|
||||
# the factor of (1-beta1) relates to momentum.
|
||||
|
Loading…
x
Reference in New Issue
Block a user