This works better for reasons I dont understand. transpose is enough, same as symmetrizing.

This commit is contained in:
Daniel Povey 2022-07-08 11:53:59 +08:00
parent e9ab1ddd39
commit 52bfb2b018

View File

@ -60,7 +60,7 @@ class LearnedGradient(Optimizer):
params,
lr=3e-02,
size_lr_scale=0.1,
meta_lr_scale=0.05,
meta_lr_scale=0.01,
betas=(0.9, 0.98),
eps=1.0e-08,
size_update_period=1,
@ -379,9 +379,10 @@ class LearnedGradient(Optimizer):
# proj being numerically equal to I but treated as a variable, i.e. as
# M = matmul(M_underlying, proj), then instead of the loss being tr(M_grad^T M), we can write
# it as tr(M_grad^T M_underlying proj), which can also be written as tr(proj_grad^T proj)
# where proj_grad == M_underlying^T M_grad == M^T M_grad
# where proj_grad == (M_grad^T M_underlying)^T == M_underlying^T M_grad == M^T M_grad
#
# Now, proj_grad is convenient to compute, being invariant of Q, but it's not in a very normalized
# Now, proj_grad is convenient to compute, being independent of Q,
# but it's not in a very normalized
# space; we want to normalize it before doing gradient descent on it.
# We're going to view M as
# M == N Q,
@ -430,6 +431,25 @@ class LearnedGradient(Optimizer):
# M_delta = M proj_delta (eq.5)
proj_grad = state[f"proj_grad_{dim}"]
Q = state[f"Q_{dim}"]
if True:
# Simpler update, do it to the right of Q.
proj_grad[:] = proj_grad + proj_grad.t()
proj_grad_var = self._get_matrix_var(proj_grad, group["eps"])
denom = proj_grad_var.sqrt()
proj_delta = proj_grad / denom
Q_delta = torch.matmul(Q, proj_delta)
M_delta = torch.matmul(M, proj_delta)
M_delta_full = M_delta.reshape(M_full.shape)
this_delta = M_delta_full.transpose(dim, -1)
# delta_scale < 1 will make it update the learning rates faster than it otherwise would,
# as we'll reach equilbrium with M less rapidly.
delta_scale=1.0
delta.add_(this_delta, alpha=-delta_scale*meta_lr*(1-beta1))
# there is no momentum on Q.
Q.add_(Q_delta, alpha=-meta_lr)
proj_grad.zero_()
continue
# Next we want to implement (eq.1), proj2_grad = Q^{-T} proj_grad Q^T
# torch.solve(A, B) returns A^{-1} B.