Conformed that symmetrizing helps because of interaction with regular update; still meta_lr_scale=0 best :-(

This commit is contained in:
Daniel Povey 2022-07-09 05:20:04 +08:00
parent 52bfb2b018
commit a9edecd32c

View File

@ -60,7 +60,7 @@ class LearnedGradient(Optimizer):
params,
lr=3e-02,
size_lr_scale=0.1,
meta_lr_scale=0.01,
meta_lr_scale=0.0,
betas=(0.9, 0.98),
eps=1.0e-08,
size_update_period=1,
@ -431,8 +431,14 @@ class LearnedGradient(Optimizer):
# M_delta = M proj_delta (eq.5)
proj_grad = state[f"proj_grad_{dim}"]
Q = state[f"Q_{dim}"]
# delta_scale < 1 will make it update the learning rates faster than it otherwise would,
# as we'll reach equilbrium with M less rapidly.
delta_scale=0.5
if True:
# Simpler update, do it to the right of Q.
# symmetrize, otherwise there is a bad interaction with the regular update.
proj_grad[:] = proj_grad + proj_grad.t()
proj_grad_var = self._get_matrix_var(proj_grad, group["eps"])
denom = proj_grad_var.sqrt()
@ -441,9 +447,6 @@ class LearnedGradient(Optimizer):
M_delta = torch.matmul(M, proj_delta)
M_delta_full = M_delta.reshape(M_full.shape)
this_delta = M_delta_full.transpose(dim, -1)
# delta_scale < 1 will make it update the learning rates faster than it otherwise would,
# as we'll reach equilbrium with M less rapidly.
delta_scale=1.0
delta.add_(this_delta, alpha=-delta_scale*meta_lr*(1-beta1))
# there is no momentum on Q.
Q.add_(Q_delta, alpha=-meta_lr)
@ -463,6 +466,10 @@ class LearnedGradient(Optimizer):
# (eq.1), proj2_grad = Q^{-T} proj_grad Q^T
proj2_grad = torch.matmul(Q_invt_proj_grad, Q.t())
# symmetrize, otherwise there is a bad interaction with the regular
# update in self._step().
proj2_grad = proj2_grad + proj2_grad.t()
# for now just estimate proj2_grad_var from proj2_grad; later, we
# may try aggregating it across iterations. This is like exp_avg_sq
# in Adam.
@ -494,9 +501,6 @@ class LearnedGradient(Optimizer):
# are the final changes, the only 2 we make in this loop that have
# side effects.
# delta_scale < 1 will make it update the learning rates faster than it otherwise would,
# as we'll reach equilbrium with M less rapidly.
delta_scale=1.0
delta.add_(this_delta, alpha=-delta_scale*meta_lr*(1-beta1))
# there is no momentum on Q.
Q.add_(Q_delta, alpha=-meta_lr)