Increase lr_est_period

This commit is contained in:
Daniel Povey 2022-07-08 05:51:18 +08:00
parent fb36712e6b
commit ceb9815f2b

View File

@ -68,7 +68,7 @@ class LearnedGradient(Optimizer):
param_max_rms=2.0,
lr_mat_min=0.01,
lr_mat_max=4.0,
lr_est_period=5,
lr_est_period=2,
diagonalize_period=4,
):
@ -251,15 +251,13 @@ class LearnedGradient(Optimizer):
# parameter rms. Updates delta.
self._step_scalar(beta1, beta2, eps, lr, p, grad, state)
else:
if step % lr_est_period == 0:
if step % lr_est_period == 0 and step > 0:
self._update_lrs(group, p, state)
if step % (lr_est_period * diagonalize_period) == 0:
if step % (lr_est_period * diagonalize_period) == 0 and step > 0:
logging.info("Diagonalizing")
self._diagonalize_lrs(group, p, state)
self._zero_exp_avg_sq(state)
if True:
self._step(group, p, grad, state)
self._step(group, p, grad, state)
p.add_(delta)
state["step"] = step + 1
@ -574,17 +572,25 @@ class LearnedGradient(Optimizer):
# Now,
# grad_cov == M_grad^T M_grad,
# decaying-averaged over minibatches; this is of shape (size,size).
# Using N_grad = M_grad Q^T, we can write:
# N_grad_cov = Q M_grad^T M_grad Q^T
# Using N_grad = M_grad Q^T, we can write the gradient covariance w.r.t. N
# (which is what we want to diagonalize), as::
# N_grad_cov = N_grad^T N_grad
# = Q M_grad^T M_grad Q^T
# = Q grad_cov Q^T
# (note: this makes sense because the 1st index of Q is the diagonalized
# index).
def dispersion(v):
# a ratio that, if the eigenvalues are more dispersed, it will be larger.
return '%.3e' % ((v**2).mean() / (v.mean() ** 2)).item()
grad_cov = state[f"grad_cov_{dim}"]
N_grad_cov = torch.matmul(Q, torch.matmul(grad_cov, Q.t()))
N_grad_cov = N_grad_cov + N_grad_cov.t() # ensure symmetric
U, S, V = N_grad_cov.svd()
logging.info(f"Diagonalizing, dispersion changed from {dispersion(N_grad_cov.diag())} to {dispersion(S)}")
# N_grad_cov is SPD, so
# N_grad_cov = U S U^T.
@ -605,7 +611,7 @@ class LearnedGradient(Optimizer):
# we modify hat_N, to keep the product M unchanged:
# M = N Q = N U U^T Q = hat_N hat_Q
# This can be done by setting
# hat_Q = U^T Q (eq.10)
# hat_Q := U^T Q (eq.10)
#
# This is the only thing we have to do, as N is implicit
# and not materialized at this point.
@ -759,7 +765,8 @@ class LearnedGradient(Optimizer):
if bias_correction2 < 0.99:
# note: not in-place.
# scalar_exp_avg_sq is also zeroed at step "zero_step", so
# use the same bias_correction2.
scalar_exp_avg_sq = scalar_exp_avg_sq * (1.0 / bias_correction2)
denom = scalar_exp_avg_sq.sqrt() + eps # scalar.