This works better for reasons I dont understand. transpose is enough, same as symmetrizing.
This commit is contained in:
parent
e9ab1ddd39
commit
52bfb2b018
@ -60,7 +60,7 @@ class LearnedGradient(Optimizer):
|
||||
params,
|
||||
lr=3e-02,
|
||||
size_lr_scale=0.1,
|
||||
meta_lr_scale=0.05,
|
||||
meta_lr_scale=0.01,
|
||||
betas=(0.9, 0.98),
|
||||
eps=1.0e-08,
|
||||
size_update_period=1,
|
||||
@ -379,9 +379,10 @@ class LearnedGradient(Optimizer):
|
||||
# proj being numerically equal to I but treated as a variable, i.e. as
|
||||
# M = matmul(M_underlying, proj), then instead of the loss being tr(M_grad^T M), we can write
|
||||
# it as tr(M_grad^T M_underlying proj), which can also be written as tr(proj_grad^T proj)
|
||||
# where proj_grad == M_underlying^T M_grad == M^T M_grad
|
||||
# where proj_grad == (M_grad^T M_underlying)^T == M_underlying^T M_grad == M^T M_grad
|
||||
#
|
||||
# Now, proj_grad is convenient to compute, being invariant of Q, but it's not in a very normalized
|
||||
# Now, proj_grad is convenient to compute, being independent of Q,
|
||||
# but it's not in a very normalized
|
||||
# space; we want to normalize it before doing gradient descent on it.
|
||||
# We're going to view M as
|
||||
# M == N Q,
|
||||
@ -430,6 +431,25 @@ class LearnedGradient(Optimizer):
|
||||
# M_delta = M proj_delta (eq.5)
|
||||
proj_grad = state[f"proj_grad_{dim}"]
|
||||
Q = state[f"Q_{dim}"]
|
||||
if True:
|
||||
# Simpler update, do it to the right of Q.
|
||||
proj_grad[:] = proj_grad + proj_grad.t()
|
||||
proj_grad_var = self._get_matrix_var(proj_grad, group["eps"])
|
||||
denom = proj_grad_var.sqrt()
|
||||
proj_delta = proj_grad / denom
|
||||
Q_delta = torch.matmul(Q, proj_delta)
|
||||
M_delta = torch.matmul(M, proj_delta)
|
||||
M_delta_full = M_delta.reshape(M_full.shape)
|
||||
this_delta = M_delta_full.transpose(dim, -1)
|
||||
# delta_scale < 1 will make it update the learning rates faster than it otherwise would,
|
||||
# as we'll reach equilbrium with M less rapidly.
|
||||
delta_scale=1.0
|
||||
delta.add_(this_delta, alpha=-delta_scale*meta_lr*(1-beta1))
|
||||
# there is no momentum on Q.
|
||||
Q.add_(Q_delta, alpha=-meta_lr)
|
||||
proj_grad.zero_()
|
||||
continue
|
||||
|
||||
|
||||
# Next we want to implement (eq.1), proj2_grad = Q^{-T} proj_grad Q^T
|
||||
# torch.solve(A, B) returns A^{-1} B.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user