Couple configuration changes, comment simplification

This commit is contained in:
Daniel Povey 2022-07-08 09:46:42 +08:00
parent 75e872ea57
commit be6680e3ba

View File

@ -60,7 +60,7 @@ class LearnedGradient(Optimizer):
params,
lr=3e-02,
size_lr_scale=0.1,
meta_lr_scale=0.0,
meta_lr_scale=0.05,
betas=(0.9, 0.98),
eps=1.0e-08,
size_update_period=1,
@ -189,9 +189,9 @@ class LearnedGradient(Optimizer):
state[f"Q_{dim}"] = torch.eye(size, **kwargs)
# proj_grad_{dim} is the gradient w.r.t. `proj`,
# which is a matrix we introduce to the right of p, i.e.
# instead of M == torch.matmul(N, p) we view it as
# M == torch.matmul(torch.matmul(N, torch.matmul(p, proj)))
# which is a matrix we introduce to the right of Q, i.e.
# instead of M == torch.matmul(N, Q) we view it as
# M == torch.matmul(torch.matmul(N, torch.matmul(Q, proj)))
# or equivalently M == torch.matmul(M_underlying, proj)
# where M_underlying is numerically equal to M, and proj is numerically
# equal to I, but they are viewed as separate variables.
@ -377,19 +377,19 @@ class LearnedGradient(Optimizer):
# which will be of shape (size, size). For purposes of our update, if we view
# the parameter matrix M as (M_underlying proj) with
# proj being numerically equal to I but treated as a variable, i.e. as
# matmul(M_underlying, proj), then instead of the loss being tr(M_grad^T M), we can write
# M = matmul(M_underlying, proj), then instead of the loss being tr(M_grad^T M), we can write
# it as tr(M_grad^T M_underlying proj), which can also be written as tr(proj_grad^T proj)
# where proj_grad == M_underlying^T M_grad.
# where proj_grad == M_underlying^T M_grad == M^T M_grad
#
# Now, proj_grad is convenient to compute but it's not in a very normalized
# space, we want to normalize it before doing gradient descent on it.
# Now, proj_grad is convenient to compute, being invariant of Q, but it's not in a very normalized
# space; we want to normalize it before doing gradient descent on it.
# We're going to view M as
# M == N Q,
# where Q is our Q_{dim} "learning-rate" matrix, of shape (size, size) indexed
# [diagonalized_index, canonical_index] (see the code), and N == M Q^{-1}.
# It's more convenient to do gradient descent on a `proj` matrix that's between N and Q,
# i.e. define `proj2`, also numerically equal to I but treated as a variable,
# and have:
# It's likely better numerically to do gradient descent on a `proj2` matrix
# that's between N and Q,
# i.e.:
# M == N proj2 Q
# We want to obtain the derivative w.r.t. proj2.
# So the linearized pseudo-loss is
@ -412,20 +412,13 @@ class LearnedGradient(Optimizer):
# Q += proj2_delta Q (eq.2)
#
# We also need to update M itself; this is easiest done by computing proj_delta
# (which is the difference proj from I). The relationship between proj and proj2
# (which is the difference of proj from I). The relationship between proj and proj2
# is given by equating
# M == N proj2 Q == N Q proj,
# so proj2 Q == Q proj
# proj == Q^{-1} proj2 Q
# so proj2 Q == Q proj
# proj == Q^{-1} proj2 Q
# and subtracting out the constant "I" part,
# proj_delta == Q^{-1} proj2_delta Q (eq.3)
# ... so the change in Q is given by:
# Q := Q (I + proj_delta),
# Q += Q proj_delta
# Q_delta = Q proj_delta
# and looking at how we compute proj_delta (eq.3), we notice the right-hand subexpression
# is the sam as p_delta, i.e.:
# Q_delta = proj2_delta Q (eq.4)
#
# and then because we view the parameter matrix as
# "M proj",
@ -459,13 +452,11 @@ class LearnedGradient(Optimizer):
# we'll take into account the factor of -meta_lr * (1-beta1) later on;
# actually proj2_delta is the negative of a parameter change.
proj2_delta = proj2_grad_var / denom
proj2_delta = proj2_grad / denom
# See (eq.4), Q_delta = proj2_delta q
# See (eq.2), Q += proj2_delta Q
Q_delta = torch.matmul(proj2_delta, Q)
assert not torch.all(proj2_delta == 0)
# See (eq.3), proj_delta = Q^{-1} proj2_delta Q = Q^{-1} Q_delta
try:
proj_delta = torch.linalg.solve(Q, Q_delta)
@ -483,8 +474,8 @@ class LearnedGradient(Optimizer):
# are the final changes, the only 2 we make in this loop that have
# side effects.
# delta_scale will make it update the learning rates faster than it otherwise would.
delta_scale=0.2
# delta_scale < 1 will make it update the learning rates faster than it otherwise would.
delta_scale=1.0
delta.add_(this_delta, alpha=-delta_scale*meta_lr*(1-beta1))
# there is no momentum on Q.
Q.add_(Q_delta, alpha=-meta_lr)
@ -586,7 +577,7 @@ class LearnedGradient(Optimizer):
Q[:] = torch.matmul(U * S_new, V.t())
if random.random() < 0.03:
subsample = max(1, S.numel() // 20)
logging.info(f"shape={tuple(p.shape)}, dim={dim}, modifed S from {S[::subsample]} to {S_new[::subsample]}")
logging.info(f"shape={tuple(p.shape)}, dim={dim}, modified S from {S[::subsample]} to {S_new[::subsample]}")
if True:
# This block does the actual diagonalization.