Version that runs

This commit is contained in:
Daniel Povey 2022-07-08 04:37:46 +08:00
parent e6d00ee3e4
commit 04d2e10b4f

View File

@ -66,7 +66,7 @@ class LearnedGradient(Optimizer):
meta_lr_scale=0.1, meta_lr_scale=0.1,
betas=(0.9, 0.98), betas=(0.9, 0.98),
eps=1.0e-08, eps=1.0e-08,
size_update_period=4, size_update_period=1,
param_min_rms=1.0e-05, param_min_rms=1.0e-05,
param_max_rms=2.0, param_max_rms=2.0,
lr_mat_min=0.01, lr_mat_min=0.01,
@ -117,10 +117,10 @@ class LearnedGradient(Optimizer):
meta_lr_scale = group["meta_lr_scale"] meta_lr_scale = group["meta_lr_scale"]
size_lr = lr * group["size_lr_scale"] size_lr = lr * group["size_lr_scale"]
beta1, beta2 = group["betas"] beta1, beta2 = group["betas"]
max_lr_eig = group["max_lr_eig"] lr_mat_min = group["lr_mat_min"]
min_lr_eig = group["max_lr_eig"] lr_mat_max = group["lr_mat_max"]
eps = group["eps"] eps = group["eps"]
size_update_period = 4 size_update_period = group["size_update_period"]
param_min_rms = group["param_min_rms"] param_min_rms = group["param_min_rms"]
param_max_rms = group["param_max_rms"] param_max_rms = group["param_max_rms"]
lr_est_period = group["lr_est_period"] lr_est_period = group["lr_est_period"]
@ -161,9 +161,14 @@ class LearnedGradient(Optimizer):
# "param_rms" just periodically records the scalar root-mean-square value of # "param_rms" just periodically records the scalar root-mean-square value of
# the parameter tensor. # the parameter tensor.
param_rms = (p**2).mean().sqrt().add_(eps) param_rms = (p**2).mean().sqrt().add_(eps)
state["param_rms"] = param_rms state["param_rms"] = param_rms
# scalar exp_avg_sq. not the same as scale_exp_avg_sq
state["exp_avg_sq_scalar"] = torch.zeros_like((), **kwargs) state["scale_exp_avg_sq"] = torch.zeros((), **kwargs)
# scalar exp_avg_sq. not the same as scale_exp_avg_sq, this is used to
# determine the magnitude of the regular step
state["scalar_exp_avg_sq"] = torch.zeros((), **kwargs)
state["scale_grads"] = torch.zeros(size_update_period,
**kwargs)
# exp_avg_sq is in the projected space. It is periodically zeroed, and we # exp_avg_sq is in the projected space. It is periodically zeroed, and we
# set zero_step. # set zero_step.
@ -230,17 +235,21 @@ class LearnedGradient(Optimizer):
numel = p.numel() numel = p.numel()
if numel > 1: if numel > 1:
# Update the size/scale of p, and set param_rms # Update the size/scale of p, and set param_rms
state["scale_grads"][step % size_update_period] = (p * grad).sum() scale_grads = state["scale_grads"]
scale_grads[step % size_update_period] = (p * grad).sum()
if step % size_update_period == size_update_period - 1: if step % size_update_period == size_update_period - 1:
# this learns the overall scale on the parameter, by shrinking or # this learns the overall scale on the parameter, by shrinking or
# expanding it. # expanding it.
state["param_rms"] = (param ** 2).mean().sqrt().clamp_(min=eps) param_rms = state["param_rms"]
param_rms.copy_((p ** 2).mean().sqrt().clamp_(min=eps))
if step > 0: if step > 0:
self._size_update(p, state, self._size_update(p, state,
scale_grads, param_rms, scale_grads, param_rms,
beta1, beta2, step, size_lr, beta1, beta2, step, size_lr,
param_min_rms, param_max_rms) param_min_rms, param_max_rms)
state["delta"].mul_(beta1)
delta = state["delta"]
delta.mul_(beta1)
if numel == 1: if numel == 1:
# For parameters with very few elements we just use a form # For parameters with very few elements we just use a form
# of Adam with a scale factor to reflect the overall # of Adam with a scale factor to reflect the overall
@ -252,8 +261,9 @@ class LearnedGradient(Optimizer):
self._update_lrs(group, p, state) self._update_lrs(group, p, state)
if step % (lr_est_period * diagonalize_period) == 0: if step % (lr_est_period * diagonalize_period) == 0:
self._diagonalize_lrs(group, p, state) self._diagonalize_lrs(group, p, state)
self._zero_exp_avg_sq() self._zero_exp_avg_sq(state)
self._step(group, p, grad, state) if True:
self._step(group, p, grad, state)
p.add_(delta) p.add_(delta)
state["step"] = step + 1 state["step"] = step + 1
@ -262,9 +272,11 @@ class LearnedGradient(Optimizer):
def _size_update(self, def _size_update(self,
p: Tensor, p: Tensor,
state: dict, state: dict,
scale_grad: Tensor, scale_grads: Tensor,
param_rms: Tensor,
beta1: float, beta1: float,
beta2: float, beta2: float,
step: int,
size_lr: float, size_lr: float,
param_min_rms: float, param_min_rms: float,
param_max_rms: float) -> None: param_max_rms: float) -> None:
@ -298,7 +310,7 @@ class LearnedGradient(Optimizer):
scale_exp_avg_sq = state["scale_exp_avg_sq"] scale_exp_avg_sq = state["scale_exp_avg_sq"]
scale_exp_avg_sq.mul_(beta2_corr).add_( scale_exp_avg_sq.mul_(beta2_corr).add_(
(scale_grads ** 2).mean(), value=1-beta2_corr) (scale_grads ** 2).mean(), alpha=1-beta2_corr)
# The 1st time we reach here is when size_step == 1. # The 1st time we reach here is when size_step == 1.
@ -354,7 +366,7 @@ class LearnedGradient(Optimizer):
# M is the parameter matrix, of shape (-1, size) where the -1 covers # M is the parameter matrix, of shape (-1, size) where the -1 covers
# all other dimensions of the tensor. # all other dimensions of the tensor.
M_full = p.transpose(dim, -1) M_full = p.transpose(dim, -1)
M = M.reshape(-1, size) M = M_full.reshape(-1, size)
# proj_grad is a summed gradient, of shape (size, size), # proj_grad is a summed gradient, of shape (size, size),
# indexed [old_param_idx][new_param_idx] (the point is, # indexed [old_param_idx][new_param_idx] (the point is,
@ -435,9 +447,12 @@ class LearnedGradient(Optimizer):
Q = state[f"Q_{dim}"] Q = state[f"Q_{dim}"]
# Next we want to implement (eq.1), proj2_grad = Q^{-T} proj_grad Q^T # Next we want to implement (eq.1), proj2_grad = Q^{-T} proj_grad Q^T
# torch.linalg.solve(A, B) returns A^{-1} B. # torch.solve(A, B) returns A^{-1} B.
Q_invt_proj_grad = torch.linalg.solve(Q.t(), proj_grad) try:
assert torch.allclose(proj_grad, torch.matmul(Q.t(), Q_invt_proj_grad)) # TODO: remove Q_invt_proj_grad = torch.linalg.solve(Q.t(), proj_grad)
except:
# torch.solve has args the other way round and also returns LU.
Q_invt_proj_grad, _ = torch.solve(proj_grad, Q.t())
# (eq.1), proj2_grad = Q^{-T} proj_grad Q^T # (eq.1), proj2_grad = Q^{-T} proj_grad Q^T
proj2_grad = torch.matmul(Q_invt_proj_grad, Q.t()) proj2_grad = torch.matmul(Q_invt_proj_grad, Q.t())
@ -456,8 +471,14 @@ class LearnedGradient(Optimizer):
# See (eq.4), Q_delta = proj2_delta q # See (eq.4), Q_delta = proj2_delta q
Q_delta = torch.matmul(proj2_delta, Q) Q_delta = torch.matmul(proj2_delta, Q)
assert not torch.all(proj2_delta == 0)
# See (eq.3), proj_delta = Q^{-1} proj2_delta Q = Q^{-1} Q_delta # See (eq.3), proj_delta = Q^{-1} proj2_delta Q = Q^{-1} Q_delta
proj_delta = torch.linalg.solve(Q, Q_delta) try:
proj_delta = torch.linalg.solve(Q, Q_delta)
except:
# torch.solve has args the other way round and also returns LU.
proj_delta, _ = torch.solve(Q_delta, Q)
M_delta = torch.matmul(M, proj_delta) # (eq.5), M_delta = M proj_delta M_delta = torch.matmul(M, proj_delta) # (eq.5), M_delta = M proj_delta
@ -474,12 +495,15 @@ class LearnedGradient(Optimizer):
# there is no momentum on Q. # there is no momentum on Q.
Q.add_(Q_delta, alpha=-meta_lr) Q.add_(Q_delta, alpha=-meta_lr)
proj_grad.zero_()
def _zero_exp_avg_sq(self, def _zero_exp_avg_sq(self,
state: dict) -> None: state: dict) -> None:
""" """
Zero the exp_avg_sq stats, and set state["zero_step"] to state["step"] Zero the exp_avg_sq stats, and set state["zero_step"] to state["step"]
""" """
state["exp_avg_sq"].zero_() state["exp_avg_sq"].zero_()
state["scalar_exp_avg_sq"].zero_()
state["zero_step"] = state["step"] state["zero_step"] = state["step"]
def _diagonalize_lrs(self, def _diagonalize_lrs(self,
@ -704,12 +728,12 @@ class LearnedGradient(Optimizer):
# This block accumulates the statistics proj_grad_{dim} and # This block accumulates the statistics proj_grad_{dim} and
# grad_cov_{dim}, which are for periodically updating the # grad_cov_{dim}, which are for periodically updating the
# learning-rate matrices. # learning-rate matrices.
size = grad.shape[size] size = grad.shape[dim]
# accumulate some stats for learning the projections # accumulate some stats for learning the projections
proj_grad = state[f"proj_grad_{dim}"] proj_grad = state[f"proj_grad_{dim}"]
grad_cov = state[f"grad_cov_{dim}"] grad_cov = state[f"grad_cov_{dim}"]
this_m = p.transpose(-1, dim).reshape(-1, size) # parameter matrix M this_m = p.transpose(-1, dim).reshape(-1, size) # parameter matrix M
this_g = g.transpose(-1, dim).reshape(-1, size) # M_grad this_g = grad.transpose(-1, dim).reshape(-1, size) # M_grad
proj_grad.add_(torch.matmul(this_m.t(), this_g)) proj_grad.add_(torch.matmul(this_m.t(), this_g))
# could perhaps accumulate grad_cov less frequently; it's only # could perhaps accumulate grad_cov less frequently; it's only
# needed when we rediagonalize which is not that common. # needed when we rediagonalize which is not that common.
@ -721,8 +745,8 @@ class LearnedGradient(Optimizer):
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, exp_avg_sq.mul_(beta2).addcmul_(grad, grad,
value=(1-beta2)) value=(1-beta2))
step = state["step"] this_step = (state["step"] - state["zero_step"])
bias_correction2 = 1 - beta2 ** (step + 1) bias_correction2 = 1 - beta2 ** (this_step + 1)
if bias_correction2 < 0.99: if bias_correction2 < 0.99:
# note: not in-place. # note: not in-place.
exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2) exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2)
@ -745,13 +769,11 @@ class LearnedGradient(Optimizer):
denom = scalar_exp_avg_sq.sqrt() + eps # scalar. denom = scalar_exp_avg_sq.sqrt() + eps # scalar.
alpha = state["param_rms"] * (1-beta1) * lr / denom alpha = state["param_rms"] * (1-beta1) * -lr / denom
delta = state["delta"] delta = state["delta"]
delta.add_(grad, alpha=alpha) delta.add_(grad, alpha=alpha)
param.add_(delta)
def _step_scalar(self, def _step_scalar(self,
beta1: float, beta1: float,
@ -1889,8 +1911,8 @@ def _test_eve_cain():
avg_loss = 0.0 avg_loss = 0.0
for epoch in range(150): for epoch in range(150):
scheduler.step_epoch() scheduler.step_epoch()
if epoch == 100 and iter in [2,3]: #if epoch == 100 and iter in [2,3]:
optim.reset_speedup() # check it doesn't crash. # optim.reset_speedup() # check it doesn't crash.
if epoch == 130: if epoch == 130:
opts = diagnostics.TensorDiagnosticOptions( opts = diagnostics.TensorDiagnosticOptions(
@ -1906,7 +1928,7 @@ def _test_eve_cain():
avg_loss = loss.item() avg_loss = loss.item()
else: else:
avg_loss = 0.95 * avg_loss + 0.05 * loss.item() avg_loss = 0.95 * avg_loss + 0.05 * loss.item()
if n == 0 and epoch % 10 == 0: if n == 0 and epoch % 5 == 0:
norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item() norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item()
norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item() norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item()
norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item() norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item()