mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
Version that runs
This commit is contained in:
parent
e6d00ee3e4
commit
04d2e10b4f
@ -66,7 +66,7 @@ class LearnedGradient(Optimizer):
|
||||
meta_lr_scale=0.1,
|
||||
betas=(0.9, 0.98),
|
||||
eps=1.0e-08,
|
||||
size_update_period=4,
|
||||
size_update_period=1,
|
||||
param_min_rms=1.0e-05,
|
||||
param_max_rms=2.0,
|
||||
lr_mat_min=0.01,
|
||||
@ -117,10 +117,10 @@ class LearnedGradient(Optimizer):
|
||||
meta_lr_scale = group["meta_lr_scale"]
|
||||
size_lr = lr * group["size_lr_scale"]
|
||||
beta1, beta2 = group["betas"]
|
||||
max_lr_eig = group["max_lr_eig"]
|
||||
min_lr_eig = group["max_lr_eig"]
|
||||
lr_mat_min = group["lr_mat_min"]
|
||||
lr_mat_max = group["lr_mat_max"]
|
||||
eps = group["eps"]
|
||||
size_update_period = 4
|
||||
size_update_period = group["size_update_period"]
|
||||
param_min_rms = group["param_min_rms"]
|
||||
param_max_rms = group["param_max_rms"]
|
||||
lr_est_period = group["lr_est_period"]
|
||||
@ -162,8 +162,13 @@ class LearnedGradient(Optimizer):
|
||||
# the parameter tensor.
|
||||
param_rms = (p**2).mean().sqrt().add_(eps)
|
||||
state["param_rms"] = param_rms
|
||||
# scalar exp_avg_sq. not the same as scale_exp_avg_sq
|
||||
state["exp_avg_sq_scalar"] = torch.zeros_like((), **kwargs)
|
||||
|
||||
state["scale_exp_avg_sq"] = torch.zeros((), **kwargs)
|
||||
# scalar exp_avg_sq. not the same as scale_exp_avg_sq, this is used to
|
||||
# determine the magnitude of the regular step
|
||||
state["scalar_exp_avg_sq"] = torch.zeros((), **kwargs)
|
||||
state["scale_grads"] = torch.zeros(size_update_period,
|
||||
**kwargs)
|
||||
|
||||
# exp_avg_sq is in the projected space. It is periodically zeroed, and we
|
||||
# set zero_step.
|
||||
@ -230,17 +235,21 @@ class LearnedGradient(Optimizer):
|
||||
numel = p.numel()
|
||||
if numel > 1:
|
||||
# Update the size/scale of p, and set param_rms
|
||||
state["scale_grads"][step % size_update_period] = (p * grad).sum()
|
||||
scale_grads = state["scale_grads"]
|
||||
scale_grads[step % size_update_period] = (p * grad).sum()
|
||||
if step % size_update_period == size_update_period - 1:
|
||||
# this learns the overall scale on the parameter, by shrinking or
|
||||
# expanding it.
|
||||
state["param_rms"] = (param ** 2).mean().sqrt().clamp_(min=eps)
|
||||
param_rms = state["param_rms"]
|
||||
param_rms.copy_((p ** 2).mean().sqrt().clamp_(min=eps))
|
||||
if step > 0:
|
||||
self._size_update(p, state,
|
||||
scale_grads, param_rms,
|
||||
beta1, beta2, step, size_lr,
|
||||
param_min_rms, param_max_rms)
|
||||
state["delta"].mul_(beta1)
|
||||
|
||||
delta = state["delta"]
|
||||
delta.mul_(beta1)
|
||||
if numel == 1:
|
||||
# For parameters with very few elements we just use a form
|
||||
# of Adam with a scale factor to reflect the overall
|
||||
@ -252,7 +261,8 @@ class LearnedGradient(Optimizer):
|
||||
self._update_lrs(group, p, state)
|
||||
if step % (lr_est_period * diagonalize_period) == 0:
|
||||
self._diagonalize_lrs(group, p, state)
|
||||
self._zero_exp_avg_sq()
|
||||
self._zero_exp_avg_sq(state)
|
||||
if True:
|
||||
self._step(group, p, grad, state)
|
||||
p.add_(delta)
|
||||
state["step"] = step + 1
|
||||
@ -262,9 +272,11 @@ class LearnedGradient(Optimizer):
|
||||
def _size_update(self,
|
||||
p: Tensor,
|
||||
state: dict,
|
||||
scale_grad: Tensor,
|
||||
scale_grads: Tensor,
|
||||
param_rms: Tensor,
|
||||
beta1: float,
|
||||
beta2: float,
|
||||
step: int,
|
||||
size_lr: float,
|
||||
param_min_rms: float,
|
||||
param_max_rms: float) -> None:
|
||||
@ -298,7 +310,7 @@ class LearnedGradient(Optimizer):
|
||||
|
||||
scale_exp_avg_sq = state["scale_exp_avg_sq"]
|
||||
scale_exp_avg_sq.mul_(beta2_corr).add_(
|
||||
(scale_grads ** 2).mean(), value=1-beta2_corr)
|
||||
(scale_grads ** 2).mean(), alpha=1-beta2_corr)
|
||||
|
||||
|
||||
# The 1st time we reach here is when size_step == 1.
|
||||
@ -354,7 +366,7 @@ class LearnedGradient(Optimizer):
|
||||
# M is the parameter matrix, of shape (-1, size) where the -1 covers
|
||||
# all other dimensions of the tensor.
|
||||
M_full = p.transpose(dim, -1)
|
||||
M = M.reshape(-1, size)
|
||||
M = M_full.reshape(-1, size)
|
||||
|
||||
# proj_grad is a summed gradient, of shape (size, size),
|
||||
# indexed [old_param_idx][new_param_idx] (the point is,
|
||||
@ -435,9 +447,12 @@ class LearnedGradient(Optimizer):
|
||||
Q = state[f"Q_{dim}"]
|
||||
|
||||
# Next we want to implement (eq.1), proj2_grad = Q^{-T} proj_grad Q^T
|
||||
# torch.linalg.solve(A, B) returns A^{-1} B.
|
||||
# torch.solve(A, B) returns A^{-1} B.
|
||||
try:
|
||||
Q_invt_proj_grad = torch.linalg.solve(Q.t(), proj_grad)
|
||||
assert torch.allclose(proj_grad, torch.matmul(Q.t(), Q_invt_proj_grad)) # TODO: remove
|
||||
except:
|
||||
# torch.solve has args the other way round and also returns LU.
|
||||
Q_invt_proj_grad, _ = torch.solve(proj_grad, Q.t())
|
||||
|
||||
# (eq.1), proj2_grad = Q^{-T} proj_grad Q^T
|
||||
proj2_grad = torch.matmul(Q_invt_proj_grad, Q.t())
|
||||
@ -456,8 +471,14 @@ class LearnedGradient(Optimizer):
|
||||
# See (eq.4), Q_delta = proj2_delta q
|
||||
Q_delta = torch.matmul(proj2_delta, Q)
|
||||
|
||||
assert not torch.all(proj2_delta == 0)
|
||||
|
||||
# See (eq.3), proj_delta = Q^{-1} proj2_delta Q = Q^{-1} Q_delta
|
||||
try:
|
||||
proj_delta = torch.linalg.solve(Q, Q_delta)
|
||||
except:
|
||||
# torch.solve has args the other way round and also returns LU.
|
||||
proj_delta, _ = torch.solve(Q_delta, Q)
|
||||
|
||||
M_delta = torch.matmul(M, proj_delta) # (eq.5), M_delta = M proj_delta
|
||||
|
||||
@ -474,12 +495,15 @@ class LearnedGradient(Optimizer):
|
||||
# there is no momentum on Q.
|
||||
Q.add_(Q_delta, alpha=-meta_lr)
|
||||
|
||||
proj_grad.zero_()
|
||||
|
||||
def _zero_exp_avg_sq(self,
|
||||
state: dict) -> None:
|
||||
"""
|
||||
Zero the exp_avg_sq stats, and set state["zero_step"] to state["step"]
|
||||
"""
|
||||
state["exp_avg_sq"].zero_()
|
||||
state["scalar_exp_avg_sq"].zero_()
|
||||
state["zero_step"] = state["step"]
|
||||
|
||||
def _diagonalize_lrs(self,
|
||||
@ -704,12 +728,12 @@ class LearnedGradient(Optimizer):
|
||||
# This block accumulates the statistics proj_grad_{dim} and
|
||||
# grad_cov_{dim}, which are for periodically updating the
|
||||
# learning-rate matrices.
|
||||
size = grad.shape[size]
|
||||
size = grad.shape[dim]
|
||||
# accumulate some stats for learning the projections
|
||||
proj_grad = state[f"proj_grad_{dim}"]
|
||||
grad_cov = state[f"grad_cov_{dim}"]
|
||||
this_m = p.transpose(-1, dim).reshape(-1, size) # parameter matrix M
|
||||
this_g = g.transpose(-1, dim).reshape(-1, size) # M_grad
|
||||
this_g = grad.transpose(-1, dim).reshape(-1, size) # M_grad
|
||||
proj_grad.add_(torch.matmul(this_m.t(), this_g))
|
||||
# could perhaps accumulate grad_cov less frequently; it's only
|
||||
# needed when we rediagonalize which is not that common.
|
||||
@ -721,8 +745,8 @@ class LearnedGradient(Optimizer):
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad,
|
||||
value=(1-beta2))
|
||||
|
||||
step = state["step"]
|
||||
bias_correction2 = 1 - beta2 ** (step + 1)
|
||||
this_step = (state["step"] - state["zero_step"])
|
||||
bias_correction2 = 1 - beta2 ** (this_step + 1)
|
||||
if bias_correction2 < 0.99:
|
||||
# note: not in-place.
|
||||
exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2)
|
||||
@ -745,13 +769,11 @@ class LearnedGradient(Optimizer):
|
||||
|
||||
denom = scalar_exp_avg_sq.sqrt() + eps # scalar.
|
||||
|
||||
alpha = state["param_rms"] * (1-beta1) * lr / denom
|
||||
alpha = state["param_rms"] * (1-beta1) * -lr / denom
|
||||
|
||||
delta = state["delta"]
|
||||
delta.add_(grad, alpha=alpha)
|
||||
|
||||
param.add_(delta)
|
||||
|
||||
|
||||
def _step_scalar(self,
|
||||
beta1: float,
|
||||
@ -1889,8 +1911,8 @@ def _test_eve_cain():
|
||||
avg_loss = 0.0
|
||||
for epoch in range(150):
|
||||
scheduler.step_epoch()
|
||||
if epoch == 100 and iter in [2,3]:
|
||||
optim.reset_speedup() # check it doesn't crash.
|
||||
#if epoch == 100 and iter in [2,3]:
|
||||
# optim.reset_speedup() # check it doesn't crash.
|
||||
|
||||
if epoch == 130:
|
||||
opts = diagnostics.TensorDiagnosticOptions(
|
||||
@ -1906,7 +1928,7 @@ def _test_eve_cain():
|
||||
avg_loss = loss.item()
|
||||
else:
|
||||
avg_loss = 0.95 * avg_loss + 0.05 * loss.item()
|
||||
if n == 0 and epoch % 10 == 0:
|
||||
if n == 0 and epoch % 5 == 0:
|
||||
norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item()
|
||||
norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item()
|
||||
norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item()
|
||||
|
Loading…
x
Reference in New Issue
Block a user