mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Increase lr_est_period
This commit is contained in:
parent
fb36712e6b
commit
ceb9815f2b
@ -68,7 +68,7 @@ class LearnedGradient(Optimizer):
|
|||||||
param_max_rms=2.0,
|
param_max_rms=2.0,
|
||||||
lr_mat_min=0.01,
|
lr_mat_min=0.01,
|
||||||
lr_mat_max=4.0,
|
lr_mat_max=4.0,
|
||||||
lr_est_period=5,
|
lr_est_period=2,
|
||||||
diagonalize_period=4,
|
diagonalize_period=4,
|
||||||
):
|
):
|
||||||
|
|
||||||
@ -251,14 +251,12 @@ class LearnedGradient(Optimizer):
|
|||||||
# parameter rms. Updates delta.
|
# parameter rms. Updates delta.
|
||||||
self._step_scalar(beta1, beta2, eps, lr, p, grad, state)
|
self._step_scalar(beta1, beta2, eps, lr, p, grad, state)
|
||||||
else:
|
else:
|
||||||
|
if step % lr_est_period == 0 and step > 0:
|
||||||
if step % lr_est_period == 0:
|
|
||||||
self._update_lrs(group, p, state)
|
self._update_lrs(group, p, state)
|
||||||
if step % (lr_est_period * diagonalize_period) == 0:
|
if step % (lr_est_period * diagonalize_period) == 0 and step > 0:
|
||||||
logging.info("Diagonalizing")
|
logging.info("Diagonalizing")
|
||||||
self._diagonalize_lrs(group, p, state)
|
self._diagonalize_lrs(group, p, state)
|
||||||
self._zero_exp_avg_sq(state)
|
self._zero_exp_avg_sq(state)
|
||||||
if True:
|
|
||||||
self._step(group, p, grad, state)
|
self._step(group, p, grad, state)
|
||||||
p.add_(delta)
|
p.add_(delta)
|
||||||
state["step"] = step + 1
|
state["step"] = step + 1
|
||||||
@ -574,17 +572,25 @@ class LearnedGradient(Optimizer):
|
|||||||
# Now,
|
# Now,
|
||||||
# grad_cov == M_grad^T M_grad,
|
# grad_cov == M_grad^T M_grad,
|
||||||
# decaying-averaged over minibatches; this is of shape (size,size).
|
# decaying-averaged over minibatches; this is of shape (size,size).
|
||||||
# Using N_grad = M_grad Q^T, we can write:
|
# Using N_grad = M_grad Q^T, we can write the gradient covariance w.r.t. N
|
||||||
# N_grad_cov = Q M_grad^T M_grad Q^T
|
# (which is what we want to diagonalize), as::
|
||||||
|
# N_grad_cov = N_grad^T N_grad
|
||||||
|
# = Q M_grad^T M_grad Q^T
|
||||||
# = Q grad_cov Q^T
|
# = Q grad_cov Q^T
|
||||||
# (note: this makes sense because the 1st index of Q is the diagonalized
|
# (note: this makes sense because the 1st index of Q is the diagonalized
|
||||||
# index).
|
# index).
|
||||||
|
|
||||||
|
def dispersion(v):
|
||||||
|
# a ratio that, if the eigenvalues are more dispersed, it will be larger.
|
||||||
|
return '%.3e' % ((v**2).mean() / (v.mean() ** 2)).item()
|
||||||
|
|
||||||
grad_cov = state[f"grad_cov_{dim}"]
|
grad_cov = state[f"grad_cov_{dim}"]
|
||||||
N_grad_cov = torch.matmul(Q, torch.matmul(grad_cov, Q.t()))
|
N_grad_cov = torch.matmul(Q, torch.matmul(grad_cov, Q.t()))
|
||||||
N_grad_cov = N_grad_cov + N_grad_cov.t() # ensure symmetric
|
N_grad_cov = N_grad_cov + N_grad_cov.t() # ensure symmetric
|
||||||
U, S, V = N_grad_cov.svd()
|
U, S, V = N_grad_cov.svd()
|
||||||
|
|
||||||
|
logging.info(f"Diagonalizing, dispersion changed from {dispersion(N_grad_cov.diag())} to {dispersion(S)}")
|
||||||
|
|
||||||
# N_grad_cov is SPD, so
|
# N_grad_cov is SPD, so
|
||||||
# N_grad_cov = U S U^T.
|
# N_grad_cov = U S U^T.
|
||||||
|
|
||||||
@ -605,7 +611,7 @@ class LearnedGradient(Optimizer):
|
|||||||
# we modify hat_N, to keep the product M unchanged:
|
# we modify hat_N, to keep the product M unchanged:
|
||||||
# M = N Q = N U U^T Q = hat_N hat_Q
|
# M = N Q = N U U^T Q = hat_N hat_Q
|
||||||
# This can be done by setting
|
# This can be done by setting
|
||||||
# hat_Q = U^T Q (eq.10)
|
# hat_Q := U^T Q (eq.10)
|
||||||
#
|
#
|
||||||
# This is the only thing we have to do, as N is implicit
|
# This is the only thing we have to do, as N is implicit
|
||||||
# and not materialized at this point.
|
# and not materialized at this point.
|
||||||
@ -759,7 +765,8 @@ class LearnedGradient(Optimizer):
|
|||||||
|
|
||||||
|
|
||||||
if bias_correction2 < 0.99:
|
if bias_correction2 < 0.99:
|
||||||
# note: not in-place.
|
# scalar_exp_avg_sq is also zeroed at step "zero_step", so
|
||||||
|
# use the same bias_correction2.
|
||||||
scalar_exp_avg_sq = scalar_exp_avg_sq * (1.0 / bias_correction2)
|
scalar_exp_avg_sq = scalar_exp_avg_sq * (1.0 / bias_correction2)
|
||||||
|
|
||||||
denom = scalar_exp_avg_sq.sqrt() + eps # scalar.
|
denom = scalar_exp_avg_sq.sqrt() + eps # scalar.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user