diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index 458fbec44..71bc0b463 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -48,9 +48,11 @@ class LearnedGradient(Optimizer): learning the scale on the parameters (we'll keep it >= this size) param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of learning the scale on the parameters (we'll keep it <= this size) + lr_mat_min: Minimum singular value of learning-rate matrices Q. + lr_mat_max: Maximum singular value of learning-rate matrices Q. lr_est_period: The periodicity, in steps, with which we update the learning-rate matrices. - orthogonalize_period: How often we re-orthogonalize the gradient covariance; this - gets multiplied by lr_est_period, i.e. we orthogonalize less frequently + diagonalize_period: How often we re-diagonalize the gradient covariance; this + gets multiplied by lr_est_period, i.e. we diagonalize less frequently than we update the learning-rate matrixes. max_step_scale: Value that determines limit on how large steps can be per element, intended to prevent instability during early steps before the learning-rate @@ -64,11 +66,13 @@ class LearnedGradient(Optimizer): meta_lr_scale=0.1, betas=(0.9, 0.98), eps=1.0e-08, + size_update_period=4, param_min_rms=1.0e-05, param_max_rms=2.0, + lr_mat_min=0.01, + lr_mat_max=4.0, lr_est_period=10, - max_lr_eig=4.0, - orthogonalize_period=4, + diagonalize_period=4, max_step_scale=2.0 ): @@ -79,11 +83,13 @@ class LearnedGradient(Optimizer): meta_lr_scale=meta_lr_scale, betas=betas, eps=eps, + size_update_period=size_update_period, param_min_rms=param_min_rms, param_max_rms=param_max_rms, + lr_mat_min=lr_mat_min, + lr_mat_max=lr_mat_max, lr_est_period=lr_est_period, - max_lr_eig=max_lr_eig, - orthogonalize_period=orthogonalize_period, + diagonalize_period=diagonalize_period, max_step_scale=max_step_scale, ) @@ -112,12 +118,13 @@ class LearnedGradient(Optimizer): size_lr = lr * group["size_lr_scale"] beta1, beta2 = group["betas"] max_lr_eig = group["max_lr_eig"] + min_lr_eig = group["max_lr_eig"] eps = group["eps"] size_update_period = 4 - param_min_eps = group["param_min_rms"] - param_max_eps = group["param_max_rms"] + param_min_rms = group["param_min_rms"] + param_max_rms = group["param_max_rms"] lr_est_period = group["lr_est_period"] - orthogonalize_period = group["orthogonalize_period"] + diagonalize_period = group["diagonalize_period"] max_step_scale = group["max_step_scale"] for p in group["params"]: @@ -158,39 +165,52 @@ class LearnedGradient(Optimizer): # scalar exp_avg_sq. not the same as scale_exp_avg_sq state["exp_avg_sq_scalar"] = torch.zeros_like((), **kwargs) + # exp_avg_sq is in the projected space. It is periodically zeroed, and we + # set zero_step. + state["exp_avg_sq"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + state["zero_step"] = 0 + for dim in range(p.ndim): size = p.shape[dim] if size == 1: # if size == p.numel(), is_one_axis == True above continue - # exp_avg_sq is in the projected space. - state["exp_avg_sq"] = torch.zeros_like( - p, memory_format=torch.preserve_format - ) - - # proj_{dim} will be the learning-rate matrix for this + # Q_{dim}, which we will call q in this comment, will be the learning-rate matrix for this # dimension, times an orthogonal matrix that we'll periodically - # re-estimate when we "reorthogonalize" (this orthogonalizes + # re-estimate when we "rediagonalize" (this diagonalizes # the gradient covariance). - state[f"proj_{dim}"] = torch.eye(size, **kwargs) + # If our parameter matrix M is of shape (..., size), + # we'll view M as being M == torch.matmul(N, q) + # so p can be interpreted as having shape + # (size, size), interpreted as being indexed [diagonalized_coordinate, canonical_coordinate] + # by "diagonalize" we mean that we diagonalize the gradient covariance. + state[f"Q_{dim}"] = torch.eye(size, **kwargs) - # proj_grad_{dim} is the gradient w.r.t. proj_{grad} - # (technically w.r.t a matrix to the left of - # proj_{grad}, we'll project it later); we accumulate - # this for `lr_est_period` steps, and then update - # proj_{dim} and zero this matrix + # proj_grad_{dim} is the gradient w.r.t. `proj`, + # which is a matrix we introduce to the right of p, i.e. + # instead of M == torch.matmul(N, p) we view it as + # M == torch.matmul(torch.matmul(N, torch.matmul(p, proj))) + # or equivalently M == torch.matmul(M_underlying, proj) + # where M_underlying is numerically equal to M, and proj is numerically + # equal to I, but they are viewed as separate variables. + # + # proj_grad_{dim} will be set to the sum of + # torch.matmul(M^T, M_grad) over a number of steps (`lr_est_period` steps), + # and will be zeroed after we call self_update_lrs(). state[f"proj_grad_{dim}"] = torch.zeros(size, size, **kwargs) # grad_cov_{dim} is the covariance of gradients on this axis (without # any co-ordinate changes), treating all other axes as as a batch axis. - # This is needed when we re-orthogonalize, and to compute the + # This is needed when we re-diagonalize, and to compute the # grad_rms_{dim}. We store it as a decaying average, decaying with beta2 # only for purposes of computing the scalar factor on the # learning rate; and, as a result of this, also contributes something - # to the gradient w.r.t. f"lr_{dim}", which is one reason + # to the gradient w.r.t. f"Q_{dim}", which is one reason # why we allocate a variable to keep track of its moving average # instead of just using a temporary and smoothing the scalar factor. state[f"grad_cov_{dim}"] = torch.zeros(size, size, **kwargs) @@ -220,17 +240,21 @@ class LearnedGradient(Optimizer): scale_grads, param_rms, beta1, beta2, step, size_lr, param_min_rms, param_max_rms) - + state["delta"].mul_(beta1) if numel == 1: # For parameters with very few elements we just use a form # of Adam with a scale factor to reflect the overall - # parameter rms. + # parameter rms. Updates delta. self._step_scalar(beta1, beta2, eps, lr, p, grad, state) else: + if step % lr_est_period == 0: self._update_lrs(group, p, state) + if step % (lr_est_period * diagonalize_period) == 0: + self._diagonalize_lrs(group, p, state) + self._zero_exp_avg_sq() self._step(group, p, grad, state) - + p.add_(delta) state["step"] = step + 1 return loss @@ -238,6 +262,7 @@ class LearnedGradient(Optimizer): def _size_update(self, p: Tensor, state: dict, + scale_grad: Tensor, beta1: float, beta2: float, size_lr: float, @@ -258,7 +283,6 @@ class LearnedGradient(Optimizer): tensor p (for enforcing param_min_rms and param_max_rms) beta1: Beta value for decaying gradient stats beta2: Beta value for decaying gradient-squared stats - step: The step index size_lr: The learning rate to apply to `scale` param_min_rms: User-specified minimum on the rms value of p, we will constrain it to >= this param_max_rms: User-specified maximum on the rms value of p, we will constrain it to <= this @@ -268,19 +292,14 @@ class LearnedGradient(Optimizer): # products (p * grad).sum() for the `size_update_period` most recent steps size_update_period, = scale_grads.shape - # correct beta1 and beta2 for the size update period: we will have - # faster decay at this level. The same for the learning rate. - beta1_corr = beta1 ** size_update_period + # correct beta2 for the size update period: we will have + # faster decay at this level. beta2_corr = beta2 ** size_update_period - size_lr_corr = size_lr * size_update_period scale_exp_avg_sq = state["scale_exp_avg_sq"] scale_exp_avg_sq.mul_(beta2_corr).add_( (scale_grads ** 2).mean(), value=1-beta2_corr) - scale_exp_avg = state["scale_exp_avg"] - scale_exp_avg.mul_(beta1_corr).add_( - scale_grads.mean(), value=1-beta1_corr) # The 1st time we reach here is when size_step == 1. size_step = (step + 1) // size_update_period @@ -291,15 +310,17 @@ class LearnedGradient(Optimizer): eps = 1.0e-10 # this value is not critical. denom = scale_exp_avg_sq.sqrt() + eps - scale_step = -size_lr_corr * (bias_correction2 ** 0.5) * scale_exp_avg / denom + scale_step = -size_lr * (bias_correction2 ** 0.5) * scale_grads.sum() / denom is_too_small = (param_rms < param_min_rms) is_too_large = (param_rms > param_max_rms) - scale_step.masked_fill_(is_too_small, size_lr_corr) - scale_step.masked_fill_(is_too_large, -size_lr_corr) + scale_step.masked_fill_(is_too_small, size_lr * size_update_period) + scale_step.masked_fill_(is_too_large, -size_lr * size_update_period) - p.add_(p, alpha=scale_step) + delta = state["delta"] + # the factor of (1-beta1) relates to momentum. + delta.add_(p, alpha=(1-beta1) * scale_step) def _update_lrs(self, @@ -307,7 +328,7 @@ class LearnedGradient(Optimizer): p: Tensor, state: dict) -> None: """ - Updates the learning-rate matrices state["lr_{dim}"]. + Updates the learning-rate matrices state["Q_{dim}"]. Args: group: dict to look up configuration values @@ -322,7 +343,7 @@ class LearnedGradient(Optimizer): meta_lr = group["lr"] * group["meta_lr_scale"] * (group["lr_est_period"] ** 0.5) beta1, beta2 = group["betas"] eps = group["eps"] - moving_g = state["exp_avg"] + delta = state["delta"] ndim = p.ndim for dim in range(ndim): @@ -330,148 +351,247 @@ class LearnedGradient(Optimizer): if size == 1: continue - p_t = p.transpose(dim, -1) - this_p = p_t.reshape(-1, size) + # M is the parameter matrix, of shape (-1, size) where the -1 covers + # all other dimensions of the tensor. + M_full = p.transpose(dim, -1) + M = M.reshape(-1, size) - # proj_grad is a summed gradient, indexed [grad_idx][param_idx] (don't worry - # if this notation about the indexing is unclear, just keep reading). + # proj_grad is a summed gradient, of shape (size, size), + # indexed [old_param_idx][new_param_idx] (the point is, + # param_idx corresponds to the index of the current M, new_param_idx + # corresponds to the index of the M after a small change). # - # Suppose M is our parameter matrix p reshaped to (size, -1) with the -1 being - # all other dims treated as a batch dim, and M_grad is the loss derivative w.r.t. M, + # Suppose M is our parameter matrix p reshaped to (-1, size) with + # the -1 being all other dims of the tensor treated as a batch dim + # and "size" is the size of whatever dimension we are working on + # right now. And suppose M_grad is the loss derivative w.r.t. M + # (for compatibility with things like Torch, we use a convention + # where all loss derivatives/grads have the same shape as the things they + # are derivative with respect to). # - # Then proj_grad is the accumulated value of M_grad M^T, which will be of shape - # (size, size). If we view the parameter matrix as actually (proj M) with - # proj being numerically equal to I, i.e. as - # matmul(proj, M), then instead of the loss being tr(M_grad^T M), we can write - # it as tr(M_grad^T proj M), which can also be written as tr(M M_grad^T proj) - # = tr( (M_grad M^T)^T proj), where we can write M_grad M^T == proj_grad. + # proj_grad is the accumulated value, on the last few steps, of M M_grad^T, + # which will be of shape (size, size). For purposes of our update, if we view + # the parameter matrix M as (M_underlying proj) with + # proj being numerically equal to I but treated as a variable, i.e. as + # matmul(M_underlying, proj), then instead of the loss being tr(M_grad^T M), we can write + # it as tr(M_grad^T M_underlying proj), which can also be written as tr(proj_grad^T proj) + # where proj_grad == M_underlying^T M_grad. # # Now, proj_grad is convenient to compute but it's not in a very normalized - # space, we want to normalize it first. We're going to view M as - # M = T N, - # where T is our proj_{dim} matrix (see the code), and N == T^{-1} M. - # It's more convenient to have the projection between T and N, i.e. have proj2, - # again numerically equal to I, so - # M == T proj2 N. - # [Be careful because when we use T as a superscript it means transpose). - # We want the derivative w.r.t. proj2. - # So suppose the loss is - # tr(M_grad^T M) == tr(M_grad^T T proj2 N) == tr(N M_grad^T T proj2) - # == tr( (T^T M_grad N^T)^T proj2 ) - # == tr( (T^T M_grad N^T)^T proj2 ) - # and using N == T^{-1} M, so N^T = M^T T^{-T}, - # == tr( (T^T M_grad M^T T^{-T})^T proj2 ) - # so we can view the thing in parentheses in the above expression as proj2_grad, i.e. - # proj2_grad = T^T M_grad M^T T^{-T} - # proj2_grad = T^T proj_grad T^{-T} - # - # note: p = state[f"proj_{dim}"] actually corresponds to T^T, since it has a - # shape interpretable as (diagonalized_index, canonical_index). - # - # So proj2_grad = p proj_grad p^{-1} (eq.1) + # space, we want to normalize it before doing gradient descent on it. + # We're going to view M as + # M == N Q, + # where Q is our Q_{dim} "learning-rate" matrix, of shape (size, size) indexed + # [diagonalized_index, canonical_index] (see the code), and N == M Q^{-1}. + # It's more convenient to do gradient descent on a `proj` matrix that's between N and Q, + # i.e. define `proj2`, also numerically equal to I but treated as a variable, + # and have: + # M == N proj2 Q + # We want to obtain the derivative w.r.t. proj2. + # So the linearized pseudo-loss is + # tr(M_grad^T M) == tr(M_grad^T N proj2 Q) == tr(Q M_grad^T N proj2), + # which we can view as tr(proj2_grad^T proj2), where + # proj2_grad == (Q M_grad^T N)^T == N^T M_grad Q^T == Q^{-T} M^T M_grad Q^T + # proj2_grad = Q^{-T} proj_grad Q^T (eq.1) # # After computing proj2_grad we compute/update a factorized representation of # its variance, factorized over its 2 axes, and do an Adam-type update # for it, except without momentum at this level (momentum will be applied # later). This gives us proj2_delta, according to the - # equation: proj2 = I + proj2_delta. + # equation: + # proj2 = I + proj2_delta + # (again: proj2_delta is derived from proj2_grad using an Adam-style update). # - # This is the point at which we control the magnitude of the parameter change. - # - # We now have to update the factor T using proj2, as in: - # T <-- T (I + proj2_delta) - # T <-- T + T proj2_delta - # and since p == T^T, this becomes: - # p <-- p + proj2_delta^T p - # or equivalently: - # p_delta = proj2_delta^T p (eq.2) + # We now have to update the "learning-rate matrix" Q using proj2, as in + # Q := proj2 Q + # Q := (proj2_delta + I) Q, or: + # Q += proj2_delta Q (eq.2) # # We also need to update M itself; this is easiest done by computing proj_delta - # (which is the deviation of proj from I). The relationship between proj and proj2 + # (which is the difference proj from I). The relationship between proj and proj2 # is given by equating - # proj T N = T proj2 N, - # proj T = T proj2, - # proj = T proj2 T^{-1} - # proj = p^T proj2 p^{-T} . - # we can apply the above formula directly to proj2_delta, since the "I" part just - # becomes a no-op, i.e. - # proj_delta = p^T proj2_delta p^{-T}. - # ... and to turn this back into a delta on M, we'll do: - # M_delta = proj_delta M. - - + # M == N proj2 Q == N Q proj, + # so proj2 Q == Q proj + # proj == Q^{-1} proj2 Q + # and subtracting out the constant "I" part, + # proj_delta == Q^{-1} proj2_delta Q (eq.3) + # ... so the change in Q is given by: + # Q := Q (I + proj_delta), + # Q += Q proj_delta + # Q_delta = Q proj_delta + # and looking at how we compute proj_delta (eq.3), we notice the right-hand subexpression + # is the sam as p_delta, i.e.: + # Q_delta = proj2_delta Q (eq.4) + # + # and then because we view the parameter matrix as + # "M proj", + # with proj being numerically equal to I but learned here, + # it becomes M (I + proj_delta) = M + M proj_delta. + # So the update to M is: + # M += M proj_delta + # or equivalently: + # M_delta = M proj_delta (eq.5) proj_grad = state[f"proj_grad_{dim}"] - p = state[f"proj_{dim}"] + Q = state[f"Q_{dim}"] - # torch.linalg.solve(A, B) returns X such that AX=B, - # meaning X = A^{-1} B. We want to compute X = proj_grad p^{-1}, - # so X^T = p^{-T} proj_grad^T. - X = torch.linalg.solve(p.t(), proj_grad.t()).t() - assert torch.allclose(proj_grad, torch.matmul(X, p)) # TEMP. + # Next we want to implement (eq.1), proj2_grad = Q^{-T} proj_grad Q^T + # torch.linalg.solve(A, B) returns A^{-1} B. + Q_invt_proj_grad = torch.linalg.solve(Q.t(), proj_grad) + assert torch.allclose(proj_grad, torch.matmul(Q.t(), Q_invt_proj_grad)) # TODO: remove - # See (eq.1), the next line implements proj_grad = p proj_grad p^{-1} - proj2_grad = torch.matmul(p, X) + # (eq.1), proj2_grad = Q^{-T} proj_grad Q^T + proj2_grad = torch.matmul(Q_invt_proj_grad, Q.t()) - # for now just estimate it from proj2_grad, later we will try aggregating it - # across iterations. This is like exp_avg_sq in Adam. + # for now just estimate proj2_grad_var from proj2_grad; later, we + # may try aggregating it across iterations. This is like exp_avg_sq + # in Adam. proj2_grad_var = self._get_matrix_var(proj_grad, group["eps"]) denom = proj2_grad_var.sqrt() - alpha = -meta_lr * (1-beta1) - - # we'll take into account alpha later on. + # we'll take into account the factor of -meta_lr * (1-beta1) later on; + # actually proj2_delta is the negative of a parameter change. proj2_delta = proj2_grad_var / denom - p_delta = torch.matmul(proj2_delta.t(), p) + # See (eq.4), Q_delta = proj2_delta q + Q_delta = torch.matmul(proj2_delta, Q) + # See (eq.3), proj_delta = Q^{-1} proj2_delta Q = Q^{-1} Q_delta + proj_delta = torch.linalg.solve(Q, Q_delta) + M_delta = torch.matmul(M, proj_delta) # (eq.5), M_delta = M proj_delta + alpha = -meta_lr * (1-beta1) - # To update the learning rate matrices, we are asking the question, "What if, - # on the previous step, we had used a slightly different learning-rate matrix? - # How would that have affected the loss function on the current step?" - moving_d = self._multiply_by_lr(moving_g) + M_delta_full = M_delta.reshape(M_full.shape) - # moving_d_rms is an rms value of moving_d computed via inner - # products with the grad, which is a basis-independent way of - # computing the rms value. you have to view this as the rms times - # an arbitrary multiplicative factor. in the "real" update we don't - # actually do the length normalization like this because it's slow, - # we do it a different, simpler way; but here we do it this way - # because we are computing gradients and it's going to make it more - # sensitive to the norm used. If we used the regular norm on moving_d, it - # might concentrate all the movement in "don't-care" dimensions in - # parameter space, assuming such dimensions exist. - moving_d_rms = (self._multiply_by_grad_cov(moving_d) * moving_d).mean().sqrt() + this_delta = M_delta_full.transpose(dim, -1) + # we'l be applying momentum on delta, so multiply by (1-beta1) to + # get the total magnitude of the change right. The 2 changes below + # are the final changes, the only 2 we make in this loop that have + # side effects. + delta.add_(this_delta, alpha=-meta_lr*(1-beta1)) + # there is no momentum on Q. + Q.add_(Q_delta, alpha=-meta_lr) - # moving_d_norm is the parameter "delta" moving_d, length-normaized but still with - # an arbitrary multiplicative constant. We don't care about this constant - # because we are going to normalize the grad length when we do the update; - # but it's important to have normalized by dividing by moving_d_rms because it - # makes the gradient "scale-independent" in the sense that it is zero in - # the direction of increasing or decreasing the parameter scale. - moving_d_norm = moving_d / moving_d_rms + def _zero_exp_avg_sq(self, + state: dict) -> None: + """ + Zero the exp_avg_sq stats, and set state["zero_step"] to state["step"] + """ + state["exp_avg_sq"].zero_() + state["zero_step"] = state["step"] - # view the current parameter p as [something] - alpha * - # moving_d_norm, where alpha contains the learning rate and other - # factors, but we don't care about the numerical value of alpha - # because we're going to divide by the gradient norm anyway. The - # (pseudo-) loss function would be (p * grad).sum(), which ignoring - # [something] as it contributes no gradient, becomes -alpha * - # (moving_d_norm * grad).sum(), and ignoring the scale alpha, that - # is -(moving_d_norm * grad).sum(), so we call this neg_loss as it's the - # negative of the loss. - neg_loss = (grad * moving_d_norm).sum() + def _diagonalize_lrs(self, + group: dict, + p: Tensor, + state: dict) -> None: + """ + Diagonalizes the learning-rate matrices state["Q_{dim}"], and also + applies the min- and max- singular value constraints. This Q + is applied to the parameter matrix M such that we interpret + M == N Q. + By "diagonalize", we mean to left-miltiply Q by an orthogonal matrix, + Q := U Q, + so that the covariance of gradients w.r.t. N, measured + over the 2nd index of N while treating its 1st index as a batch index, + is diagonal. - # after the following, there will grads in state[f"lr_{dim}"] - # that are the negative of loss function grads, i.e. we want to go - # in the forward grad direction. - neg_loss.backward() + Before doing this, also limits the singular values of Q so that the + max,min eig are lr_min_eig,lr_max_eig, and then renormalizing so the + mean is 1.0 and limiting again. + + Args: + group: dict to look up configuration values + p: parameter matrix that we are updating. The learning rate matrices + are actually factors of p, so p itself will change when we change + them. + state: state dict for the current parameter + """ + # meta_lr is the learning rate for the learning rate matrix. Including the factor + # (group["lr_est_period"] ** 0.5) is intended to make it approximately invariant + # to the lr_est_period (close to convergence). + + lr_mat_min = group["lr_mat_min"] + lr_mat_max = group["lr_mat_max"] + + ndim = p.ndim + + for dim in range(ndim): + size = p.shape[dim] + if size == 1: + continue + Q = state[f"Q_{dim}"] + + if True: + # This block limits the singular values of Q. + try: + U,S,V = Q.svd() + except: + # if SVD fails for some reason we can rotate Q by something arbitrary on the + # left, as we're going to rotate here anyway, and try again. + logging.warn("Error doing SVD on Q. Trying again after rotation.") + U,_,_ = torch.randn_like(Q).svd() # Create random orthogonal matrix. + Q[:] = torch.matmul(U, Q) + U,S,V = Q.svd() + S_new = S.clamp(lr_mat_min, lr_mat_max) + S_new *= 1.0 / S_new.mean() # normalize so mean is 1.0 + S_new.clamp_(lr_mat_min, lr_mat_max) # apply limits once more. + # Reconstruct Q with the modified S. + Q[:] = torch.matmul(U * S, V.t()) + if random.random() < 0.1: + subsample = max(1, S.numel() // 20) + logging.info(f"shape={tuple(p.shape)}, dim={dim}, modifed S from {S[::subsample]} to {S_new[::subsample]}") + + if True: + # This block does the actual diagonalization. + # Suppose the actual parameter matrix p is M, of shape (-1, size), where + # the -1 represents all other tensor dims treated as a batch dimension. + # M_grad is the same shape as M. We could write a pseudo-loss as + # loss = tr(M_grad^T M) + # Because we can decompose M as M == N Q, we can write: + # loss = tr(M_grad^T N Q) = tr(Q M_grad^T N), + # so we can write this at tr(N_grad^T N), + # where N_grad == (Q M_grad^T)^T = M_grad Q^T. + # Now, + # grad_cov == M_grad^T M_grad, + # decaying-averaged over minibatches; this is of shape (size,size). + # Using N_grad = M_grad Q^T, we can write: + # N_grad_cov = Q M_grad^T M_grad Q^T + # = Q grad_cov Q^T + # (note: this makes sense because the 1st index of Q is the diagonalized + # index). + + grad_cov = state[f"grad_cov_{dim}"] + N_grad_cov = torch.matmul(Q, torch.matmul(grad_cov, Q)) + N_grad_cov = N_grad_cov + N_grad_cov.t() # ensure symmetric + U, S, V = N_grad_cov.svd() + + # N_grad_cov is SPD, so + # N_grad_cov = U S U^T. + + # Now, we can diagonalize N_grad_cov with: + # U^T N_grad_cov U == S. + # N_grad_cov is a sum of N_grad^T N_grad. + # So U^T N_grad^T N_grad U is diagonal. + # The linearized pseudo-loss can be written as tr(N_grad^T N_grad). + # This can be written as tr(U U^T N_grad^T N_grad), since U U^T == I, + # + # which we can rearrange as tr(U^T N_grad^T N U). This can be interpreted + # as tr(hat_N_grad hat_N), where: + # hat_N_grad = N_grad U + # hat_N = N U + # (hat_N means \hat{N}, or N with a hat on it). + # So if we interpret hat_N = N U, the gradient covariance w.r.t. + # hat_N will be diagonalized. We also modify Q to hat_Q when + # we modify hat_N, to keep the product M = N Q = N U U^T Q = hat_N hat_Q + # This can be done by setting + # hat_Q = U^T Q (eq.10) + # + # This is the only thing we have to do, as N is implicit + # and not materialized at this point. + Q[:] = torch.matmul(U.t(), Q) - for dim in range(ndim): - if p.shape[dim] != 1: - self._update_lr(state[f"lr_{dim}"], meta_lr) def _get_matrix_var(self, x: Tensor, @@ -581,14 +701,18 @@ class LearnedGradient(Optimizer): for dim in range(grad.ndim): + # This block accumulates the statistics proj_grad_{dim} and + # grad_cov_{dim}, which are for periodically updating the + # learning-rate matrices. size = grad.shape[size] # accumulate some stats for learning the projections proj_grad = state[f"proj_grad_{dim}"] grad_cov = state[f"grad_cov_{dim}"] - this_p = p.transpose(-1, dim).reshape(-1, size) - this_g = g.transpose(-1, dim).reshape(-1, size) - proj_grad.add_(torch.matmul(this_g.t(), this_p)) - # could perhaps accumulate grad_cov less frequently. + this_m = p.transpose(-1, dim).reshape(-1, size) # parameter matrix M + this_g = g.transpose(-1, dim).reshape(-1, size) # M_grad + proj_grad.add_(torch.matmul(this_m.t(), this_g)) + # could perhaps accumulate grad_cov less frequently; it's only + # needed when we rediagonalize which is not that common. grad_cov.mul_(beta2).add_(torch.matmul(this_g.t(), this_g)) grad = self._project(grad, forward=True, state) @@ -624,26 +748,25 @@ class LearnedGradient(Optimizer): alpha = state["param_rms"] * (1-beta1) * lr / denom delta = state["delta"] - delta.mul_(beta1).add_(grad, alpha=alpha) + delta.add_(grad, alpha=alpha) param.add_(delta) def _step_scalar(self, - beta1: float, - beta2: float, - eps: float, - lr: float, - p: Tensor, - grad: Tensor, - state: dict): + beta1: float, + beta2: float, + eps: float, + lr: float, + p: Tensor, + grad: Tensor, + state: dict): """ - A form of the core update for tensors with a small number of elements, e.g. scalars. This is - Adam where, if the numel() > 1, the learning rate is proportional to the parameter rms value. + A form of the core update for tensors with a small number of elements, + e.g. scalars. This is Adam where, if the numel() > 1, the learning rate + is proportional to the parameter rms value. """ - exp_avg = state["exp_avg"] exp_avg_sq = state["scalar_exp_avg_sq"] - exp_avg.mul_(beta1).add_(grad, alpha=1-beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1-beta2) @@ -651,11 +774,10 @@ class LearnedGradient(Optimizer): # slower update at the start will help stability anyway. bias_correction2 = 1 - beta2 ** (state["step"] + 1) denom = (exp_avg_sq.sum() / (exp_avg_sq.numel() * bias_correction2)).sqrt() + eps - alpha = -lr / denom - # for now not supporting p.numel() > 1. - #if p.numel() > 1: - # alpha = alpha * state["param_rms"] - p.add_(exp_avg, alpha=alpha) + alpha = -lr * (1-beta1) / denom + + delta = state["delta"] + delta.add_(grad / denom, alpha=alpha) def _project(self, @@ -670,19 +792,19 @@ class LearnedGradient(Optimizer): Args: x: The tensor to project, e.g. a gradient or a parameter change. forward: if True, go in the forward directin (from canonical to diagnonalized - co-ordinates; if False, the reverse. They differ by a transpose, + co-ordinates); if False, the reverse. They differ by a transpose, not an inverse, and do not make a round trip. """ for dim in range(x.ndim): if x.shape[dim] == 1: continue - proj = state[f"proj_{dim}"] - if forward: - # dimension 0 of `proj` is diagonalized co-ordinate. - proj = proj.t() + Q = state[f"Q_{dim}"] + if not forward: + # Q is indexed [canonical_index, diagonalized_index] + Q = Q.t() # TODO: could possibly somehow force the output format to be unchanged. x = x.transpose(-1, dim) - x = torch.matmul(x, proj) + x = torch.matmul(x, Q) x = x.transpose(-1, dim) return x