Fix potential bug re scale change

This commit is contained in:
Daniel Povey 2022-06-15 12:49:47 +08:00
parent 57957cc049
commit bc5d72b0f3

View File

@ -66,12 +66,12 @@ class NeutralGradient(Optimizer):
params,
lr=1e-2,
betas=(0.9, 0.98, 0.98),
scale_speed=0.025,
scale_speed=0.05,
eps=1e-8,
param_eps=1.0e-05,
cond_eps=1.0e-10,
param_max=10.0,
min_diag_smooth=0.2,
min_diag_smooth=0.5,
max_size=1023,
stats_period=1,
estimate_period=200,
@ -378,14 +378,17 @@ class NeutralGradient(Optimizer):
assert step % stats_period == 0
norm_step = step // stats_period
def update_ref_exp_avg_sq(ref_exp_avg_sq, grad_cov_smoothed, dim):
scale_change = True
def update_ref_exp_avg_sq(ref_exp_avg_sq, grad_cov_smoothed, dim,
scale_change):
"""
Updates ref_exp_avg_sq, taking into account `grad_cov_smoothed` which is
a tensor of shape (p.shape[dim],) containing the smoothed diagonal covariance
of grads for dimension `dim`.
Returns new ref_exp_avg_sq
"""
if dim > 0:
if not scale_change:
# We only want to "count" the overall scale of the gradients' variance
# (i.e. mean variance) once, so for all dims other than 0 we normalize
# this out.
@ -412,12 +415,14 @@ class NeutralGradient(Optimizer):
grad_var_smoothed = grad_cov + eps*eps
ref_exp_avg_sq = update_ref_exp_avg_sq(ref_exp_avg_sq,
grad_var_smoothed,
dim)
if dim != 0:
# If dim != 0, make the grad_var_smoothed and
# param_var_smoothed have the same mean, because when
# preconditioning gradients, we only want to count the
# overall change in scale once.
dim, scale_change)
if scale_change:
scale_change = False # only use the scale change once
else:
# For all but the first non-trivial dim, make the
# grad_var_smoothed and param_var_smoothed have the same
# mean, because when preconditioning gradients, we only want
# to count the overall change in scale once.
grad_var_smoothed *= (param_var_smoothed.sum() / grad_var_smoothed.sum())
proj[:] = (param_var_smoothed / grad_var_smoothed).sqrt()
else:
@ -432,9 +437,11 @@ class NeutralGradient(Optimizer):
min_diag_smooth=min_diag_smooth)
ref_exp_avg_sq = update_ref_exp_avg_sq(ref_exp_avg_sq,
grad_cov_smoothed.diag(),
dim)
dim, scale_change)
if dim != 0:
if scale_change:
scale_change = False # only use the scale change once
else:
# If dim != 0, make the traces of grad_cov_smoothed and
# param_cov_smoothed be the same, because when
# preconditioning gradients, we only want to count the