From e9e2a85c95f03d508f09446ee19aa0fd89643854 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 6 Jul 2022 13:35:19 -0700 Subject: [PATCH] In the middle of reworking for new idea --- .../ASR/pruned_transducer_stateless7/optim.py | 394 +++++++++++++----- 1 file changed, 285 insertions(+), 109 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index 7c64b67a0..458fbec44 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -48,7 +48,10 @@ class LearnedGradient(Optimizer): learning the scale on the parameters (we'll keep it >= this size) param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of learning the scale on the parameters (we'll keep it <= this size) - lr_est_period: The periodicity, in steps, with which we update the learning-rate matrix. + lr_est_period: The periodicity, in steps, with which we update the learning-rate matrices. + orthogonalize_period: How often we re-orthogonalize the gradient covariance; this + gets multiplied by lr_est_period, i.e. we orthogonalize less frequently + than we update the learning-rate matrixes. max_step_scale: Value that determines limit on how large steps can be per element, intended to prevent instability during early steps before the learning-rate matrices are well learned. @@ -58,12 +61,14 @@ class LearnedGradient(Optimizer): params, lr=3e-02, size_lr_scale=0.1, - meta_lr_scale=1.0, + meta_lr_scale=0.1, betas=(0.9, 0.98), eps=1.0e-08, param_min_rms=1.0e-05, param_max_rms=2.0, lr_est_period=10, + max_lr_eig=4.0, + orthogonalize_period=4, max_step_scale=2.0 ): @@ -77,6 +82,8 @@ class LearnedGradient(Optimizer): param_min_rms=param_min_rms, param_max_rms=param_max_rms, lr_est_period=lr_est_period, + max_lr_eig=max_lr_eig, + orthogonalize_period=orthogonalize_period, max_step_scale=max_step_scale, ) @@ -104,16 +111,15 @@ class LearnedGradient(Optimizer): meta_lr_scale = group["meta_lr_scale"] size_lr = lr * group["size_lr_scale"] beta1, beta2 = group["betas"] + max_lr_eig = group["max_lr_eig"] eps = group["eps"] size_update_period = 4 param_min_eps = group["param_min_rms"] param_max_eps = group["param_max_rms"] lr_est_period = group["lr_est_period"] + orthogonalize_period = group["orthogonalize_period"] max_step_scale = group["max_step_scale"] - - small_tensor_cutoff = 5 # cutoff below which we use Adam - for p in group["params"]: if p.grad is None: continue @@ -134,60 +140,69 @@ class LearnedGradient(Optimizer): kwargs = {'device':p.device, 'dtype':p.dtype} - if p.numel() > 1: - # we learn the scale on the parameters in a way that - # also emulates Eve. scale_exp_avg_sq is the - # moving-average square of the gradient w.r.t. this - # scalar scale. - state["scale_exp_avg_sq"] = torch.zeros((), **kwargs) - state["scale_exp_avg"] = torch.zeros((), **kwargs) - # for speed, we do the scale update only periodically, - # and meanwhile we store the derivs w.r.t. the grad scale - # as a list - state["scale_grads"] = torch.zeros(size_update_period, **kwargs) - - state["exp_avg"] = torch.zeros_like( + # 'delta' implements conventional momentum. There are + # several different kinds of update going on, so rather than + # compute "exp_avg" like in Adam, we store and decay a + # parameter-change "delta", which combines all forms of + # update. this is equivalent to how it's done in Adam, + # except for the first few steps. + state["delta"] = torch.zeros_like( p, memory_format=torch.preserve_format ) - if p.numel() > small_tensor_cutoff: - # moving_grad is a moving average of the raw gradient. - # We do the accumulation on the input side because it - # makes it easier to get grads w.r.t. the learning rate - # matrices. - - # "scale" is just a per-tensor scalar that determines the - # rate of learning (in terms of the rms value of the change - # in parameter) + if p.numel() > 1: + # "param_rms" just periodically records the scalar root-mean-square value of + # the parameter tensor. param_rms = (p**2).mean().sqrt().add_(eps) - state["param_rms"] = param_rms + state["param_rms"] = param_rms + # scalar exp_avg_sq. not the same as scale_exp_avg_sq + state["exp_avg_sq_scalar"] = torch.zeros_like((), **kwargs) - for dim in range(p.ndim): - size = p.shape[dim] - if size == 1 or size == p.numel(): - # if size == p.numel(), is_one_axis == True above - continue + for dim in range(p.ndim): + size = p.shape[dim] + if size == 1: + # if size == p.numel(), is_one_axis == True above + continue - state[f"lr_{dim}"] = torch.eye(size, **kwargs) - # Some parts of the update -- the parts where we are - # doing to update lr_{dim} will require grad. these - # parts will be done inside torch.enable_grad(). - # The rest of this function is inside no_grad(), so - # it will ignore the requires_grad == True. - state[f"lr_{dim}"].requires_grad = True + # exp_avg_sq is in the projected space. + state["exp_avg_sq"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) - # grad_cov_{dim} is the covariance of gradients on this axis, - # treating all other axes as as a batch axis. This is needed - # only for purposes of computing the scalar factor on the - # learning rate; and, as a result of this, also contributes something - # to the gradient w.r.t. f"lr_{dim}", which is one reason - # why we allocate a variable to keep track of its moving average - # instead of just using a temporary and smoothing the scalar factor. - state[f"grad_cov_{dim}"] = torch.zeros(size, size, **kwargs) - else: - # size <= small_tensor_cutoff: do a form of Adam - state["exp_avg_sq"] = torch.zeros_like(p, - mamory_format=preserve_memory_format) + + # proj_{dim} will be the learning-rate matrix for this + # dimension, times an orthogonal matrix that we'll periodically + # re-estimate when we "reorthogonalize" (this orthogonalizes + # the gradient covariance). + state[f"proj_{dim}"] = torch.eye(size, **kwargs) + + # proj_grad_{dim} is the gradient w.r.t. proj_{grad} + # (technically w.r.t a matrix to the left of + # proj_{grad}, we'll project it later); we accumulate + # this for `lr_est_period` steps, and then update + # proj_{dim} and zero this matrix + state[f"proj_grad_{dim}"] = torch.zeros(size, size, **kwargs) + + # grad_cov_{dim} is the covariance of gradients on this axis (without + # any co-ordinate changes), treating all other axes as as a batch axis. + # This is needed when we re-orthogonalize, and to compute the + # grad_rms_{dim}. We store it as a decaying average, decaying with beta2 + + # only for purposes of computing the scalar factor on the + # learning rate; and, as a result of this, also contributes something + # to the gradient w.r.t. f"lr_{dim}", which is one reason + # why we allocate a variable to keep track of its moving average + # instead of just using a temporary and smoothing the scalar factor. + state[f"grad_cov_{dim}"] = torch.zeros(size, size, **kwargs) + + # proj_grad_var_{dim} is the variance of gradients of the thing we're + # going to multiply proj_{dim} by (between it and the "underlying + # parameter matrix"), aggregated over each of its 2 dimensions + # and decayed over time with beta1 each time we do a step on proj_{dim}, + # e.g. every lr_est_period steps. It's a factored representation + # of something like exp_avg_sq, used for estimating proj_{dim}; we avoid + # storing an exp_avg_sq for proj_{dim}, to save GPU memory. + state[f"proj_grad_var_{dim}"] = torch.ones(2, size, **kwargs) step = state["step"] @@ -206,14 +221,14 @@ class LearnedGradient(Optimizer): beta1, beta2, step, size_lr, param_min_rms, param_max_rms) - if numel <= small_tensor_cutoff: + if numel == 1: # For parameters with very few elements we just use a form # of Adam with a scale factor to reflect the overall # parameter rms. - self._step_small(beta1, beta2, eps, lr, p, grad, state) + self._step_scalar(beta1, beta2, eps, lr, p, grad, state) else: if step % lr_est_period == 0: - self._update_lrs(group, grad, state) + self._update_lrs(group, p, state) self._step(group, p, grad, state) state["step"] = step + 1 @@ -289,28 +304,129 @@ class LearnedGradient(Optimizer): def _update_lrs(self, group: dict, - grad: Tensor, + p: Tensor, state: dict) -> None: """ Updates the learning-rate matrices state["lr_{dim}"]. + Args: group: dict to look up configuration values - grad: current gradient for the parameter that we are updating + p: parameter matrix that we are updating. The learning rate matrices + are actually factors of p, so p itself will change when we change + them. state: state dict for the current parameter """ - # meta_lr is the learning rate for the learning rate matrix. - meta_lr = group["lr"] * group["meta_lr_scale"] + # meta_lr is the learning rate for the learning rate matrix. Including the factor + # (group["lr_est_period"] ** 0.5) is intended to make it approximately invariant + # to the lr_est_period (close to convergence). + meta_lr = group["lr"] * group["meta_lr_scale"] * (group["lr_est_period"] ** 0.5) beta1, beta2 = group["betas"] eps = group["eps"] moving_g = state["exp_avg"] - ndim = p.ndim - # Update grad_cov. - self._store_grad_stats(grad, state, beta2) + for dim in range(ndim): + size = p.shape[dim] + if size == 1: + continue + + p_t = p.transpose(dim, -1) + this_p = p_t.reshape(-1, size) + + # proj_grad is a summed gradient, indexed [grad_idx][param_idx] (don't worry + # if this notation about the indexing is unclear, just keep reading). + # + # Suppose M is our parameter matrix p reshaped to (size, -1) with the -1 being + # all other dims treated as a batch dim, and M_grad is the loss derivative w.r.t. M, + # + # Then proj_grad is the accumulated value of M_grad M^T, which will be of shape + # (size, size). If we view the parameter matrix as actually (proj M) with + # proj being numerically equal to I, i.e. as + # matmul(proj, M), then instead of the loss being tr(M_grad^T M), we can write + # it as tr(M_grad^T proj M), which can also be written as tr(M M_grad^T proj) + # = tr( (M_grad M^T)^T proj), where we can write M_grad M^T == proj_grad. + # + # Now, proj_grad is convenient to compute but it's not in a very normalized + # space, we want to normalize it first. We're going to view M as + # M = T N, + # where T is our proj_{dim} matrix (see the code), and N == T^{-1} M. + # It's more convenient to have the projection between T and N, i.e. have proj2, + # again numerically equal to I, so + # M == T proj2 N. + # [Be careful because when we use T as a superscript it means transpose). + # We want the derivative w.r.t. proj2. + # So suppose the loss is + # tr(M_grad^T M) == tr(M_grad^T T proj2 N) == tr(N M_grad^T T proj2) + # == tr( (T^T M_grad N^T)^T proj2 ) + # == tr( (T^T M_grad N^T)^T proj2 ) + # and using N == T^{-1} M, so N^T = M^T T^{-T}, + # == tr( (T^T M_grad M^T T^{-T})^T proj2 ) + # so we can view the thing in parentheses in the above expression as proj2_grad, i.e. + # proj2_grad = T^T M_grad M^T T^{-T} + # proj2_grad = T^T proj_grad T^{-T} + # + # note: p = state[f"proj_{dim}"] actually corresponds to T^T, since it has a + # shape interpretable as (diagonalized_index, canonical_index). + # + # So proj2_grad = p proj_grad p^{-1} (eq.1) + # + # After computing proj2_grad we compute/update a factorized representation of + # its variance, factorized over its 2 axes, and do an Adam-type update + # for it, except without momentum at this level (momentum will be applied + # later). This gives us proj2_delta, according to the + # equation: proj2 = I + proj2_delta. + # + # This is the point at which we control the magnitude of the parameter change. + # + # We now have to update the factor T using proj2, as in: + # T <-- T (I + proj2_delta) + # T <-- T + T proj2_delta + # and since p == T^T, this becomes: + # p <-- p + proj2_delta^T p + # or equivalently: + # p_delta = proj2_delta^T p (eq.2) + # + # We also need to update M itself; this is easiest done by computing proj_delta + # (which is the deviation of proj from I). The relationship between proj and proj2 + # is given by equating + # proj T N = T proj2 N, + # proj T = T proj2, + # proj = T proj2 T^{-1} + # proj = p^T proj2 p^{-T} . + # we can apply the above formula directly to proj2_delta, since the "I" part just + # becomes a no-op, i.e. + # proj_delta = p^T proj2_delta p^{-T}. + # ... and to turn this back into a delta on M, we'll do: + # M_delta = proj_delta M. + + + proj_grad = state[f"proj_grad_{dim}"] + p = state[f"proj_{dim}"] + + # torch.linalg.solve(A, B) returns X such that AX=B, + # meaning X = A^{-1} B. We want to compute X = proj_grad p^{-1}, + # so X^T = p^{-T} proj_grad^T. + X = torch.linalg.solve(p.t(), proj_grad.t()).t() + assert torch.allclose(proj_grad, torch.matmul(X, p)) # TEMP. + + # See (eq.1), the next line implements proj_grad = p proj_grad p^{-1} + proj2_grad = torch.matmul(p, X) + + # for now just estimate it from proj2_grad, later we will try aggregating it + # across iterations. This is like exp_avg_sq in Adam. + proj2_grad_var = self._get_matrix_var(proj_grad, group["eps"]) + + denom = proj2_grad_var.sqrt() + + alpha = -meta_lr * (1-beta1) + + # we'll take into account alpha later on. + proj2_delta = proj2_grad_var / denom + + p_delta = torch.matmul(proj2_delta.t(), p) + - with torch.enable_grad(): # To update the learning rate matrices, we are asking the question, "What if, # on the previous step, we had used a slightly different learning-rate matrix? @@ -345,7 +461,7 @@ class LearnedGradient(Optimizer): # [something] as it contributes no gradient, becomes -alpha * # (moving_d_norm * grad).sum(), and ignoring the scale alpha, that # is -(moving_d_norm * grad).sum(), so we call this neg_loss as it's the - # negative of hte iss. + # negative of the loss. neg_loss = (grad * moving_d_norm).sum() # after the following, there will grads in state[f"lr_{dim}"] @@ -357,6 +473,32 @@ class LearnedGradient(Optimizer): if p.shape[dim] != 1: self._update_lr(state[f"lr_{dim}"], meta_lr) + def _get_matrix_var(self, + x: Tensor, + eps: float) -> Tensor: + """ + Estimate a variance of elements of x, which must be a tensor with exactly + 2 diimensions, neither of which can have size == 1. This is done using + a structured representation, so the variance is a product over the 2 + dimensions (otherwise, it wouldn't be possible to estimate a variance from + a single sample). Eventually we may do this via aggregation across + time as well; this approach seems to give a high variance estimate even + for largish dimension, e.g. if x.shape == (100,100) the 95th percentile + is from about 0.6 to 1.6. + + Args: + x: tensor to estimate the variance of elements of. Require x.ndim == 2, + neither size must be 1. + eps: Epsilon to avoid zero variance: epsilon is dimensionally a rms value. + """ + assert x.ndim == 2 and 1 not in list(x.shape) # TODO: remove? + + var0 = (x**2).mean(dim=1, keepdim=True) + eps*eps + var1 = ((x / var0.sqrt()) ** 2).mean(dim=0, keepdim=True) + eps*eps + # esetimate var0 one more time. + var0 = ((x / var1.sqrt()) ** 2).mean(dim=1, keepdim=True) + eps*eps + return var1 * var0 + def _update_lr(self, @@ -421,7 +563,7 @@ class LearnedGradient(Optimizer): state: dict): """ This function does the core update of self._step, in the case where the tensor - has more than about 5 elements. It multiplies the moving-average gradient tensor by a + has more than 1 elements. It multiplies the moving-average gradient tensor by a state["lr_{dim}"] on each non-trivial axis/dimension, does a length normalization, and takes a step in that direction. @@ -434,55 +576,60 @@ class LearnedGradient(Optimizer): This function modifies p. """ lr = group["lr"] - beta1 = group["betas"][0] + beta1, beta2 = group["betas"] eps = group["eps"] - max_step_scale = group["max_step_scale"] - moving_g = state["exp_avg"] - # include the current grad in moving_g - moving_g.mul_(beta1).add(grad, alpha=(1-beta1)) - # get the parameter delta (up to a scalar factor) by multiplying by a - # learning-rate matrix for each dimension. - moving_d = self._multiply_by_lr(moving_g) + for dim in range(grad.ndim): + size = grad.shape[size] + # accumulate some stats for learning the projections + proj_grad = state[f"proj_grad_{dim}"] + grad_cov = state[f"grad_cov_{dim}"] + this_p = p.transpose(-1, dim).reshape(-1, size) + this_g = g.transpose(-1, dim).reshape(-1, size) + proj_grad.add_(torch.matmul(this_g.t(), this_p)) + # could perhaps accumulate grad_cov less frequently. + grad_cov.mul_(beta2).add_(torch.matmul(this_g.t(), this_g)) - # moving_d contains a moving average of parameter change contributions - # from multiple iterations with weights: (1-beta) * (1 + beta + beta^2 + - # beta^3, etc.), which sum to 1 for large step. Suppose the rms value - # of each individual gradient were r. Assuming the terms are - # independent, with rms=r, the variance of the elements of moving_d - # would be r^2 * (1-beta)^2 * (1 + beta^2 + beta^4 + ... ) which equals - # r^2 * (1-beta)^2 / (1-beta^2), so the measured rms of moving_d would - # be moving_d_rms = individual_d_rms * (1-beta) / sqrt(1-beta^2), so - # individual_d_rms = moving_d_rms * sqrt(1-beta^2) / (1-beta) ... this - # would be the root-mean-square value of the individual deltas, if we - # were to proces them individually (if they were independent), and this - # is what we want to normalize by for greater independence of the beta - # value, i.e. so beta only affects momentum and not the overall speed - moving_d_rms = (moving_d**2).mean().sqrt() + eps + grad = self._project(grad, forward=True, state) - # e.g. if beta1==0.98, this formula gives a factor of 9.95, reflecting - # that the individual rms's of the un-scaled parameter deltas without - # momentum are larger than their moving average. see the justification - # in a comment above. - d_rms = moving_d_rms * (((1 - beta1 * beta1) ** 0.5) / (1 - beta1)) - param_rms = state["param_rms"] - alpha = -lr * param_rms / d_rms + exp_avg_sq = state["exp_avg_sq"] + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, + value=(1-beta2)) - # apply max_step_scale to limit magnitude of change on individual - # elements, this should help stability. - limit = moving_d_rms * max_step_scale - moving_d_clamped = moving_d.clamp_(min=-limit, max=limit) + step = state["step"] + bias_correction2 = 1 - beta2 ** (step + 1) + if bias_correction2 < 0.99: + # note: not in-place. + exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2) + denom = exp_avg_sq.sqrt() + eps - if random.random() < 0.001: - clamp_diff_rms = ((moving_d_clamped - moving_d) ** 2).mean().sqrt() - rel_clamp_diff = clamp_diff_rms / d_rms - logging.info(f"Clamping difference is {clamp_diff_rms.item():.3g} vs. {d_rms.item():.3g}: " - f"relative diff {rel_clamp_diff.item():.3g}, shape={tuple(p.shape)}") + grad = grad / denom - p.add_(moving_d_clamped, alpha=alpha) + # and project back.. + grad = self._project(grad, forward=False, state) - def _step_small(self, + + scalar_exp_avg_sq = state["scalar_exp_avg_sq"] + scalar_exp_avg_sq.mul_(beta2).add_((grad**2).mean(), + alpha=1-beta2) + + + if bias_correction2 < 0.99: + # note: not in-place. + scalar_exp_avg_sq = scalar_exp_avg_sq * (1.0 / bias_correction2) + + denom = scalar_exp_avg_sq.sqrt() + eps # scalar. + + alpha = state["param_rms"] * (1-beta1) * lr / denom + + delta = state["delta"] + delta.mul_(beta1).add_(grad, alpha=alpha) + + param.add_(delta) + + + def _step_scalar(self, beta1: float, beta2: float, eps: float, @@ -491,11 +638,11 @@ class LearnedGradient(Optimizer): grad: Tensor, state: dict): """ - A form of the core update for tensors with a small number of elements, e.g. <5. This is + A form of the core update for tensors with a small number of elements, e.g. scalars. This is Adam where, if the numel() > 1, the learning rate is proportional to the parameter rms value. """ exp_avg = state["exp_avg"] - exp_avg_sq = state["exp_avg_sq"] + exp_avg_sq = state["scalar_exp_avg_sq"] exp_avg.mul_(beta1).add_(grad, alpha=1-beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1-beta2) @@ -505,11 +652,40 @@ class LearnedGradient(Optimizer): bias_correction2 = 1 - beta2 ** (state["step"] + 1) denom = (exp_avg_sq.sum() / (exp_avg_sq.numel() * bias_correction2)).sqrt() + eps alpha = -lr / denom - if p.numel() > 1: - alpha = alpha * state["param_rms"] + # for now not supporting p.numel() > 1. + #if p.numel() > 1: + # alpha = alpha * state["param_rms"] p.add_(exp_avg, alpha=alpha) + def _project(self, + x: Tensor, + forward: bool, + state: dict) -> Tensor: + """ + Multiply a tensor x by proj_{dim} on each of its dimensions. + If forward == True, converts x from canonical to preconditioned/diagonalized + co-ordinates; if forward == False, does the reverse. + + Args: + x: The tensor to project, e.g. a gradient or a parameter change. + forward: if True, go in the forward directin (from canonical to diagnonalized + co-ordinates; if False, the reverse. They differ by a transpose, + not an inverse, and do not make a round trip. + """ + for dim in range(x.ndim): + if x.shape[dim] == 1: + continue + proj = state[f"proj_{dim}"] + if forward: + # dimension 0 of `proj` is diagonalized co-ordinate. + proj = proj.t() + # TODO: could possibly somehow force the output format to be unchanged. + x = x.transpose(-1, dim) + x = torch.matmul(x, proj) + x = x.transpose(-1, dim) + return x + def _store_grad_stats(self, grad: Tensor, state: dict,