From 04d2e10b4f4acae1b1c8c4d44c6b22da0cbe2821 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 8 Jul 2022 04:37:46 +0800 Subject: [PATCH] Version that runs --- .../ASR/pruned_transducer_stateless7/optim.py | 80 ++++++++++++------- 1 file changed, 51 insertions(+), 29 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index f27ff153e..8fe3fffaa 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -66,7 +66,7 @@ class LearnedGradient(Optimizer): meta_lr_scale=0.1, betas=(0.9, 0.98), eps=1.0e-08, - size_update_period=4, + size_update_period=1, param_min_rms=1.0e-05, param_max_rms=2.0, lr_mat_min=0.01, @@ -117,10 +117,10 @@ class LearnedGradient(Optimizer): meta_lr_scale = group["meta_lr_scale"] size_lr = lr * group["size_lr_scale"] beta1, beta2 = group["betas"] - max_lr_eig = group["max_lr_eig"] - min_lr_eig = group["max_lr_eig"] + lr_mat_min = group["lr_mat_min"] + lr_mat_max = group["lr_mat_max"] eps = group["eps"] - size_update_period = 4 + size_update_period = group["size_update_period"] param_min_rms = group["param_min_rms"] param_max_rms = group["param_max_rms"] lr_est_period = group["lr_est_period"] @@ -161,9 +161,14 @@ class LearnedGradient(Optimizer): # "param_rms" just periodically records the scalar root-mean-square value of # the parameter tensor. param_rms = (p**2).mean().sqrt().add_(eps) - state["param_rms"] = param_rms - # scalar exp_avg_sq. not the same as scale_exp_avg_sq - state["exp_avg_sq_scalar"] = torch.zeros_like((), **kwargs) + state["param_rms"] = param_rms + + state["scale_exp_avg_sq"] = torch.zeros((), **kwargs) + # scalar exp_avg_sq. not the same as scale_exp_avg_sq, this is used to + # determine the magnitude of the regular step + state["scalar_exp_avg_sq"] = torch.zeros((), **kwargs) + state["scale_grads"] = torch.zeros(size_update_period, + **kwargs) # exp_avg_sq is in the projected space. It is periodically zeroed, and we # set zero_step. @@ -230,17 +235,21 @@ class LearnedGradient(Optimizer): numel = p.numel() if numel > 1: # Update the size/scale of p, and set param_rms - state["scale_grads"][step % size_update_period] = (p * grad).sum() + scale_grads = state["scale_grads"] + scale_grads[step % size_update_period] = (p * grad).sum() if step % size_update_period == size_update_period - 1: # this learns the overall scale on the parameter, by shrinking or # expanding it. - state["param_rms"] = (param ** 2).mean().sqrt().clamp_(min=eps) + param_rms = state["param_rms"] + param_rms.copy_((p ** 2).mean().sqrt().clamp_(min=eps)) if step > 0: self._size_update(p, state, scale_grads, param_rms, beta1, beta2, step, size_lr, param_min_rms, param_max_rms) - state["delta"].mul_(beta1) + + delta = state["delta"] + delta.mul_(beta1) if numel == 1: # For parameters with very few elements we just use a form # of Adam with a scale factor to reflect the overall @@ -252,8 +261,9 @@ class LearnedGradient(Optimizer): self._update_lrs(group, p, state) if step % (lr_est_period * diagonalize_period) == 0: self._diagonalize_lrs(group, p, state) - self._zero_exp_avg_sq() - self._step(group, p, grad, state) + self._zero_exp_avg_sq(state) + if True: + self._step(group, p, grad, state) p.add_(delta) state["step"] = step + 1 @@ -262,9 +272,11 @@ class LearnedGradient(Optimizer): def _size_update(self, p: Tensor, state: dict, - scale_grad: Tensor, + scale_grads: Tensor, + param_rms: Tensor, beta1: float, beta2: float, + step: int, size_lr: float, param_min_rms: float, param_max_rms: float) -> None: @@ -298,7 +310,7 @@ class LearnedGradient(Optimizer): scale_exp_avg_sq = state["scale_exp_avg_sq"] scale_exp_avg_sq.mul_(beta2_corr).add_( - (scale_grads ** 2).mean(), value=1-beta2_corr) + (scale_grads ** 2).mean(), alpha=1-beta2_corr) # The 1st time we reach here is when size_step == 1. @@ -354,7 +366,7 @@ class LearnedGradient(Optimizer): # M is the parameter matrix, of shape (-1, size) where the -1 covers # all other dimensions of the tensor. M_full = p.transpose(dim, -1) - M = M.reshape(-1, size) + M = M_full.reshape(-1, size) # proj_grad is a summed gradient, of shape (size, size), # indexed [old_param_idx][new_param_idx] (the point is, @@ -435,9 +447,12 @@ class LearnedGradient(Optimizer): Q = state[f"Q_{dim}"] # Next we want to implement (eq.1), proj2_grad = Q^{-T} proj_grad Q^T - # torch.linalg.solve(A, B) returns A^{-1} B. - Q_invt_proj_grad = torch.linalg.solve(Q.t(), proj_grad) - assert torch.allclose(proj_grad, torch.matmul(Q.t(), Q_invt_proj_grad)) # TODO: remove + # torch.solve(A, B) returns A^{-1} B. + try: + Q_invt_proj_grad = torch.linalg.solve(Q.t(), proj_grad) + except: + # torch.solve has args the other way round and also returns LU. + Q_invt_proj_grad, _ = torch.solve(proj_grad, Q.t()) # (eq.1), proj2_grad = Q^{-T} proj_grad Q^T proj2_grad = torch.matmul(Q_invt_proj_grad, Q.t()) @@ -456,8 +471,14 @@ class LearnedGradient(Optimizer): # See (eq.4), Q_delta = proj2_delta q Q_delta = torch.matmul(proj2_delta, Q) + assert not torch.all(proj2_delta == 0) + # See (eq.3), proj_delta = Q^{-1} proj2_delta Q = Q^{-1} Q_delta - proj_delta = torch.linalg.solve(Q, Q_delta) + try: + proj_delta = torch.linalg.solve(Q, Q_delta) + except: + # torch.solve has args the other way round and also returns LU. + proj_delta, _ = torch.solve(Q_delta, Q) M_delta = torch.matmul(M, proj_delta) # (eq.5), M_delta = M proj_delta @@ -474,12 +495,15 @@ class LearnedGradient(Optimizer): # there is no momentum on Q. Q.add_(Q_delta, alpha=-meta_lr) + proj_grad.zero_() + def _zero_exp_avg_sq(self, state: dict) -> None: """ Zero the exp_avg_sq stats, and set state["zero_step"] to state["step"] """ state["exp_avg_sq"].zero_() + state["scalar_exp_avg_sq"].zero_() state["zero_step"] = state["step"] def _diagonalize_lrs(self, @@ -704,12 +728,12 @@ class LearnedGradient(Optimizer): # This block accumulates the statistics proj_grad_{dim} and # grad_cov_{dim}, which are for periodically updating the # learning-rate matrices. - size = grad.shape[size] + size = grad.shape[dim] # accumulate some stats for learning the projections proj_grad = state[f"proj_grad_{dim}"] grad_cov = state[f"grad_cov_{dim}"] this_m = p.transpose(-1, dim).reshape(-1, size) # parameter matrix M - this_g = g.transpose(-1, dim).reshape(-1, size) # M_grad + this_g = grad.transpose(-1, dim).reshape(-1, size) # M_grad proj_grad.add_(torch.matmul(this_m.t(), this_g)) # could perhaps accumulate grad_cov less frequently; it's only # needed when we rediagonalize which is not that common. @@ -721,8 +745,8 @@ class LearnedGradient(Optimizer): exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1-beta2)) - step = state["step"] - bias_correction2 = 1 - beta2 ** (step + 1) + this_step = (state["step"] - state["zero_step"]) + bias_correction2 = 1 - beta2 ** (this_step + 1) if bias_correction2 < 0.99: # note: not in-place. exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2) @@ -745,13 +769,11 @@ class LearnedGradient(Optimizer): denom = scalar_exp_avg_sq.sqrt() + eps # scalar. - alpha = state["param_rms"] * (1-beta1) * lr / denom + alpha = state["param_rms"] * (1-beta1) * -lr / denom delta = state["delta"] delta.add_(grad, alpha=alpha) - param.add_(delta) - def _step_scalar(self, beta1: float, @@ -1889,8 +1911,8 @@ def _test_eve_cain(): avg_loss = 0.0 for epoch in range(150): scheduler.step_epoch() - if epoch == 100 and iter in [2,3]: - optim.reset_speedup() # check it doesn't crash. + #if epoch == 100 and iter in [2,3]: + # optim.reset_speedup() # check it doesn't crash. if epoch == 130: opts = diagnostics.TensorDiagnosticOptions( @@ -1906,7 +1928,7 @@ def _test_eve_cain(): avg_loss = loss.item() else: avg_loss = 0.95 * avg_loss + 0.05 * loss.item() - if n == 0 and epoch % 10 == 0: + if n == 0 and epoch % 5 == 0: norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item() norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item() norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item()