Fix potential bug re scale change

This commit is contained in:
Daniel Povey 2022-06-15 12:49:47 +08:00
parent 57957cc049
commit bc5d72b0f3

View File

@ -66,12 +66,12 @@ class NeutralGradient(Optimizer):
params, params,
lr=1e-2, lr=1e-2,
betas=(0.9, 0.98, 0.98), betas=(0.9, 0.98, 0.98),
scale_speed=0.025, scale_speed=0.05,
eps=1e-8, eps=1e-8,
param_eps=1.0e-05, param_eps=1.0e-05,
cond_eps=1.0e-10, cond_eps=1.0e-10,
param_max=10.0, param_max=10.0,
min_diag_smooth=0.2, min_diag_smooth=0.5,
max_size=1023, max_size=1023,
stats_period=1, stats_period=1,
estimate_period=200, estimate_period=200,
@ -378,14 +378,17 @@ class NeutralGradient(Optimizer):
assert step % stats_period == 0 assert step % stats_period == 0
norm_step = step // stats_period norm_step = step // stats_period
def update_ref_exp_avg_sq(ref_exp_avg_sq, grad_cov_smoothed, dim): scale_change = True
def update_ref_exp_avg_sq(ref_exp_avg_sq, grad_cov_smoothed, dim,
scale_change):
""" """
Updates ref_exp_avg_sq, taking into account `grad_cov_smoothed` which is Updates ref_exp_avg_sq, taking into account `grad_cov_smoothed` which is
a tensor of shape (p.shape[dim],) containing the smoothed diagonal covariance a tensor of shape (p.shape[dim],) containing the smoothed diagonal covariance
of grads for dimension `dim`. of grads for dimension `dim`.
Returns new ref_exp_avg_sq Returns new ref_exp_avg_sq
""" """
if dim > 0: if not scale_change:
# We only want to "count" the overall scale of the gradients' variance # We only want to "count" the overall scale of the gradients' variance
# (i.e. mean variance) once, so for all dims other than 0 we normalize # (i.e. mean variance) once, so for all dims other than 0 we normalize
# this out. # this out.
@ -412,12 +415,14 @@ class NeutralGradient(Optimizer):
grad_var_smoothed = grad_cov + eps*eps grad_var_smoothed = grad_cov + eps*eps
ref_exp_avg_sq = update_ref_exp_avg_sq(ref_exp_avg_sq, ref_exp_avg_sq = update_ref_exp_avg_sq(ref_exp_avg_sq,
grad_var_smoothed, grad_var_smoothed,
dim) dim, scale_change)
if dim != 0: if scale_change:
# If dim != 0, make the grad_var_smoothed and scale_change = False # only use the scale change once
# param_var_smoothed have the same mean, because when else:
# preconditioning gradients, we only want to count the # For all but the first non-trivial dim, make the
# overall change in scale once. # grad_var_smoothed and param_var_smoothed have the same
# mean, because when preconditioning gradients, we only want
# to count the overall change in scale once.
grad_var_smoothed *= (param_var_smoothed.sum() / grad_var_smoothed.sum()) grad_var_smoothed *= (param_var_smoothed.sum() / grad_var_smoothed.sum())
proj[:] = (param_var_smoothed / grad_var_smoothed).sqrt() proj[:] = (param_var_smoothed / grad_var_smoothed).sqrt()
else: else:
@ -432,9 +437,11 @@ class NeutralGradient(Optimizer):
min_diag_smooth=min_diag_smooth) min_diag_smooth=min_diag_smooth)
ref_exp_avg_sq = update_ref_exp_avg_sq(ref_exp_avg_sq, ref_exp_avg_sq = update_ref_exp_avg_sq(ref_exp_avg_sq,
grad_cov_smoothed.diag(), grad_cov_smoothed.diag(),
dim) dim, scale_change)
if dim != 0: if scale_change:
scale_change = False # only use the scale change once
else:
# If dim != 0, make the traces of grad_cov_smoothed and # If dim != 0, make the traces of grad_cov_smoothed and
# param_cov_smoothed be the same, because when # param_cov_smoothed be the same, because when
# preconditioning gradients, we only want to count the # preconditioning gradients, we only want to count the