mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Conformed that symmetrizing helps because of interaction with regular update; still meta_lr_scale=0 best :-(
This commit is contained in:
parent
52bfb2b018
commit
a9edecd32c
@ -60,7 +60,7 @@ class LearnedGradient(Optimizer):
|
|||||||
params,
|
params,
|
||||||
lr=3e-02,
|
lr=3e-02,
|
||||||
size_lr_scale=0.1,
|
size_lr_scale=0.1,
|
||||||
meta_lr_scale=0.01,
|
meta_lr_scale=0.0,
|
||||||
betas=(0.9, 0.98),
|
betas=(0.9, 0.98),
|
||||||
eps=1.0e-08,
|
eps=1.0e-08,
|
||||||
size_update_period=1,
|
size_update_period=1,
|
||||||
@ -431,8 +431,14 @@ class LearnedGradient(Optimizer):
|
|||||||
# M_delta = M proj_delta (eq.5)
|
# M_delta = M proj_delta (eq.5)
|
||||||
proj_grad = state[f"proj_grad_{dim}"]
|
proj_grad = state[f"proj_grad_{dim}"]
|
||||||
Q = state[f"Q_{dim}"]
|
Q = state[f"Q_{dim}"]
|
||||||
|
|
||||||
|
# delta_scale < 1 will make it update the learning rates faster than it otherwise would,
|
||||||
|
# as we'll reach equilbrium with M less rapidly.
|
||||||
|
delta_scale=0.5
|
||||||
|
|
||||||
if True:
|
if True:
|
||||||
# Simpler update, do it to the right of Q.
|
# Simpler update, do it to the right of Q.
|
||||||
|
# symmetrize, otherwise there is a bad interaction with the regular update.
|
||||||
proj_grad[:] = proj_grad + proj_grad.t()
|
proj_grad[:] = proj_grad + proj_grad.t()
|
||||||
proj_grad_var = self._get_matrix_var(proj_grad, group["eps"])
|
proj_grad_var = self._get_matrix_var(proj_grad, group["eps"])
|
||||||
denom = proj_grad_var.sqrt()
|
denom = proj_grad_var.sqrt()
|
||||||
@ -441,9 +447,6 @@ class LearnedGradient(Optimizer):
|
|||||||
M_delta = torch.matmul(M, proj_delta)
|
M_delta = torch.matmul(M, proj_delta)
|
||||||
M_delta_full = M_delta.reshape(M_full.shape)
|
M_delta_full = M_delta.reshape(M_full.shape)
|
||||||
this_delta = M_delta_full.transpose(dim, -1)
|
this_delta = M_delta_full.transpose(dim, -1)
|
||||||
# delta_scale < 1 will make it update the learning rates faster than it otherwise would,
|
|
||||||
# as we'll reach equilbrium with M less rapidly.
|
|
||||||
delta_scale=1.0
|
|
||||||
delta.add_(this_delta, alpha=-delta_scale*meta_lr*(1-beta1))
|
delta.add_(this_delta, alpha=-delta_scale*meta_lr*(1-beta1))
|
||||||
# there is no momentum on Q.
|
# there is no momentum on Q.
|
||||||
Q.add_(Q_delta, alpha=-meta_lr)
|
Q.add_(Q_delta, alpha=-meta_lr)
|
||||||
@ -463,6 +466,10 @@ class LearnedGradient(Optimizer):
|
|||||||
# (eq.1), proj2_grad = Q^{-T} proj_grad Q^T
|
# (eq.1), proj2_grad = Q^{-T} proj_grad Q^T
|
||||||
proj2_grad = torch.matmul(Q_invt_proj_grad, Q.t())
|
proj2_grad = torch.matmul(Q_invt_proj_grad, Q.t())
|
||||||
|
|
||||||
|
# symmetrize, otherwise there is a bad interaction with the regular
|
||||||
|
# update in self._step().
|
||||||
|
proj2_grad = proj2_grad + proj2_grad.t()
|
||||||
|
|
||||||
# for now just estimate proj2_grad_var from proj2_grad; later, we
|
# for now just estimate proj2_grad_var from proj2_grad; later, we
|
||||||
# may try aggregating it across iterations. This is like exp_avg_sq
|
# may try aggregating it across iterations. This is like exp_avg_sq
|
||||||
# in Adam.
|
# in Adam.
|
||||||
@ -494,9 +501,6 @@ class LearnedGradient(Optimizer):
|
|||||||
# are the final changes, the only 2 we make in this loop that have
|
# are the final changes, the only 2 we make in this loop that have
|
||||||
# side effects.
|
# side effects.
|
||||||
|
|
||||||
# delta_scale < 1 will make it update the learning rates faster than it otherwise would,
|
|
||||||
# as we'll reach equilbrium with M less rapidly.
|
|
||||||
delta_scale=1.0
|
|
||||||
delta.add_(this_delta, alpha=-delta_scale*meta_lr*(1-beta1))
|
delta.add_(this_delta, alpha=-delta_scale*meta_lr*(1-beta1))
|
||||||
# there is no momentum on Q.
|
# there is no momentum on Q.
|
||||||
Q.add_(Q_delta, alpha=-meta_lr)
|
Q.add_(Q_delta, alpha=-meta_lr)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user