mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
Version that runs
This commit is contained in:
parent
e6d00ee3e4
commit
04d2e10b4f
@ -66,7 +66,7 @@ class LearnedGradient(Optimizer):
|
|||||||
meta_lr_scale=0.1,
|
meta_lr_scale=0.1,
|
||||||
betas=(0.9, 0.98),
|
betas=(0.9, 0.98),
|
||||||
eps=1.0e-08,
|
eps=1.0e-08,
|
||||||
size_update_period=4,
|
size_update_period=1,
|
||||||
param_min_rms=1.0e-05,
|
param_min_rms=1.0e-05,
|
||||||
param_max_rms=2.0,
|
param_max_rms=2.0,
|
||||||
lr_mat_min=0.01,
|
lr_mat_min=0.01,
|
||||||
@ -117,10 +117,10 @@ class LearnedGradient(Optimizer):
|
|||||||
meta_lr_scale = group["meta_lr_scale"]
|
meta_lr_scale = group["meta_lr_scale"]
|
||||||
size_lr = lr * group["size_lr_scale"]
|
size_lr = lr * group["size_lr_scale"]
|
||||||
beta1, beta2 = group["betas"]
|
beta1, beta2 = group["betas"]
|
||||||
max_lr_eig = group["max_lr_eig"]
|
lr_mat_min = group["lr_mat_min"]
|
||||||
min_lr_eig = group["max_lr_eig"]
|
lr_mat_max = group["lr_mat_max"]
|
||||||
eps = group["eps"]
|
eps = group["eps"]
|
||||||
size_update_period = 4
|
size_update_period = group["size_update_period"]
|
||||||
param_min_rms = group["param_min_rms"]
|
param_min_rms = group["param_min_rms"]
|
||||||
param_max_rms = group["param_max_rms"]
|
param_max_rms = group["param_max_rms"]
|
||||||
lr_est_period = group["lr_est_period"]
|
lr_est_period = group["lr_est_period"]
|
||||||
@ -161,9 +161,14 @@ class LearnedGradient(Optimizer):
|
|||||||
# "param_rms" just periodically records the scalar root-mean-square value of
|
# "param_rms" just periodically records the scalar root-mean-square value of
|
||||||
# the parameter tensor.
|
# the parameter tensor.
|
||||||
param_rms = (p**2).mean().sqrt().add_(eps)
|
param_rms = (p**2).mean().sqrt().add_(eps)
|
||||||
state["param_rms"] = param_rms
|
state["param_rms"] = param_rms
|
||||||
# scalar exp_avg_sq. not the same as scale_exp_avg_sq
|
|
||||||
state["exp_avg_sq_scalar"] = torch.zeros_like((), **kwargs)
|
state["scale_exp_avg_sq"] = torch.zeros((), **kwargs)
|
||||||
|
# scalar exp_avg_sq. not the same as scale_exp_avg_sq, this is used to
|
||||||
|
# determine the magnitude of the regular step
|
||||||
|
state["scalar_exp_avg_sq"] = torch.zeros((), **kwargs)
|
||||||
|
state["scale_grads"] = torch.zeros(size_update_period,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
# exp_avg_sq is in the projected space. It is periodically zeroed, and we
|
# exp_avg_sq is in the projected space. It is periodically zeroed, and we
|
||||||
# set zero_step.
|
# set zero_step.
|
||||||
@ -230,17 +235,21 @@ class LearnedGradient(Optimizer):
|
|||||||
numel = p.numel()
|
numel = p.numel()
|
||||||
if numel > 1:
|
if numel > 1:
|
||||||
# Update the size/scale of p, and set param_rms
|
# Update the size/scale of p, and set param_rms
|
||||||
state["scale_grads"][step % size_update_period] = (p * grad).sum()
|
scale_grads = state["scale_grads"]
|
||||||
|
scale_grads[step % size_update_period] = (p * grad).sum()
|
||||||
if step % size_update_period == size_update_period - 1:
|
if step % size_update_period == size_update_period - 1:
|
||||||
# this learns the overall scale on the parameter, by shrinking or
|
# this learns the overall scale on the parameter, by shrinking or
|
||||||
# expanding it.
|
# expanding it.
|
||||||
state["param_rms"] = (param ** 2).mean().sqrt().clamp_(min=eps)
|
param_rms = state["param_rms"]
|
||||||
|
param_rms.copy_((p ** 2).mean().sqrt().clamp_(min=eps))
|
||||||
if step > 0:
|
if step > 0:
|
||||||
self._size_update(p, state,
|
self._size_update(p, state,
|
||||||
scale_grads, param_rms,
|
scale_grads, param_rms,
|
||||||
beta1, beta2, step, size_lr,
|
beta1, beta2, step, size_lr,
|
||||||
param_min_rms, param_max_rms)
|
param_min_rms, param_max_rms)
|
||||||
state["delta"].mul_(beta1)
|
|
||||||
|
delta = state["delta"]
|
||||||
|
delta.mul_(beta1)
|
||||||
if numel == 1:
|
if numel == 1:
|
||||||
# For parameters with very few elements we just use a form
|
# For parameters with very few elements we just use a form
|
||||||
# of Adam with a scale factor to reflect the overall
|
# of Adam with a scale factor to reflect the overall
|
||||||
@ -252,8 +261,9 @@ class LearnedGradient(Optimizer):
|
|||||||
self._update_lrs(group, p, state)
|
self._update_lrs(group, p, state)
|
||||||
if step % (lr_est_period * diagonalize_period) == 0:
|
if step % (lr_est_period * diagonalize_period) == 0:
|
||||||
self._diagonalize_lrs(group, p, state)
|
self._diagonalize_lrs(group, p, state)
|
||||||
self._zero_exp_avg_sq()
|
self._zero_exp_avg_sq(state)
|
||||||
self._step(group, p, grad, state)
|
if True:
|
||||||
|
self._step(group, p, grad, state)
|
||||||
p.add_(delta)
|
p.add_(delta)
|
||||||
state["step"] = step + 1
|
state["step"] = step + 1
|
||||||
|
|
||||||
@ -262,9 +272,11 @@ class LearnedGradient(Optimizer):
|
|||||||
def _size_update(self,
|
def _size_update(self,
|
||||||
p: Tensor,
|
p: Tensor,
|
||||||
state: dict,
|
state: dict,
|
||||||
scale_grad: Tensor,
|
scale_grads: Tensor,
|
||||||
|
param_rms: Tensor,
|
||||||
beta1: float,
|
beta1: float,
|
||||||
beta2: float,
|
beta2: float,
|
||||||
|
step: int,
|
||||||
size_lr: float,
|
size_lr: float,
|
||||||
param_min_rms: float,
|
param_min_rms: float,
|
||||||
param_max_rms: float) -> None:
|
param_max_rms: float) -> None:
|
||||||
@ -298,7 +310,7 @@ class LearnedGradient(Optimizer):
|
|||||||
|
|
||||||
scale_exp_avg_sq = state["scale_exp_avg_sq"]
|
scale_exp_avg_sq = state["scale_exp_avg_sq"]
|
||||||
scale_exp_avg_sq.mul_(beta2_corr).add_(
|
scale_exp_avg_sq.mul_(beta2_corr).add_(
|
||||||
(scale_grads ** 2).mean(), value=1-beta2_corr)
|
(scale_grads ** 2).mean(), alpha=1-beta2_corr)
|
||||||
|
|
||||||
|
|
||||||
# The 1st time we reach here is when size_step == 1.
|
# The 1st time we reach here is when size_step == 1.
|
||||||
@ -354,7 +366,7 @@ class LearnedGradient(Optimizer):
|
|||||||
# M is the parameter matrix, of shape (-1, size) where the -1 covers
|
# M is the parameter matrix, of shape (-1, size) where the -1 covers
|
||||||
# all other dimensions of the tensor.
|
# all other dimensions of the tensor.
|
||||||
M_full = p.transpose(dim, -1)
|
M_full = p.transpose(dim, -1)
|
||||||
M = M.reshape(-1, size)
|
M = M_full.reshape(-1, size)
|
||||||
|
|
||||||
# proj_grad is a summed gradient, of shape (size, size),
|
# proj_grad is a summed gradient, of shape (size, size),
|
||||||
# indexed [old_param_idx][new_param_idx] (the point is,
|
# indexed [old_param_idx][new_param_idx] (the point is,
|
||||||
@ -435,9 +447,12 @@ class LearnedGradient(Optimizer):
|
|||||||
Q = state[f"Q_{dim}"]
|
Q = state[f"Q_{dim}"]
|
||||||
|
|
||||||
# Next we want to implement (eq.1), proj2_grad = Q^{-T} proj_grad Q^T
|
# Next we want to implement (eq.1), proj2_grad = Q^{-T} proj_grad Q^T
|
||||||
# torch.linalg.solve(A, B) returns A^{-1} B.
|
# torch.solve(A, B) returns A^{-1} B.
|
||||||
Q_invt_proj_grad = torch.linalg.solve(Q.t(), proj_grad)
|
try:
|
||||||
assert torch.allclose(proj_grad, torch.matmul(Q.t(), Q_invt_proj_grad)) # TODO: remove
|
Q_invt_proj_grad = torch.linalg.solve(Q.t(), proj_grad)
|
||||||
|
except:
|
||||||
|
# torch.solve has args the other way round and also returns LU.
|
||||||
|
Q_invt_proj_grad, _ = torch.solve(proj_grad, Q.t())
|
||||||
|
|
||||||
# (eq.1), proj2_grad = Q^{-T} proj_grad Q^T
|
# (eq.1), proj2_grad = Q^{-T} proj_grad Q^T
|
||||||
proj2_grad = torch.matmul(Q_invt_proj_grad, Q.t())
|
proj2_grad = torch.matmul(Q_invt_proj_grad, Q.t())
|
||||||
@ -456,8 +471,14 @@ class LearnedGradient(Optimizer):
|
|||||||
# See (eq.4), Q_delta = proj2_delta q
|
# See (eq.4), Q_delta = proj2_delta q
|
||||||
Q_delta = torch.matmul(proj2_delta, Q)
|
Q_delta = torch.matmul(proj2_delta, Q)
|
||||||
|
|
||||||
|
assert not torch.all(proj2_delta == 0)
|
||||||
|
|
||||||
# See (eq.3), proj_delta = Q^{-1} proj2_delta Q = Q^{-1} Q_delta
|
# See (eq.3), proj_delta = Q^{-1} proj2_delta Q = Q^{-1} Q_delta
|
||||||
proj_delta = torch.linalg.solve(Q, Q_delta)
|
try:
|
||||||
|
proj_delta = torch.linalg.solve(Q, Q_delta)
|
||||||
|
except:
|
||||||
|
# torch.solve has args the other way round and also returns LU.
|
||||||
|
proj_delta, _ = torch.solve(Q_delta, Q)
|
||||||
|
|
||||||
M_delta = torch.matmul(M, proj_delta) # (eq.5), M_delta = M proj_delta
|
M_delta = torch.matmul(M, proj_delta) # (eq.5), M_delta = M proj_delta
|
||||||
|
|
||||||
@ -474,12 +495,15 @@ class LearnedGradient(Optimizer):
|
|||||||
# there is no momentum on Q.
|
# there is no momentum on Q.
|
||||||
Q.add_(Q_delta, alpha=-meta_lr)
|
Q.add_(Q_delta, alpha=-meta_lr)
|
||||||
|
|
||||||
|
proj_grad.zero_()
|
||||||
|
|
||||||
def _zero_exp_avg_sq(self,
|
def _zero_exp_avg_sq(self,
|
||||||
state: dict) -> None:
|
state: dict) -> None:
|
||||||
"""
|
"""
|
||||||
Zero the exp_avg_sq stats, and set state["zero_step"] to state["step"]
|
Zero the exp_avg_sq stats, and set state["zero_step"] to state["step"]
|
||||||
"""
|
"""
|
||||||
state["exp_avg_sq"].zero_()
|
state["exp_avg_sq"].zero_()
|
||||||
|
state["scalar_exp_avg_sq"].zero_()
|
||||||
state["zero_step"] = state["step"]
|
state["zero_step"] = state["step"]
|
||||||
|
|
||||||
def _diagonalize_lrs(self,
|
def _diagonalize_lrs(self,
|
||||||
@ -704,12 +728,12 @@ class LearnedGradient(Optimizer):
|
|||||||
# This block accumulates the statistics proj_grad_{dim} and
|
# This block accumulates the statistics proj_grad_{dim} and
|
||||||
# grad_cov_{dim}, which are for periodically updating the
|
# grad_cov_{dim}, which are for periodically updating the
|
||||||
# learning-rate matrices.
|
# learning-rate matrices.
|
||||||
size = grad.shape[size]
|
size = grad.shape[dim]
|
||||||
# accumulate some stats for learning the projections
|
# accumulate some stats for learning the projections
|
||||||
proj_grad = state[f"proj_grad_{dim}"]
|
proj_grad = state[f"proj_grad_{dim}"]
|
||||||
grad_cov = state[f"grad_cov_{dim}"]
|
grad_cov = state[f"grad_cov_{dim}"]
|
||||||
this_m = p.transpose(-1, dim).reshape(-1, size) # parameter matrix M
|
this_m = p.transpose(-1, dim).reshape(-1, size) # parameter matrix M
|
||||||
this_g = g.transpose(-1, dim).reshape(-1, size) # M_grad
|
this_g = grad.transpose(-1, dim).reshape(-1, size) # M_grad
|
||||||
proj_grad.add_(torch.matmul(this_m.t(), this_g))
|
proj_grad.add_(torch.matmul(this_m.t(), this_g))
|
||||||
# could perhaps accumulate grad_cov less frequently; it's only
|
# could perhaps accumulate grad_cov less frequently; it's only
|
||||||
# needed when we rediagonalize which is not that common.
|
# needed when we rediagonalize which is not that common.
|
||||||
@ -721,8 +745,8 @@ class LearnedGradient(Optimizer):
|
|||||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad,
|
exp_avg_sq.mul_(beta2).addcmul_(grad, grad,
|
||||||
value=(1-beta2))
|
value=(1-beta2))
|
||||||
|
|
||||||
step = state["step"]
|
this_step = (state["step"] - state["zero_step"])
|
||||||
bias_correction2 = 1 - beta2 ** (step + 1)
|
bias_correction2 = 1 - beta2 ** (this_step + 1)
|
||||||
if bias_correction2 < 0.99:
|
if bias_correction2 < 0.99:
|
||||||
# note: not in-place.
|
# note: not in-place.
|
||||||
exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2)
|
exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2)
|
||||||
@ -745,13 +769,11 @@ class LearnedGradient(Optimizer):
|
|||||||
|
|
||||||
denom = scalar_exp_avg_sq.sqrt() + eps # scalar.
|
denom = scalar_exp_avg_sq.sqrt() + eps # scalar.
|
||||||
|
|
||||||
alpha = state["param_rms"] * (1-beta1) * lr / denom
|
alpha = state["param_rms"] * (1-beta1) * -lr / denom
|
||||||
|
|
||||||
delta = state["delta"]
|
delta = state["delta"]
|
||||||
delta.add_(grad, alpha=alpha)
|
delta.add_(grad, alpha=alpha)
|
||||||
|
|
||||||
param.add_(delta)
|
|
||||||
|
|
||||||
|
|
||||||
def _step_scalar(self,
|
def _step_scalar(self,
|
||||||
beta1: float,
|
beta1: float,
|
||||||
@ -1889,8 +1911,8 @@ def _test_eve_cain():
|
|||||||
avg_loss = 0.0
|
avg_loss = 0.0
|
||||||
for epoch in range(150):
|
for epoch in range(150):
|
||||||
scheduler.step_epoch()
|
scheduler.step_epoch()
|
||||||
if epoch == 100 and iter in [2,3]:
|
#if epoch == 100 and iter in [2,3]:
|
||||||
optim.reset_speedup() # check it doesn't crash.
|
# optim.reset_speedup() # check it doesn't crash.
|
||||||
|
|
||||||
if epoch == 130:
|
if epoch == 130:
|
||||||
opts = diagnostics.TensorDiagnosticOptions(
|
opts = diagnostics.TensorDiagnosticOptions(
|
||||||
@ -1906,7 +1928,7 @@ def _test_eve_cain():
|
|||||||
avg_loss = loss.item()
|
avg_loss = loss.item()
|
||||||
else:
|
else:
|
||||||
avg_loss = 0.95 * avg_loss + 0.05 * loss.item()
|
avg_loss = 0.95 * avg_loss + 0.05 * loss.item()
|
||||||
if n == 0 and epoch % 10 == 0:
|
if n == 0 and epoch % 5 == 0:
|
||||||
norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item()
|
norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item()
|
||||||
norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item()
|
norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item()
|
||||||
norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item()
|
norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user