mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Fix potential bug re scale change
This commit is contained in:
parent
57957cc049
commit
bc5d72b0f3
@ -66,12 +66,12 @@ class NeutralGradient(Optimizer):
|
|||||||
params,
|
params,
|
||||||
lr=1e-2,
|
lr=1e-2,
|
||||||
betas=(0.9, 0.98, 0.98),
|
betas=(0.9, 0.98, 0.98),
|
||||||
scale_speed=0.025,
|
scale_speed=0.05,
|
||||||
eps=1e-8,
|
eps=1e-8,
|
||||||
param_eps=1.0e-05,
|
param_eps=1.0e-05,
|
||||||
cond_eps=1.0e-10,
|
cond_eps=1.0e-10,
|
||||||
param_max=10.0,
|
param_max=10.0,
|
||||||
min_diag_smooth=0.2,
|
min_diag_smooth=0.5,
|
||||||
max_size=1023,
|
max_size=1023,
|
||||||
stats_period=1,
|
stats_period=1,
|
||||||
estimate_period=200,
|
estimate_period=200,
|
||||||
@ -378,14 +378,17 @@ class NeutralGradient(Optimizer):
|
|||||||
assert step % stats_period == 0
|
assert step % stats_period == 0
|
||||||
norm_step = step // stats_period
|
norm_step = step // stats_period
|
||||||
|
|
||||||
def update_ref_exp_avg_sq(ref_exp_avg_sq, grad_cov_smoothed, dim):
|
scale_change = True
|
||||||
|
|
||||||
|
def update_ref_exp_avg_sq(ref_exp_avg_sq, grad_cov_smoothed, dim,
|
||||||
|
scale_change):
|
||||||
"""
|
"""
|
||||||
Updates ref_exp_avg_sq, taking into account `grad_cov_smoothed` which is
|
Updates ref_exp_avg_sq, taking into account `grad_cov_smoothed` which is
|
||||||
a tensor of shape (p.shape[dim],) containing the smoothed diagonal covariance
|
a tensor of shape (p.shape[dim],) containing the smoothed diagonal covariance
|
||||||
of grads for dimension `dim`.
|
of grads for dimension `dim`.
|
||||||
Returns new ref_exp_avg_sq
|
Returns new ref_exp_avg_sq
|
||||||
"""
|
"""
|
||||||
if dim > 0:
|
if not scale_change:
|
||||||
# We only want to "count" the overall scale of the gradients' variance
|
# We only want to "count" the overall scale of the gradients' variance
|
||||||
# (i.e. mean variance) once, so for all dims other than 0 we normalize
|
# (i.e. mean variance) once, so for all dims other than 0 we normalize
|
||||||
# this out.
|
# this out.
|
||||||
@ -412,12 +415,14 @@ class NeutralGradient(Optimizer):
|
|||||||
grad_var_smoothed = grad_cov + eps*eps
|
grad_var_smoothed = grad_cov + eps*eps
|
||||||
ref_exp_avg_sq = update_ref_exp_avg_sq(ref_exp_avg_sq,
|
ref_exp_avg_sq = update_ref_exp_avg_sq(ref_exp_avg_sq,
|
||||||
grad_var_smoothed,
|
grad_var_smoothed,
|
||||||
dim)
|
dim, scale_change)
|
||||||
if dim != 0:
|
if scale_change:
|
||||||
# If dim != 0, make the grad_var_smoothed and
|
scale_change = False # only use the scale change once
|
||||||
# param_var_smoothed have the same mean, because when
|
else:
|
||||||
# preconditioning gradients, we only want to count the
|
# For all but the first non-trivial dim, make the
|
||||||
# overall change in scale once.
|
# grad_var_smoothed and param_var_smoothed have the same
|
||||||
|
# mean, because when preconditioning gradients, we only want
|
||||||
|
# to count the overall change in scale once.
|
||||||
grad_var_smoothed *= (param_var_smoothed.sum() / grad_var_smoothed.sum())
|
grad_var_smoothed *= (param_var_smoothed.sum() / grad_var_smoothed.sum())
|
||||||
proj[:] = (param_var_smoothed / grad_var_smoothed).sqrt()
|
proj[:] = (param_var_smoothed / grad_var_smoothed).sqrt()
|
||||||
else:
|
else:
|
||||||
@ -432,9 +437,11 @@ class NeutralGradient(Optimizer):
|
|||||||
min_diag_smooth=min_diag_smooth)
|
min_diag_smooth=min_diag_smooth)
|
||||||
ref_exp_avg_sq = update_ref_exp_avg_sq(ref_exp_avg_sq,
|
ref_exp_avg_sq = update_ref_exp_avg_sq(ref_exp_avg_sq,
|
||||||
grad_cov_smoothed.diag(),
|
grad_cov_smoothed.diag(),
|
||||||
dim)
|
dim, scale_change)
|
||||||
|
|
||||||
if dim != 0:
|
if scale_change:
|
||||||
|
scale_change = False # only use the scale change once
|
||||||
|
else:
|
||||||
# If dim != 0, make the traces of grad_cov_smoothed and
|
# If dim != 0, make the traces of grad_cov_smoothed and
|
||||||
# param_cov_smoothed be the same, because when
|
# param_cov_smoothed be the same, because when
|
||||||
# preconditioning gradients, we only want to count the
|
# preconditioning gradients, we only want to count the
|
||||||
|
Loading…
x
Reference in New Issue
Block a user