From 6810849058ddc5ba2b92304445679b06859c2efd Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 9 Jul 2022 10:02:17 +0800 Subject: [PATCH] Implement new version of learning method. Does more complete diagonalization of grads than the previous methods. --- .../ASR/pruned_transducer_stateless7/optim.py | 617 ++++++------------ 1 file changed, 196 insertions(+), 421 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index e624d53e7..fa951d94c 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -24,10 +24,10 @@ from torch.optim import Optimizer from icefall import diagnostics # only for testing code import logging -class LearnedGradient(Optimizer): +class PrAdam(Optimizer): """ - Implements 'Learned Gradient' update. It's special factorization of the - parameter matrix.. + Implements 'Projected Adam', a variant of Adam with projections for each + dim of each tensor. Args: @@ -35,63 +35,70 @@ class LearnedGradient(Optimizer): lr: The learning rate. We will typically use a learning rate schedule that starts at 0.03 and decreases over time, i.e. much higher than other common optimizers. + betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad only for + scalar tensors. Must satisfy 0 < beta <= beta2 < 1. size_lr_scale: A scaling factor on the learning rate, that we use to update the scale of each parameter tensor. If each parameter were decomposed as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, size_lr_scale is the scaling factor on the learning rate of p_scale. - meta_lr_scale: The learning rate for the learning-rate matrix, this is a scale on - the regular learning rate. - betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad only for - scalar tensors. Must satisfy 0 < beta <= beta2 < 1. + param_pow: Power on the parameter covariance matrix, 1.0 means learn proportional + to parameter rms (1.0 will be too much, should be between 0 and 1.) +param_rms_smooth0: Limiting value of smoothing proportion for parameter matrix, as + assumed rank of param covariance [==product of sizes on the other + tensor dims] approaches 0. +param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of + param covariance equals the dimension of the covaraince matrix. + param_rms_smooth{0,1} determine the smoothing proportions for other + conditions. + param_cov_freshness: Constant that determines how "recent" the parameter covariance + matrix is. 1.0 corresponds to flat average over time; 2.0 + would mean weighting with a quadratic function, etc. eps: An epsilon to prevent division by zero param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of learning the scale on the parameters (we'll keep it >= this size) param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of learning the scale on the parameters (we'll keep it <= this size) - lr_mat_min: Minimum singular value of learning-rate matrices Q. - lr_mat_max: Maximum singular value of learning-rate matrices Q. - lr_est_period: The periodicity, in steps, with which we update the learning-rate matrices. - diagonalize_period: How often we re-diagonalize the gradient covariance; this - gets multiplied by lr_est_period, i.e. we diagonalize less frequently - than we update the learning-rate matrixes. + size_update_period: The periodicity, in steps, with which we update the size (scale) + of the parameter tensor. This is provided to save a little time. + lr_update_period: The periodicity, in steps, with which we update the learning-rate matrices. """ def __init__( self, params, lr=3e-02, - size_lr_scale=0.1, - meta_lr_scale=0.0, betas=(0.9, 0.98), + size_lr_scale=0.1, + param_pow=0.75, + param_rms_smooth0=0.75, + param_rms_smooth1=0.25, + param_cov_freshness=1.0, eps=1.0e-08, - size_update_period=1, param_min_rms=1.0e-05, param_max_rms=2.0, - lr_mat_min=0.01, - lr_mat_max=10.0, - lr_est_period=2, - diagonalize_period=4, + size_update_period=4, + lr_update_period=20, ): defaults = dict( lr=lr, size_lr_scale=size_lr_scale, - meta_lr_scale=meta_lr_scale, + param_pow=param_pow, + param_rms_smooth0=param_rms_smooth0, + param_rms_smooth1=param_rms_smooth1, + param_cov_freshness=param_cov_freshness, betas=betas, eps=eps, - size_update_period=size_update_period, param_min_rms=param_min_rms, param_max_rms=param_max_rms, - lr_mat_min=lr_mat_min, - lr_mat_max=lr_mat_max, - lr_est_period=lr_est_period, - diagonalize_period=diagonalize_period, + size_update_period=size_update_period, + lr_update_period=lr_update_period, ) - super(LearnedGradient, self).__init__(params, defaults) + super(PrAdam, self).__init__(params, defaults) def __setstate__(self, state): - super(LearnedGradient, self).__setstate__(state) + super(PrAdam, self).__setstate__(state) @torch.no_grad() @@ -111,14 +118,11 @@ class LearnedGradient(Optimizer): lr = group["lr"] size_lr = lr * group["size_lr_scale"] beta1, beta2 = group["betas"] - lr_mat_min = group["lr_mat_min"] - lr_mat_max = group["lr_mat_max"] eps = group["eps"] size_update_period = group["size_update_period"] param_min_rms = group["param_min_rms"] param_max_rms = group["param_max_rms"] - lr_est_period = group["lr_est_period"] - diagonalize_period = group["diagonalize_period"] + lr_update_period = group["lr_update_period"] for p in group["params"]: if p.grad is None: @@ -128,7 +132,7 @@ class LearnedGradient(Optimizer): grad = p.grad if grad.is_sparse: raise RuntimeError( - "LearnedGradient optimizer does not support sparse gradients" + "PrAdam optimizer does not support sparse gradients" ) state = self.state[p] @@ -177,29 +181,14 @@ class LearnedGradient(Optimizer): continue - # Q_{dim}, which we will call q in this comment, will be the learning-rate matrix for this - # dimension, times an orthogonal matrix that we'll periodically - # re-estimate when we "rediagonalize" (this diagonalizes - # the gradient covariance). - # If our parameter matrix M is of shape (..., size), - # we'll view M as being M == torch.matmul(N, Q) - # so p can be interpreted as having shape - # (size, size), interpreted as being indexed [diagonalized_coordinate, canonical_coordinate] - # by "diagonalize" we mean that we diagonalize the gradient covariance. + # Q_{dim} is be the learning-rate matrix for this + # dimension, a matrix of indexed [diagonalized_coordinate, canonical_coordinate]. + # this is used twice in the update step, once transposed and once without transpose. state[f"Q_{dim}"] = torch.eye(size, **kwargs) - # proj_grad_{dim} is the gradient w.r.t. `proj`, - # which is a matrix we introduce to the right of Q, i.e. - # instead of M == torch.matmul(N, Q) we view it as - # M == torch.matmul(torch.matmul(N, torch.matmul(Q, proj))) - # or equivalently M == torch.matmul(M_underlying, proj) - # where M_underlying is numerically equal to M, and proj is numerically - # equal to I, but they are viewed as separate variables. - # - # proj_grad_{dim} will be set to the sum of - # torch.matmul(M^T, M_grad) over a number of steps (`lr_est_period` steps), - # and will be zeroed after we call self_update_lrs(). - state[f"proj_grad_{dim}"] = torch.zeros(size, size, **kwargs) + # param_cov_{dim} is the averaged-over-time gradient of parameters on this dimension, treating + # all other dims as a batch axis. + state[f"param_cov_{dim}"] = torch.zeros(size, size, **kwargs) # grad_cov_{dim} is the covariance of gradients on this axis (without # any co-ordinate changes), treating all other axes as as a batch axis. @@ -213,14 +202,6 @@ class LearnedGradient(Optimizer): # instead of just using a temporary and smoothing the scalar factor. state[f"grad_cov_{dim}"] = torch.zeros(size, size, **kwargs) - # proj_grad_var_{dim} is the variance of gradients of the thing we're - # going to multiply proj_{dim} by (between it and the "underlying - # parameter matrix"), aggregated over each of its 2 dimensions - # and decayed over time with beta1 each time we do a step on proj_{dim}, - # e.g. every lr_est_period steps. It's a factored representation - # of something like exp_avg_sq, used for estimating proj_{dim}; we avoid - # storing an exp_avg_sq for proj_{dim}, to save GPU memory. - state[f"proj_grad_var_{dim}"] = torch.ones(2, size, **kwargs) step = state["step"] @@ -250,10 +231,9 @@ class LearnedGradient(Optimizer): # parameter rms. Updates delta. self._step_scalar(beta1, beta2, eps, lr, p, grad, state) else: - if step % lr_est_period == 0 and step > 0: + if step % lr_update_period == 0 and step > 0: + self._accum_param_covs(group, p, state) self._update_lrs(group, p, state) - if step % (lr_est_period * diagonalize_period) == 0 and step > 0: - self._diagonalize_lrs(group, p, state) self._zero_exp_avg_sq(state) self._step(group, p, grad, state) p.add_(delta) @@ -326,31 +306,19 @@ class LearnedGradient(Optimizer): # the factor of (1-beta1) relates to momentum. delta.add_(p, alpha=(1-beta1) * scale_step) - - def _update_lrs(self, - group: dict, - p: Tensor, - state: dict) -> None: + def _accum_param_covs(self, + group: dict, + p: Tensor, + state: dict) -> None: """ - Updates the learning-rate matrices state["Q_{dim}"]. - - Args: - group: dict to look up configuration values - p: parameter matrix that we are updating. The learning rate matrices - are actually factors of p, so p itself will change when we change - them. - state: state dict for the current parameter + Updates state[f"param_cov_{dim}"] for each dim of tensor p. """ - # meta_lr is the learning rate for the learning rate matrix. Including the factor - # (group["lr_est_period"] ** 0.5) is intended to make it approximately invariant - # to the lr_est_period (close to convergence). - meta_lr = group["lr"] * group["meta_lr_scale"] * (group["lr_est_period"] ** 0.5) - beta1 = group["betas"][0] eps = group["eps"] - delta = state["delta"] - ndim = p.ndim - - for dim in range(ndim): + lr_update_period = group["lr_update_period"] + step = state["step"] + this_weight = (group["param_cov_freshness"] * lr_update_period / + (step + lr_update_period)) + for dim in range(p.ndim): size = p.shape[dim] if size == 1: continue @@ -359,153 +327,14 @@ class LearnedGradient(Optimizer): # all other dimensions of the tensor. M_full = p.transpose(dim, -1) M = M_full.reshape(-1, size) + this_param_cov = torch.matmul(M.t(), M) + # normalize scale of this param_cov, in casre parameter scale + # changes significantly during training, which would cause some + # parts of the training timeline to be more highly weighted. + this_param_cov /= (this_param_cov.diag().mean() + eps) + state[f"param_cov_{dim}"].mul_(1-this_weight).add_(this_param_cov, + alpha=this_weight) - # proj_grad is a summed gradient, of shape (size, size), - # indexed [old_param_idx][new_param_idx] (the point is, - # param_idx corresponds to the index of the current M, new_param_idx - # corresponds to the index of the M after a small change). - # - # Suppose M is our parameter matrix p reshaped to (-1, size) with - # the -1 being all other dims of the tensor treated as a batch dim - # and "size" is the size of whatever dimension we are working on - # right now. And suppose M_grad is the loss derivative w.r.t. M - # (for compatibility with things like Torch, we use a convention - # where all loss derivatives/grads have the same shape as the things they - # are derivative with respect to). - # - # proj_grad is the accumulated value, on the last few steps, of M M_grad^T, - # which will be of shape (size, size). For purposes of our update, if we view - # the parameter matrix M as (M_underlying proj) with - # proj being numerically equal to I but treated as a variable, i.e. as - # M = matmul(M_underlying, proj), then instead of the loss being tr(M_grad^T M), we can write - # it as tr(M_grad^T M_underlying proj), which can also be written as tr(proj_grad^T proj) - # where proj_grad == (M_grad^T M_underlying)^T == M_underlying^T M_grad == M^T M_grad - # - # Now, proj_grad is convenient to compute, being independent of Q, - # but it's not in a very normalized - # space; we want to normalize it before doing gradient descent on it. - # We're going to view M as - # M == N Q, - # where Q is our Q_{dim} "learning-rate" matrix, of shape (size, size) indexed - # [diagonalized_index, canonical_index] (see the code), and N == M Q^{-1}. - # It's likely better numerically to do gradient descent on a `proj2` matrix - # that's between N and Q, - # i.e.: - # M == N proj2 Q - # We want to obtain the derivative w.r.t. proj2. - # So the linearized pseudo-loss is - # tr(M_grad^T M) == tr(M_grad^T N proj2 Q) == tr(Q M_grad^T N proj2), - # which we can view as tr(proj2_grad^T proj2), where - # proj2_grad == (Q M_grad^T N)^T == N^T M_grad Q^T == Q^{-T} M^T M_grad Q^T - # proj2_grad = Q^{-T} proj_grad Q^T (eq.1) - # - # After computing proj2_grad we compute/update a factorized representation of - # its variance, factorized over its 2 axes, and do an Adam-type update - # for it, except without momentum at this level (momentum will be applied - # later). This gives us proj2_delta, according to the - # equation: - # proj2 = I + proj2_delta - # (again: proj2_delta is derived from proj2_grad using an Adam-style update). - # - # We now have to update the "learning-rate matrix" Q using proj2, as in - # Q := proj2 Q - # Q := (proj2_delta + I) Q, or: - # Q += proj2_delta Q (eq.2) - # - # We also need to update M itself; this is easiest done by computing proj_delta - # (which is the difference of proj from I). The relationship between proj and proj2 - # is given by equating - # M == N proj2 Q == N Q proj, - # so proj2 Q == Q proj - # proj == Q^{-1} proj2 Q - # and subtracting out the constant "I" part, - # proj_delta == Q^{-1} proj2_delta Q (eq.3) - # - # and then because we view the parameter matrix as - # "M proj", - # with proj being numerically equal to I but learned here, - # it becomes M (I + proj_delta) = M + M proj_delta. - # So the update to M is: - # M += M proj_delta - # or equivalently: - # M_delta = M proj_delta (eq.5) - proj_grad = state[f"proj_grad_{dim}"] - Q = state[f"Q_{dim}"] - - # delta_scale < 1 will make it update the learning rates faster than it otherwise would, - # as we'll reach equilbrium with M less rapidly. - delta_scale=0.5 - - if True: - # Simpler update, do it to the right of Q. - # symmetrize, otherwise there is a bad interaction with the regular update. - proj_grad[:] = proj_grad + proj_grad.t() - proj_grad_var = self._get_matrix_var(proj_grad, group["eps"]) - denom = proj_grad_var.sqrt() - proj_delta = proj_grad / denom - Q_delta = torch.matmul(Q, proj_delta) - M_delta = torch.matmul(M, proj_delta) - M_delta_full = M_delta.reshape(M_full.shape) - this_delta = M_delta_full.transpose(dim, -1) - delta.add_(this_delta, alpha=-delta_scale*meta_lr*(1-beta1)) - # there is no momentum on Q. - Q.add_(Q_delta, alpha=-meta_lr) - proj_grad.zero_() - continue - - - # Next we want to implement (eq.1), proj2_grad = Q^{-T} proj_grad Q^T - # torch.solve(A, B) returns A^{-1} B. - try: - # not all versions of Torch have this. - Q_invt_proj_grad = torch.linalg.solve(Q.t(), proj_grad) - except: - # torch.solve has args the other way round and also returns LU. - Q_invt_proj_grad, _ = torch.solve(proj_grad, Q.t()) - - # (eq.1), proj2_grad = Q^{-T} proj_grad Q^T - proj2_grad = torch.matmul(Q_invt_proj_grad, Q.t()) - - # symmetrize, otherwise there is a bad interaction with the regular - # update in self._step(). - proj2_grad = proj2_grad + proj2_grad.t() - - # for now just estimate proj2_grad_var from proj2_grad; later, we - # may try aggregating it across iterations. This is like exp_avg_sq - # in Adam. - proj2_grad_var = self._get_matrix_var(proj2_grad, group["eps"]) - - denom = proj2_grad_var.sqrt() - - # we'll take into account the factor of -meta_lr * (1-beta1) later on; - # actually proj2_delta is the negative of a parameter change. - proj2_delta = proj2_grad / denom - - # See (eq.2), Q += proj2_delta Q - Q_delta = torch.matmul(proj2_delta, Q) - - # See (eq.3), proj_delta = Q^{-1} proj2_delta Q = Q^{-1} Q_delta - try: - proj_delta = torch.linalg.solve(Q, Q_delta) - except: - # torch.solve has args the other way round and also returns LU. - proj_delta, _ = torch.solve(Q_delta, Q) - - M_delta = torch.matmul(M, proj_delta) # (eq.5), M_delta = M proj_delta - - M_delta_full = M_delta.reshape(M_full.shape) - - this_delta = M_delta_full.transpose(dim, -1) - # we'l be applying momentum on delta, so multiply by (1-beta1) to - # get the total magnitude of the change right. The 2 changes below - # are the final changes, the only 2 we make in this loop that have - # side effects. - - delta.add_(this_delta, alpha=-delta_scale*meta_lr*(1-beta1)) - # there is no momentum on Q. - Q.add_(Q_delta, alpha=-meta_lr) - - proj_grad.zero_() def _zero_exp_avg_sq(self, state: dict) -> None: @@ -516,25 +345,12 @@ class LearnedGradient(Optimizer): state["scalar_exp_avg_sq"].zero_() state["zero_step"] = state["step"] - def _diagonalize_lrs(self, - group: dict, - p: Tensor, - state: dict) -> None: + def _update_lrs(self, + group: dict, + p: Tensor, + state: dict) -> None: """ - Diagonalizes the learning-rate matrices state["Q_{dim}"], and also - applies the min- and max- singular value constraints. This Q - is applied to the parameter matrix M such that we interpret - M == N Q. - By "diagonalize", we mean to left-miltiply Q by an orthogonal matrix, - Q := U Q, - so that the covariance of gradients w.r.t. N, measured - over the 2nd index of N while treating its 1st index as a batch index, - is diagonal. - - Before doing this, also limits the singular values of Q so that the - max,min eig are lr_min_eig,lr_max_eig, and then renormalizing so the - mean is 1.0 and limiting again. - + Computes learning-rate matrices Q for each dim of this tensor. Args: group: dict to look up configuration values p: parameter matrix that we are updating. The learning rate matrices @@ -542,158 +358,87 @@ class LearnedGradient(Optimizer): them. state: state dict for the current parameter """ - # meta_lr is the learning rate for the learning rate matrix. Including the factor - # (group["lr_est_period"] ** 0.5) is intended to make it approximately invariant - # to the lr_est_period (close to convergence). - - lr_mat_min = group["lr_mat_min"] - lr_mat_max = group["lr_mat_max"] - ndim = p.ndim + numel = p.numel() for dim in range(ndim): size = p.shape[dim] if size == 1: continue + param_cov = state[f"param_cov_{dim}"] + U, S, V = _svd(param_cov) + rank = numel // size + rms = self._smooth_param_rms(group, S.sqrt(), rank) + + if random.random() < 0.05: + logging.info(f"Shape={tuple(p.shape)}, dim={dim}, size={size}, rms={rms[::10]}") + Q = state[f"Q_{dim}"] - Q_orig = Q.clone() # TEMP - for i in range(4): - try: - self._diagonalize_lr(group, p, dim, state) - break # break from the loop over i: nothing was raised, so it succeeded. - except: - logging.warning(f"self._diagonalize_lr raised: i={i}, Q.sum()={Q.sum()}: likely " - f"error in SVD. Will try again after random change.") - # Likely torch SVD failed, its implementation for GPU has bugs for some torch versions, - # e.g. 1.7.1. Left-multiply Q by an arbitrary orthogonal matrix, and we'll try again. - # (Q is going to be diagonalized on that axis anyway, so this shouldn't meaningfully - # affect the output). - U, S, V = torch.randn_like(Q_orig).svd() - # U is a random orthogonal matrix. - Q[:] = torch.matmul(U, Q_orig) + Q[:] = (U * rms).t() + + if True: + # This block does the actual diagonalization. + # Suppose the actual parameter matrix p is M, of shape (-1, size), where + # the -1 represents all other tensor dims treated as a batch dimension. + # M_grad is the same shape as M. We could write a pseudo-loss as + # loss = tr(M_grad^T M) + # Because we can decompose M as M == N Q, we can write: + # loss = tr(M_grad^T N Q) = tr(Q M_grad^T N), + # so we can write this at tr(N_grad^T N), + # where N_grad == (Q M_grad^T)^T = M_grad Q^T. + # Now, + # grad_cov == M_grad^T M_grad, + # decaying-averaged over minibatches; this is of shape (size,size). + # Using N_grad = M_grad Q^T, we can write the gradient covariance w.r.t. N + # (which is what we want to diagonalize), as:: + # N_grad_cov = N_grad^T N_grad + # = Q M_grad^T M_grad Q^T + # = Q grad_cov Q^T + # (note: this makes sense because the 1st index of Q is the diagonalized + # index). + + def dispersion(v): + # a ratio that, if the eigenvalues are more dispersed, it will be larger. + return '%.3e' % ((v**2).mean() / (v.mean() ** 2)).item() + + grad_cov = state[f"grad_cov_{dim}"] + N_grad_cov = torch.matmul(Q, torch.matmul(grad_cov, Q.t())) + N_grad_cov = N_grad_cov + N_grad_cov.t() # ensure symmetric + U, S, V = _svd(N_grad_cov) + if random.random() < 0.1: + logging.info(f"Diagonalizing, shape={tuple(p.shape)}, dim={dim}, dispersion " + f"changed from {dispersion(N_grad_cov.diag())} to {dispersion(S)}") + + + # N_grad_cov is SPD, so + # N_grad_cov = U S U^T. + + # Now, we can diagonalize N_grad_cov with: + # U^T N_grad_cov U == S. + # N_grad_cov is a sum of N_grad^T N_grad. + # We know U^T N_grad_cov U is diagonal, so U^T N_grad^T N_grad U is diagonal. + # The linearized pseudo-loss can be written as tr(N_grad^T N_grad). + # This can be written as tr(U U^T N_grad^T N_grad), since U U^T == I, + # + # which we can rearrange as tr(U^T N_grad^T N U). This can be interpreted + # as tr(hat_N_grad hat_N), where: + # hat_N_grad = N_grad U + # hat_N = N U + # (hat_N means \hat{N}, or N with a hat on it). + # So if we interpret hat_N = N U, the gradient covariance w.r.t. + # hat_N will be diagonalized. We also modify Q to hat_Q when + # we modify hat_N, to keep the product M unchanged: + # M = N Q = N U U^T Q = hat_N hat_Q + # This can be done by setting + # hat_Q := U^T Q (eq.10) + # + # This is the only thing we have to do, as N is implicit + # and not materialized at this point. + Q[:] = torch.matmul(U.t(), Q) - def _diagonalize_lr(self, - group: dict, - p: Tensor, - dim: int, - state: dict) -> None: - """ - This piece of code was broken out of _diagonalize_lrs so that we can more easily - recover from errors in torch SVD, which can be quite common in some torch versions; - e.g. in torch.1.7.1+cu101, CUDA SVD generates NaNs for some very reasonably-sized - inputs, like I have a 200 by 200 "bad case", whose min and max elements are - -0.2766 and 0.3444 respectively, that generates NaN's in its U and V factors. - Having closely-spaced singular values seems to present a problem. - """ - lr_mat_min = group["lr_mat_min"] - lr_mat_max = group["lr_mat_max"] - - Q = state[f"Q_{dim}"] - if True: - # This block limits the singular values of Q. - U,S,V = Q.svd() - S_new = S.clamp(lr_mat_min, lr_mat_max) - S_new *= 1.0 / S_new.mean() # normalize so mean is 1.0 - S_new.clamp_(lr_mat_min, lr_mat_max) # apply limits once more. - # Reconstruct Q with the modified S. - Q[:] = torch.matmul(U * S_new, V.t()) - if random.random() < 0.03: - subsample = max(1, S.numel() // 20) - logging.info(f"shape={tuple(p.shape)}, dim={dim}, modified S from {S[::subsample]} to {S_new[::subsample]}") - - if True: - # This block does the actual diagonalization. - # Suppose the actual parameter matrix p is M, of shape (-1, size), where - # the -1 represents all other tensor dims treated as a batch dimension. - # M_grad is the same shape as M. We could write a pseudo-loss as - # loss = tr(M_grad^T M) - # Because we can decompose M as M == N Q, we can write: - # loss = tr(M_grad^T N Q) = tr(Q M_grad^T N), - # so we can write this at tr(N_grad^T N), - # where N_grad == (Q M_grad^T)^T = M_grad Q^T. - # Now, - # grad_cov == M_grad^T M_grad, - # decaying-averaged over minibatches; this is of shape (size,size). - # Using N_grad = M_grad Q^T, we can write the gradient covariance w.r.t. N - # (which is what we want to diagonalize), as:: - # N_grad_cov = N_grad^T N_grad - # = Q M_grad^T M_grad Q^T - # = Q grad_cov Q^T - # (note: this makes sense because the 1st index of Q is the diagonalized - # index). - - def dispersion(v): - # a ratio that, if the eigenvalues are more dispersed, it will be larger. - return '%.3e' % ((v**2).mean() / (v.mean() ** 2)).item() - - grad_cov = state[f"grad_cov_{dim}"] - N_grad_cov = torch.matmul(Q, torch.matmul(grad_cov, Q.t())) - N_grad_cov = N_grad_cov + N_grad_cov.t() # ensure symmetric - U, S, V = N_grad_cov.svd() - if random.random() < 0.1: - logging.info(f"Diagonalizing, shape={tuple(p.shape)}, dim={dim}, dispersion " - f"changed from {dispersion(N_grad_cov.diag())} to {dispersion(S)}") - - # N_grad_cov is SPD, so - # N_grad_cov = U S U^T. - - # Now, we can diagonalize N_grad_cov with: - # U^T N_grad_cov U == S. - # N_grad_cov is a sum of N_grad^T N_grad. - # We know U^T N_grad_cov U is diagonal, so U^T N_grad^T N_grad U is diagonal. - # The linearized pseudo-loss can be written as tr(N_grad^T N_grad). - # This can be written as tr(U U^T N_grad^T N_grad), since U U^T == I, - # - # which we can rearrange as tr(U^T N_grad^T N U). This can be interpreted - # as tr(hat_N_grad hat_N), where: - # hat_N_grad = N_grad U - # hat_N = N U - # (hat_N means \hat{N}, or N with a hat on it). - # So if we interpret hat_N = N U, the gradient covariance w.r.t. - # hat_N will be diagonalized. We also modify Q to hat_Q when - # we modify hat_N, to keep the product M unchanged: - # M = N Q = N U U^T Q = hat_N hat_Q - # This can be done by setting - # hat_Q := U^T Q (eq.10) - # - # This is the only thing we have to do, as N is implicit - # and not materialized at this point. - Q[:] = torch.matmul(U.t(), Q) - - s = Q[:].sum() - if not s == s: # if a NaN was generated: - raise RuntimeError('Likely SVD failure') - - - def _get_matrix_var(self, - x: Tensor, - eps: float) -> Tensor: - """ - Estimate a variance of elements of x, which must be a tensor with exactly - 2 diimensions, neither of which can have size == 1. This is done using - a structured representation, so the variance is a product over the 2 - dimensions (otherwise, it wouldn't be possible to estimate a variance from - a single sample). Eventually we may do this via aggregation across - time as well; this approach seems to give a high variance estimate even - for largish dimension, e.g. if x.shape == (100,100) the 95th percentile - is from about 0.6 to 1.6. - - Args: - x: tensor to estimate the variance of elements of. Require x.ndim == 2, - neither size must be 1. - eps: Epsilon to avoid zero variance: epsilon is dimensionally a rms value. - """ - assert x.ndim == 2 and 1 not in list(x.shape) # TODO: remove? - - var0 = (x**2).mean(dim=1, keepdim=True) + eps*eps - var1 = ((x / var0.sqrt()) ** 2).mean(dim=0, keepdim=True) + eps*eps - # estimate var0 one more time. - var0 = ((x / var1.sqrt()) ** 2).mean(dim=1, keepdim=True) + eps*eps - return var1 * var0 - def _step(self, @@ -725,12 +470,9 @@ class LearnedGradient(Optimizer): # grad_cov_{dim}, which are for periodically updating the # learning-rate matrices. size = grad.shape[dim] - # accumulate some stats for learning the projections - proj_grad = state[f"proj_grad_{dim}"] grad_cov = state[f"grad_cov_{dim}"] this_m = p.transpose(-1, dim).reshape(-1, size) # parameter matrix M this_g = grad.transpose(-1, dim).reshape(-1, size) # M_grad - proj_grad.add_(torch.matmul(this_m.t(), this_g)) # could perhaps accumulate grad_cov less frequently; it's only # needed when we rediagonalize which is not that common. grad_cov.mul_(beta2).add_(torch.matmul(this_g.t(), this_g)) @@ -823,7 +565,8 @@ class LearnedGradient(Optimizer): # direction we want to change canonical to diagonalized index, so have # to transpose. Q = Q.t() - # TODO: could possibly somehow force the output format to be unchanged. + # TODO: could possibly somehow force the output memory format to be + # unchanged. x = x.transpose(-1, dim) x = torch.matmul(x, Q) x = x.transpose(-1, dim) @@ -1059,29 +802,33 @@ class LearnedGradient(Optimizer): - def _smooth_param_diag_var(self, - param_diag: Tensor, - param_pow: float, - param_rel_eps: float, - param_rel_max: float, - param_reverse_cutoff: float) -> Tensor: + def _smooth_param_rms(self, + group: dict, + rms: Tensor, + rank: int) -> Tensor: """ - Applies the smoothing formula to the eigenvalues of the parameter covariance - tensor. - (Actually when we use this function in _get_param_cov, they won't exactly - be the eigenvalues because we diagonalize relative to the grad_cov, - but they will be parameter covariances in certain directions). - """ - # normalize so mean = 1.0 - param_diag = param_diag / (param_diag.mean() + 1.0e-20) - # use 1/(1/x + 1/y) to apply soft-max of param_diag with param_rel_max. - # use param_rel_eps here to prevent division by zero. - ans = 1. / (1. / (param_diag + param_rel_eps) + 1. / param_rel_max) + Smooths and normalizes to mean=1 a tensor of diagonalized parameter rms values for one dimension + of a tensor. This will be used to construct a learning-rate matrix. - # 2nd factor below becomes a 1/param_diag type of function when - # param_diag >> param_reverse_cutoff. - ans = ans * (param_reverse_cutoff / (param_diag + param_reverse_cutoff)) - return ans ** param_pow + Args: + group: dict for configuration values + rms: a tensor with one dimension, representing diagonalized parameter + root-mean-square values. + rank: the assumed rank of the covariance matrix from which rms was derived, + used to decide how much smoothing to apply. This is actually the + minimal rank, if the parameter matrix were stay fixed during training, + but still relevant to know how robust the parameter covariance estimate is. + """ + smooth0 = group["param_rms_smooth0"] + smooth1 = group["param_rms_smooth1"] + eps = group["eps"] + size, = rms.shape + smooth = (smooth0 + + (smooth1 - smooth0) * size / (size + rank)) + mean = rms.mean() + rms += eps + smooth * mean + new_mean = (eps + (smooth + 1) * mean) + return rms / new_mean def _estimate_proj(self, @@ -1241,7 +988,7 @@ class LearnedGradient(Optimizer): param_freqs = param_freqs.tolist() param_periods = param_periods.tolist() - logging.info(f"LearnedGradient._recalibrate_speedup: speedup = {speedup:.2g}, actual_speedup = {actual_speedup:.2g}") + logging.info(f"PrAdam._recalibrate_speedup: speedup = {speedup:.2g}, actual_speedup = {actual_speedup:.2g}") print_info = random.random() < 0.01 i = 0 for p in group["params"]: @@ -1256,6 +1003,27 @@ class LearnedGradient(Optimizer): # Do nothing more, as yet. +def _svd(x: Tensor): + # Wrapper for torch svd that catches errors and re-tries. + randU = None + for i in range(4): + try: + U, S, V = x.svd() + s = U.sum() + if s - s != 0 or (__name__ == '__main__' and random.random() < 0.01): + raise RuntimeError(f"svd failed (or random test), sum={s}") + if randU is not None: + U = torch.matmul(randU.t(), U) + return U, S, V # success + except: + logging.warning(f"svd failed: i={i}, x.shape={tuple(x.shape)}, x.sum()={x.sum()}: likely " + f"error in SVD. Will try again after random change.") + U, S, V = torch.randn_like(x).svd() + x = torch.matmul(U, x) + if randU is None: + randU = U + else: + randU = torch.matmul(U, randU) @@ -1902,8 +1670,8 @@ def _test_eve_cain(): if iter == 0: optim = Eve(m.parameters(), lr=0.003) elif iter == 1: optim = Cain(m.parameters(), lr=0.03) - elif iter == 2: optim = LearnedGradient(m.parameters(), lr=0.03) - elif iter == 3: optim = LearnedGradient(m.parameters(), lr=0.03) + elif iter == 2: optim = PrAdam(m.parameters(), lr=0.03) + elif iter == 3: optim = PrAdam(m.parameters(), lr=0.03) scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False) start = timeit.default_timer() @@ -1965,11 +1733,18 @@ def _test_eve_cain(): logging.info("Normalized output row magnitudes log-stddev: {stddev(((m[2].weight**2).sum(dim=1).sqrt() / output_magnitudes).log())}") - +def _test_svd(): + device = 'cuda' + for i in range(1000): + X = torch.eye(100, device=device) + 1.0e-08 * torch.randn(100, 100, device=device) + U, S, V = _svd(X) + X2 = torch.matmul(U*S, V.t()) + assert torch.allclose(X2, X, atol=0.001) if __name__ == "__main__": logging.getLogger().setLevel(logging.INFO) torch.set_num_threads(1) torch.set_num_interop_threads(1) + #_test_svd() _test_eve_cain() #_test_eden()