mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Conformed that symmetrizing helps because of interaction with regular update; still meta_lr_scale=0 best :-(
This commit is contained in:
parent
52bfb2b018
commit
a9edecd32c
@ -60,7 +60,7 @@ class LearnedGradient(Optimizer):
|
||||
params,
|
||||
lr=3e-02,
|
||||
size_lr_scale=0.1,
|
||||
meta_lr_scale=0.01,
|
||||
meta_lr_scale=0.0,
|
||||
betas=(0.9, 0.98),
|
||||
eps=1.0e-08,
|
||||
size_update_period=1,
|
||||
@ -431,8 +431,14 @@ class LearnedGradient(Optimizer):
|
||||
# M_delta = M proj_delta (eq.5)
|
||||
proj_grad = state[f"proj_grad_{dim}"]
|
||||
Q = state[f"Q_{dim}"]
|
||||
|
||||
# delta_scale < 1 will make it update the learning rates faster than it otherwise would,
|
||||
# as we'll reach equilbrium with M less rapidly.
|
||||
delta_scale=0.5
|
||||
|
||||
if True:
|
||||
# Simpler update, do it to the right of Q.
|
||||
# symmetrize, otherwise there is a bad interaction with the regular update.
|
||||
proj_grad[:] = proj_grad + proj_grad.t()
|
||||
proj_grad_var = self._get_matrix_var(proj_grad, group["eps"])
|
||||
denom = proj_grad_var.sqrt()
|
||||
@ -441,9 +447,6 @@ class LearnedGradient(Optimizer):
|
||||
M_delta = torch.matmul(M, proj_delta)
|
||||
M_delta_full = M_delta.reshape(M_full.shape)
|
||||
this_delta = M_delta_full.transpose(dim, -1)
|
||||
# delta_scale < 1 will make it update the learning rates faster than it otherwise would,
|
||||
# as we'll reach equilbrium with M less rapidly.
|
||||
delta_scale=1.0
|
||||
delta.add_(this_delta, alpha=-delta_scale*meta_lr*(1-beta1))
|
||||
# there is no momentum on Q.
|
||||
Q.add_(Q_delta, alpha=-meta_lr)
|
||||
@ -463,6 +466,10 @@ class LearnedGradient(Optimizer):
|
||||
# (eq.1), proj2_grad = Q^{-T} proj_grad Q^T
|
||||
proj2_grad = torch.matmul(Q_invt_proj_grad, Q.t())
|
||||
|
||||
# symmetrize, otherwise there is a bad interaction with the regular
|
||||
# update in self._step().
|
||||
proj2_grad = proj2_grad + proj2_grad.t()
|
||||
|
||||
# for now just estimate proj2_grad_var from proj2_grad; later, we
|
||||
# may try aggregating it across iterations. This is like exp_avg_sq
|
||||
# in Adam.
|
||||
@ -494,9 +501,6 @@ class LearnedGradient(Optimizer):
|
||||
# are the final changes, the only 2 we make in this loop that have
|
||||
# side effects.
|
||||
|
||||
# delta_scale < 1 will make it update the learning rates faster than it otherwise would,
|
||||
# as we'll reach equilbrium with M less rapidly.
|
||||
delta_scale=1.0
|
||||
delta.add_(this_delta, alpha=-delta_scale*meta_lr*(1-beta1))
|
||||
# there is no momentum on Q.
|
||||
Q.add_(Q_delta, alpha=-meta_lr)
|
||||
|
Loading…
x
Reference in New Issue
Block a user