# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey) # # See ../LICENSE for clarification regarding multiple authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from collections import defaultdict from typing import List, Optional, Union, Tuple, List from lhotse.utils import fix_random_seed import torch from scaling import ActivationBalancer import random from torch import Tensor from torch.optim import Optimizer import logging class BatchedOptimizer(Optimizer): """ This class adds to class Optimizer the capability to optimize parameters in batches: it will stack the parameters and their grads for you so the optimizer can work on tensors with an extra leading dimension. Args: params: """ def __init__(self, params, defaults): super(BatchedOptimizer, self).__init__(params, defaults) def batched_params(self, param_group): """ This function returns an iterator over tuples (p, state), where p is a `fake` parameter that is stacked (over axis 0) from real parameters that share the same shape, and its gradient is also stacked; `state` is the state corresponding to this fake parameter (it will be physically located in the "state" for one of the real parameters, the last one that has any particular shape and dtype). The idea is, instead of doing: for p in group["params"]: state = self.state[p] ... you can do: for p, state in self.batched_params(group["params"]): ... Args: group: a parameter group, which is a list of parameters; should be one of self.groups. """ batches = defaultdict(list) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter for p in param_group: assert p.grad is not None and "All parameters must have a grad when you call optimizer.step()" assert not p.grad.is_sparse and "Sparse gradients not supported." key = (str(p.dtype), *p.shape) batches[key].append(p) for p in param_group: key = (str(p.dtype), *p.shape) batch = batches[key] if p is batch[0]: # if this is the 1st param in the batch... # we arbitrarily store the state in the # state corresponding to the 1st parameter in the # group. class Optimizer will take care of saving/loading state. state = self.state[p] p_stacked = torch.stack(batch) grad = torch.stack([p.grad for p in batch]) p_stacked.grad = grad yield p_stacked, state # <-- calling code will do the actual optimization here! # Now un-stack the parameter changes for i,p in enumerate(batch): p.copy_(p_stacked[i]) class PrAdam(BatchedOptimizer): """ Implements 'Projected Adam', a variant of Adam with projections for each dim of each tensor. Args: params: The parameters or param_groups to optimize (like other Optimizer subclasses) lr: The learning rate. We will typically use a learning rate schedule that starts at 0.03 and decreases over time, i.e. much higher than other common optimizers. betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad only for scalar tensors. Must satisfy 0 < beta <= beta2 < 1. size_lr_scale: A scaling factor on the learning rate, that we use to update the scale of each parameter tensor. If each parameter were decomposed as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, size_lr_scale is the scaling factor on the learning rate of p_scale. param_rms_smooth0: Limiting value of smoothing proportion for parameter matrix, as assumed rank of param covariance [==product of sizes on the other tensor dims] approaches 0. param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of param covariance equals the dimension of the covaraince matrix. param_rms_smooth{0,1} determine the smoothing proportions for other conditions. cov_min,cov_max: [IMPORTANT] 5-tuples of minimums and maximums of the diagonal values of covariance matrices, after normalizing to unit-mean. The first 3 are for smoothing the parameter covariance, normalized in 3 different ways: (1) relative to its own diagonal (in a basis that diagonalizes the grad covariance); (2) multiplied by the grad covariance, (3) in the canonical basis. (4) is for smoothing the grad covariance used for (2) (5) is for smoothing the inverse Z^{-1} final learning-rate matrix Z relative to its own diagonal. Only the cov_min[4] is actually used, we ignore cov_max[4] cov_pow: This was mainly added for development and experimentation purposes; it allows you to smooth the parameter covariance matrices at the stages (1), (2), (3) of smoothing mentioned above, and also the gradient covariance matrix used in stage (2) of smoothing. If the 1st 3 values are not 1.0 it will cause a measurable speed penalty because it requires SVD. Recommend to leave all these at 1.0. eps: A general-purpose epsilon to prevent division by zero denom_rel_eps: An epsilon value used to keep the elements of denominator (exp_avg_grad_sqrt.sqrt()) not too small relative to the mean value; this will limit the speed of update along directions with very small gradients. param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of learning the scale on the parameters (we'll constrain the rms of each non-scalar parameter tensor to be >= this value) param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of learning the scale on the parameters (we'll constrain the rms of each non-scalar parameter tensor to be <= this value) scalar_max: Maximum absolute value for scalar parameters (applicable if your model has any parameters with numel() == 1) size_update_period: The periodicity, in steps, with which we update the size (scale) of the parameter tensor. This is provided to save a little time in the update. lr_update_period: [IMPORTANT]: A 2-tuple of ints that Determines the periodicity, in steps, with which we update the learning-rate matrices. The first number is the periodicity at the start of training, the second number is the periodicity later in training, and we gradually increase from one to the other. The reason for such a complicated schedule is that updating the learning rate matrices is very slow, principally because it requires SVD, and SVD seems to have quite slow implementations. max_block_size: [IMPORTANT] The maximum block size in block-diagonal co-ordinate transformations. You can probably set this to 512 or 1024. Larger values will require more MEMORY and may be a bit slower, but should lead to better optimization performance. """ def __init__( self, params, lr=3e-02, betas=(0.9, 0.98), size_lr_scale=0.1, cov_min=(0.025, 0.0025, 0.02, 0.0001, 0.05), cov_max=(10.0, 80.0, 5.0, 400.0, 100.0), cov_pow=(1.0, 1.0, 1.0, 1.0), param_rms_smooth0=0.4, param_rms_smooth1=0.2, eps=1.0e-08, denom_rel_eps=1.0e-05, param_min_rms=1.0e-05, param_max_rms=2.0, scalar_max=2.0, size_update_period=4, lr_update_period=(200, 1000), # grad_cov_period does not affect speed very much. If __name__ == # "__main__", keeping this coprime to 20 which is the number of # sample pairs used in the test code grad_cov_period=(3 if __name__ == "__main__" else 5), max_block_size=1024, ): defaults = dict( lr=lr, size_lr_scale=size_lr_scale, cov_min=cov_min, cov_max=cov_max, cov_pow=cov_pow, param_rms_smooth0=param_rms_smooth0, param_rms_smooth1=param_rms_smooth1, betas=betas, eps=eps, denom_rel_eps=denom_rel_eps, param_min_rms=param_min_rms, param_max_rms=param_max_rms, scalar_max=scalar_max, size_update_period=size_update_period, lr_update_period=lr_update_period, grad_cov_period=grad_cov_period, max_block_size=max_block_size, ) super(PrAdam, self).__init__(params, defaults) def __setstate__(self, state): super(PrAdam, self).__setstate__(state) @torch.no_grad() def step(self, closure=None): """Performs a single optimization step. Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ loss = None if closure is not None: with torch.enable_grad(): loss = closure() batch = True for group in self.param_groups: for p, state in self.batched_params(group["params"]): # Perform optimization step grad = p.grad if grad.is_sparse: raise RuntimeError( "PrAdam optimizer does not support sparse gradients" ) # State initialization if len(state) == 0: self._init_state(group, p, state) self._step_one_batch(group, p, state) return loss def _init_state(self, group: dict, p: Tensor, state: dict): """ Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p is actually the batch dimension, corresponding to batched-together parameters of a given shape. Args: group: Dict to look up configuration values. p: The parameter that we are initializing the state for, actually will be stacked-together "real" parameters state: Dict from string to whatever state we are initializing """ eps = group["eps"] size_update_period = group["size_update_period"] max_block_size = group["max_block_size"] state["step"] = 0 kwargs = {'device':p.device, 'dtype':p.dtype} # 'delta' implements conventional momentum. There are # several different kinds of update going on, so rather than # compute "exp_avg" like in Adam, we store and decay a # parameter-change "delta", which combines all forms of # update. this is equivalent to how it's done in Adam, # except for the first few steps. state["delta"] = torch.zeros_like( p, memory_format=torch.preserve_format ) batch_size = p.shape[0] numel = p.numel() // batch_size if numel > 1: # "param_rms" just periodically records the scalar root-mean-square value of # the parameter tensor. # it has a shape like (batch_size, 1, 1, 1, 1) param_rms = _mean(p**2, exclude_dims=[0], keepdim=True).sqrt() + eps state["param_rms"] = param_rms state["scale_exp_avg_sq"] = torch.zeros_like(param_rms) state["scale_grads"] = torch.zeros(size_update_period, *param_rms.shape, **kwargs) # exp_avg_sq is in the projected space. It is periodically zeroed, and we # set zero_step whenever this happens. state["exp_avg_sq"] = torch.zeros_like( p, memory_format=torch.preserve_format ) # ignore_rank1_dims = True means we don't do our basis-changing update # on tensors like biases that only have one non-trivial dimension. # "rank1" refers to the rank of our estimate of the parameter # covariance, if we were to estimate it just from the current parameter # value. ignore_rank1_dims = True trivial_update = True for dim in range(1, p.ndim): size = p.shape[dim] if size == 1 or (size == numel and ignore_rank1_dims): continue trivial_update = False state["trivial_update"] = trivial_update # so we know whether to zero exp_avg_sq stats. if trivial_update: return # "zero_step" being a member of state is the sign that this parameter has # at least one dim that has a projection. It also records the most recent # step on which we zeroed state["exp_avg_sq"] state["zero_step"] = 0 # last_param_scale_update records the last time we updated the part of the learning rate # matrices that relates to the parameter covariance; we avoid doing this too often # as it doesn't change very rapidly and the SVD is quite slow. state["last_param_scale_update"] = -1 for dim in range(1, p.ndim): size = p.shape[dim] if size == 1 or (size == numel and ignore_rank1_dims): continue # Most of the time, (num_blocks, block_size) will equal (1, size) num_blocks, block_size = self._get_block_size(size, max_block_size) # Q_{dim} is be the learning-rate matrix for this # dimension, a matrix of indexed [batch_idx, block_idx, diagonalized_coordinate, canonical_coordinate]. # this is used twice in the update step, once transposed and once without transpose. Q = torch.eye(block_size, block_size, **kwargs).unsqueeze(0).unsqueeze(0).expand( batch_size, num_blocks, block_size, block_size).contiguous() state[f"Q_{dim}"] = Q # cov_{dim} is the averaged-over-time covariance of parameters on this dimension, treating # all other dims as a batch axis. Also initialize as identity. state[f"param_cov_{dim}"] = Q.clone() # grad_cov_{dim} is the covariance of gradients on this axis (without # any co-ordinate changes), treating all other axes as as a batch axis. # This is needed when we re-diagonalize, and to compute the # grad_rms_{dim}. We store it as a decaying average, decaying with beta2 # only for purposes of computing the scalar factor on the # learning rate; and, as a result of this, also contributes something # to the gradient w.r.t. f"Q_{dim}", which is one reason # why we allocate a variable to keep track of its moving average # instead of just using a temporary and smoothing the scalar factor. state[f"grad_cov_{dim}"] = torch.zeros_like(Q) def _get_block_size(self, size: int, max_block_size: int) -> Tuple[int,int]: """ Returns information about the block size for a block-diagonal structure of covariances and co-ordinate transformations. size: The size of the parameter tensor on one of its dimensions, e.g. a channel dim or kernel size max_block_size: The maximum block size allowed (of the blocks in a block-diagonal structure) Returns: (num_blocks, block_size) where nun_blocks * block_size == size and block_size <= max_block_size """ for num_blocks in range(1, size): block_size = size // num_blocks if block_size * num_blocks == size and block_size <= max_block_size: return (num_blocks, block_size) assert False, (size, max_block_size) # code error or e.g. negative or # zero inputs def _step_one_batch(self, group: dict, p: Tensor, state: dict): """ Do the step for one parameter, which is actually going to be a batch of `real` parameters, with dim 0 as the batch dim. Args: group: dict to look up configuration values p: parameter to update (actually multiple parameters stacked together as a batch) state: state-dict for p, to look up the optimizer state """ lr = group["lr"] size_update_period = group["size_update_period"] grad_cov_period = group["grad_cov_period"] eps = group["eps"] beta1 = group["betas"][0] grad = p.grad step = state["step"] delta = state["delta"] delta.mul_(beta1) batch_size = p.shape[0] numel = p.numel() // batch_size if numel > 1: # Update the size/scale of p, and set param_rms scale_grads = state["scale_grads"] scale_grads[step % size_update_period] = _sum(p * grad, exclude_dims=[0], keepdim=True) if step % size_update_period == size_update_period - 1: param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..) param_rms.copy_(_mean(p ** 2, exclude_dims=[0], keepdim=True).sqrt().clamp_(min=eps)) if step > 0: # self._size_update() learns the overall scale on the # parameter, by shrinking or expanding it. self._size_update(group, scale_grads, p, state) if numel == 1: # For parameters with 1 element we just use regular Adam. # Updates delta. self._step_scalar(group, p, state) else: if self._is_lr_update_step(group, state): self._update_param_cov(group, p, state) self._compute_bases(group, p.shape, state) self._zero_exp_avg_sq(state) if step % grad_cov_period == 0: self._update_grad_cov(group, p, state) self._step(group, p, state) state["step"] = step + 1 def _size_update(self, group: dict, scale_grads: Tensor, p: Tensor, state: dict) -> None: """ Called only where p.numel() > 1, this updates the scale of the parameter. If we imagine: p = underlying_param * scale.exp(), and we are doing gradient descent on underlying param and on scale, this function does the update on `scale`. Args: group: dict to look up configuration values scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing grads w.r.t. the scales. p: The parameter to update state: The state-dict of p """ param_rms = state["param_rms"] beta1, beta2 = group["betas"] size_lr = group["lr"] * group["size_lr_scale"] param_min_rms = group["param_min_rms"] param_max_rms = group["param_max_rms"] step = state["step"] batch_size = p.shape[0] size_update_period = scale_grads.shape[0] # correct beta2 for the size update period: we will have # faster decay at this level. beta2_corr = beta2 ** size_update_period scale_exp_avg_sq = state["scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..) scale_exp_avg_sq.mul_(beta2_corr).add_( (scale_grads ** 2).mean(dim=0), # mean over dim `size_update_period` alpha=1-beta2_corr) # shape is (batch_size, 1, 1, ...) # The 1st time we reach here is when size_step == 1. size_step = (step + 1) // size_update_period bias_correction2 = 1 - beta2_corr ** size_step # we don't bother with bias_correction1; this will help prevent divergence # at the start of training. eps = 1.0e-10 # this value is not critical. denom = scale_exp_avg_sq.sqrt() + eps # shape: (batch_size, 1, 1, ..) scale_step = -size_lr * (bias_correction2 ** 0.5) * scale_grads.sum(dim=0) / denom is_too_small = (param_rms < param_min_rms) is_too_large = (param_rms > param_max_rms) scale_step.masked_fill_(is_too_small, size_lr * size_update_period) scale_step.masked_fill_(is_too_large, -size_lr * size_update_period) delta = state["delta"] # the factor of (1-beta1) relates to momentum. delta.add_(p * scale_step, alpha=(1-beta1)) def _update_param_cov(self, group: dict, p: Tensor, state: dict) -> None: """ Updates state[f"param_cov_{dim}"] for each dim of tensor p (except batch and trivial and rank-1 dims) """ eps = group["eps"] # zero_step is always the last time we called _update_param_cov. # Our aim is to compute the parameter covariance averaged over all time # (in steps) from the start, so "this_weight" that we give to this # new covariance is proportional to the interval of time it represents, # i.e. the interval from zero_step until step, while the existing "param_cov" # represents the interval from step until zero_step. # The min of 0.5 is a special case to ensure that the "torch.eye" we initialized # param_cov with gets some weight on the 1st time we call this. zero_step = state["zero_step"] step = state["step"] this_weight = min(0.5, (step - zero_step) / step) batch_size = p.shape[0] numel = p.numel() // batch_size for dim in range(1, p.ndim): size = p.shape[dim] try: param_cov = state[f"param_cov_{dim}"] except KeyError: continue # e.g. size == 1 or size == numel # param_cov shape: (batch_size, num_blocks, block_size, block_size) (batch_size, num_blocks, block_size, block_size) = param_cov.shape # M: (batch_size, num_blocks, x, block_size) M = p.transpose(dim, -1).reshape(batch_size, -1, num_blocks, block_size).transpose(1, 2) # will be batched matrix multiply. shape of this_param_cov: # (batch_size, num_blocks, block_size, block_size) this_param_cov = torch.matmul(M.transpose(2, 3), M) # normalize scale of this param_cov, in case parameter scale # changes significantly during training, which would cause some # parts of the training timeline to be more highly weighted. # shape of expression on r.h.s of "/=" has shape (batch_size, 1, 1, 1) this_param_cov /= _mean(_diag(this_param_cov), exclude_dims=[0], keepdim=True).unsqueeze(-1) + eps param_cov.mul_(1-this_weight).add_(this_param_cov, alpha=this_weight) def _zero_exp_avg_sq(self, state: dict) -> None: """ Zero the exp_avg_sq stats, and set state["zero_step"] to state["step"] """ state["exp_avg_sq"].zero_() state["zero_step"] = state["step"] def _is_lr_update_step(self, group: dict, state: dict) -> bool: """ Returns True if on this step we need to update the learning-rate matrices for this tensor and False if not (note: this may just mean we update the gradient-dependent part). The periodicity with which we update them increases from (by default) 200 at the start of training to 2000 later on. """ try: zero_step = state["zero_step"] except: # This parameter tensor has no learning-rate matrices to estimate return False step = state["step"] period_initial, period_final = group["lr_update_period"] # this formula gradually increases the periodicity from period_initial at the start # to period_final when step >> 4 * period_final cur_update_period = (period_initial + ((period_final - period_initial) * step / (step + 4 * period_final))) return (step >= zero_step + cur_update_period) def _compute_bases(self, group: dict, p_shape: torch.Size, state: dict): """ For each tensor dimension that we are preconditioning, this function sets state[f"Q_{dim}"] to an orthogonal matrix (per batch element and block); it will be indexed [batch_index, block_index, diagonalized_coordinate, canonical_coordinate]. This will be the matrix that diagonalizes Z, where Z is the symmetric matrix satisfying: Z G Z = P (eqn:1) with G being the gradient covariance grad_cov, P being the smoothed version of the parameter covariance param_cov, and Z the symmetric positive definit learning-rate matrices. We'll discuss later how we solve for Z. Firstly, because we want a basis that is as meaningful possible for smoothing P, we will first put P and G in a basis where G is diagonal. That is, we do an SVD G = U G' U^T, and multiplying (eqn:1) on the left and right by U^T and U respetively, and inserting some expressions equivalent to I, we have: (U^T Z U) (U^T G U) (U^T Z U) = (U^T P U) or: Z' G' Z' = P' (eqn:2) where Z' = U^T Z U and P' = U^T P U, and of course G' is diagonal. We are actually going to be solving (eqn:2), and then computing: Z = U Z' U^T. A solution to (eqn:2) is as follows. We are going to be using a Cholesky-based solution in favor of one that requires SVD or eigenvalue decomposition, because it is much faster (we first have to be careful that the input is not close to singular, though). So, with the Cholesky decomposition of P' being: P' = C C^T, the solution is going to be: Z' = C (C^T G' C)^{-0.5} C^T (eqn:3) [[ we can verify that this is a solution by multiplying (eqn:1) by C^{-1} and C^{-T} on the left and right respectively, giving: (C^T G' C)^{-0.5} C^T G' C (C^T G' C)^{-0.5} = I which we can immediately see is the case. ]] It's actually going to be more convenient to compute Z'^{1} (the inverse of Z'), because ultimately only need the basis that diagonalizes Z, and (C^T G' C) may be singular. It's not hard to work out that: Z'^{-1} = C^{-T} (C^T G' C)^0.5 C^{-1} (eqn:4) Args: group: dict to look up config values p_shape: the shape of the batch of identical-sized tensors we are optimizing state: dict to look up optimization state Return: This function returns a list of Tensors P_proj indexed by dim from 0..ndim-1, the tensors being of shape (batch_size, size), which contain the diagonal of the smoothed parameter covariance P after projection by the orthogonal matrix state[f"Q_{dim}"] that diagonalizes Z (this is the space in which we do the actual update). """ ndim = len(p_shape) P_proj = [None] * ndim denom_rel_eps = group["denom_rel_eps"] eps = group["eps"] for dim in range(1, ndim): size = p_shape[dim] try: Q = state[f"Q_{dim}"] except KeyError: assert size == 1 or size == numel, size continue # e.g. size == 1 or size == numel param_cov = state[f"param_cov_{dim}"] # (batch_size, num_blocks, block_size, block_size) grad_cov = state[f"grad_cov_{dim}"] # (batch_size, num_blocks, block_size, block_size) (batch_size, num_blocks, block_size, block_size) = param_cov.shape # U_g diagonalizes grad_cov, in the sense that U_g^T grad_cov U_g is diagonal; its shape # is the same as param_cov and grad_cov. # # G_prime is the diagonalized grad_cov, G' above, of shape (batch_size, num_blocks, block_size). U_g, G_prime, _ = _svd(grad_cov) # Recompute G_prime via matrix multiplication, as svd does not seem to produce zeros on # the singular values in places where it should be exactly zero. This is to keep # dimensions with zero grad separate from those with nonzero grad. G_prime = _diag(torch.matmul(U_g.transpose(2,3), torch.matmul(grad_cov, U_g))) G_prime_noeps = G_prime.clone() # Use the form of the diagonalized gradient matrix that we get after # we add the Adam-type smoothing with epsilon. G_prime += (_mean(G_prime, exclude_dims=[0], keepdim=True) *(denom_rel_eps * denom_rel_eps) + (eps * eps)) # P_prime is P' above, which represents param_cov in the basis that diagonalizes G_prime. # It is not smoothed yet. P_prime = torch.matmul(U_g.transpose(2, 3), torch.matmul(param_cov, U_g)) P_prime_unsmoothed = P_prime P_prime = self._smooth_param_cov(group, p_shape, P_prime, G_prime) # C will satisfy: P_prime == torch.matmul(C, C.transpose(2, 3)) # C is of shape (batch_size, num_blocks, block_size, block_size). #def _fake_cholesky(X): # U_, S_, _ = _svd(X) # return U_ * S_.sqrt().unsqueeze(-2) #C = _fake_cholesky(P_prime) C = P_prime.cholesky() # OK, P == U C C^T U^T. # A matrix that takes normally distributed data to P would # be U_g C, because C I C^T = C C^T = P. We can actually use *any* matrix # that takes normally distributed data to P, so we can use # U_g C U for any orthogonal U, since U_g C U I U^T C^T U_g^T == P. # So there is no harm in choosing a matrix U that diagonalizes the # projected grad_cov. grad_cov gets projected by # this projects P; its transpose projects the gradient. UC = torch.matmul(U_g, C) # instead of projecting grad_cov, we can just use its diagonal, forget the # # U_g part of the transform, and project with C. grad_cov_proj = torch.matmul(C.transpose(2, 3) * G_prime_noeps.unsqueeze(-1), C) # OK, grad_cov is diagonalized by U^T C^T U_g^T. So the projection that we # apply to the param cov is U_g C U U, S, _ = _svd(grad_cov_proj) # proj is indexed [batch_idx,block_idx,canonical_coordinate,diagonalized_coordinate], # so we need to transpose to get Q_{dim}. proj = torch.matmul(UC, U) Q[:] = proj.transpose(2, 3) continue def _smooth_param_cov(self, group: dict, p_shape: torch.Size, P_prime: Tensor, G_prime: Tensor) -> Tensor: """ This function returns a modified/smoothed version of the parameter covariance P_prime. Args: group: dict to look up config values p_shape: The shape of the parameter we are optimizing P_prime: a Tensor of shape (batch_size, num_blocks, block_size, block_size), containing the parameter covariance in a basis that diagonalizes the gradient covariance. G_prime: the diagonalized gradient covariance, of shape (batch_size, num_blocks, block_size) state[f"param_cov_{dim}"], which is an estimate of the covariance of the parameter p, averaged over time, and taken over dimension `dim` of the tensor. The smoothing done here limits the extent to which the parameter covariance can be strongly "off-diagonal" with respect to the gradient covariance. That is: if the parameter covariance is just the gradient covariance to some power, this function does no smoothing; but if it is highly off-diagonal we do more smoothing. """ # do smoothing based on 'rank', # that is intended to compensate for bad estimates of P. (batch_size, num_blocks, block_size, block_size) = P_prime.shape # `rank_per_block` is the rank of each block of P_prime if we were to estimate it from just one # parameter tensor. We average it over time, but actually it won't be changing # too much, so `rank` does tell us something. size = num_blocks * block_size rank = p_shape.numel() // (size * batch_size) # actually the rank of each block smooth0 = group["param_rms_smooth0"] smooth1 = group["param_rms_smooth1"] # We want expr for smoothing amount to be of the form: smooth = alpha * size / (beta*rank + size) # for "size" here, we actually want to use block_size, since we are concerned about the # robustness of the covariance within these blocks. # param_rms_smooth{0,1} represents the user-specified desired amount of smoothing # when rank==0*size and rank==1*size, respectively. # from rank==0*size, we get smooth0 = alpha * size/size, so alpha = smooth0. # from setting rank==size, we get smooth1 = alpha * size / (beta*size * size) = alpha/(1+beta), # so smooth1 == smooth0 / (1+beta), so (1+beta) = smooth0/smooth1, so beta=smooth0/smooth1 - 1 smooth = smooth0 * block_size / ((smooth0/smooth1 - 1) * rank + block_size) if True: logging.info(f"block size={block_size}, rank={rank}, smooth={smooth}") # add rank-dependent smoothing amount to diagonal of P_prime. _diag() returns an aliased tensor. # we don't need to multiply `smooth` by anything, because at this point, P_prime should have # diagonal elements close to 1. P_prime_diag = _diag(P_prime) P_prime_diag_mean = _mean(P_prime_diag, exclude_dims=[0], keepdim=True) P_prime_diag += smooth * P_prime_diag_mean if True: # This block smooths G_prime. # Make sure G_prime has unit mean and no eigenvalue is super small. Note, G_prime # is already diagonalized, the variable G_prime is just the tensor of eigenvalues. G_prime_mean = _mean(G_prime, exclude_dims=[0], keepdim=True) G_prime_min = group["cov_min"][3] # make sure G_prime has no zero eigs, and is unit mean. G_prime = G_prime + G_prime_min * G_prime_mean + 1.0e-20 G_prime /= _mean(G_prime, exclude_dims=[0], keepdim=True) # it now has unit mean.. G_prime_max = group["cov_max"][3] G_prime = 1. / (1./G_prime + 1./G_prime_max) # apply max G_prime_pow = group["cov_pow"][3] if G_prime_pow != 1.0: G_prime = G_prime ** G_prime_pow G_prime_rms = G_prime.sqrt() G_prime_scale = G_prime_rms.unsqueeze(-1) * G_prime_rms.unsqueeze(-2) # P_gnorm is a version of P_prime that is multiplied by G (actually # G^{0.5} P_prime G^{0.5}), so that it reflects the amount of # loss-function change in each dimension. We also tried smoothing # a version of P_prime divided by G, but it seemed not to be helpful. P_gnorm = P_prime * G_prime_scale # Apply another round of smoothing "relative to G" P_gnorm = self._smooth_cov(P_gnorm, group["cov_min"][1], group["cov_max"][1], group["cov_pow"][1]) # Undo the scaling relative to G, so we have stage-2-smoothed version of P_prime. P_prime = P_gnorm / G_prime_scale # Apply a 3rd round of smoothing in the canonical basis. P_prime = self._smooth_cov(P_prime, group["cov_min"][2], group["cov_max"][2], group["cov_pow"][2]) return P_prime def _smooth_cov(self, X: Tensor, min_eig: float, max_eig: float, power: float = 1.0) -> Tensor: """ Returns a `smoothed` version of a symmetric positive definite covariance matrix [with block-diagonal structure, in a batch]. This is done without SVD (which can be very slow). The eigenvalues L will be transformed as: L = L + min_eig * L.mean() + eps L /= L.mean() L = 1 / (1/L + 1/max_eig) # soft-min between L and max_eig L = L ** power # need SVD for this, will get rid of the requirement later. L /= L.mean() # Note on approximation functions like x^0.75 for smallish x: on wolframalpha, type: # plot x^0.75 and 0.05 + (1.1x - 0.18 x^2 + 0.02 x^3) for x from 0 to 10 # [this starts to diverge after 5 or so] Args: min_eig: minimum allowed eigenvalue of returned X max_eig: maximum allowed eigenvalue of returned X power: power to take eigenvalues to X: the batch of symmetric positive definite tensors we are smoothing; of shape (batch_size, num_blocks, block_size, block_size) """ eps = 1.0e-10 if power != 1.0: U, S, _ = _svd(X) S_mean = _mean(S, exclude_dims=[0], keepdim=True) S = S + min_eig * S_mean + eps S_mean = S_mean * (1 + min_eig) + eps S = S / S_mean S = 1. / (1./S + 1./max_eig) S = S ** power S = S / _mean(S, exclude_dims=[0], keepdim=True) return torch.matmul(U * S.unsqueeze(-2), U.transpose(2, 3)) else: X = X.clone() diag = _diag(X) # Aliased with X mean_eig = _mean(diag, exclude_dims=[0], keepdim=True) diag += (mean_eig * min_eig + eps) cur_diag_mean = mean_eig * (1 + min_eig) + eps X /= cur_diag_mean.unsqueeze(-1) # OK, now the mean of the diagonal of X is 1 (or less than 1, in # which case X is extremely tiny). # eig_ceil is the maximum possible eigenvalue that X could possibly # have at this time. eig_ceil = X.shape[-1] while max_eig <= eig_ceil * 0.5: # this is equivalent to the operation: l -> l - 0.5/eig_ceil*l*l # on eigenvalues, which maps eig_ceil to 0.5*eig_ceil and is monotonically # increasing from 0..eig_ceil. logging.info(f"X={X}, eig_ceil={eig_ceil}") X = X - 0.5/eig_ceil * torch.matmul(X, X) eig_ceil = 0.5 * eig_ceil # max_eig > eig_ceil * 0.5 if max_eig < eig_ceil: # map l -> l - coeff*l*l, if l==eig_ceil, this # takes us to: # eig_ceil - (eig_ceil-max_eig/(eig_ceil*eig_ceil))*eig_ceil*eig_ceil # == max_eig # .. and the fact that coeff <= 0.5/eig_ceil [since max_eig>eig_ceil*0.5] # means that the function is monotonic on inputs from 0 to eig_ceil. coeff = (eig_ceil - max_eig) / (eig_ceil*eig_ceil) X = X - coeff * torch.matmul(X, X) # Normalize again. X /= _mean(_diag(X), exclude_dims=[0], keepdim=True).unsqueeze(-1) return X def _update_grad_cov(self, group: dict, p: Tensor, state: dict): """ Updates the gradient covariance state[f"grad_cov_{dim}"] for appropriate dimensions of the tensor. """ beta2 = group["betas"][1] grad = p.grad batch_size = p.shape[0] for dim in range(1, grad.ndim): # dim 0 is batch dim. size = grad.shape[dim] name = f"grad_cov_{dim}" if name not in state: continue # e.g. size==1, size == numel, size > max_block_size grad_cov = state[name] # shaped: (batch_size, num_blocks, block_size, block_size) (batch_size, num_blocks, block_size, block_size) = grad_cov.shape g = grad.transpose(-1, dim) # if grad were of shape (batch_size, x, size, y, z), # g would be of shape (batch_size, x, z, y, size); and # after the next line, g will be of shape # (batch_size, x, z, y, num_blocks, block_size) g = g.reshape(*g.shape[:-1], num_blocks, block_size) g = _move_dim(g, -2, 1) # (batch_size, num_blocks, x, z, y, block_size) g = g.reshape(batch_size, num_blocks, -1, block_size) # now g is of shape (batch_size, num_blocks, x*z*y, block_size) # this_grad_cov: (batch_size, num_blocks, block_size, block_size) this_grad_cov = torch.matmul(g.transpose(-2, -1), g) grad_cov.mul_(beta2).add_(this_grad_cov, alpha=(1-beta2)) def _step(self, group: dict, p: Tensor, state: dict): """ This function does the core update of self.step(), in the case where the tensor has more than 1 elements. It multiplies the moving-average gradient tensor by a state["lr_{dim}"] on each non-trivial axis/dimension, does a length normalization, and takes a step in that direction. Args: group: A dict which will be used to look up configuration values p: The parameter to be updated grad: The grad of p state: The state-dict corresponding to parameter p batch: if true, the first dimension of p is a batch dim, i.e. a batch of parameter matrices. This function modifies p. """ grad = p.grad lr = group["lr"] beta1, beta2 = group["betas"] eps = group["eps"] denom_rel_eps = group["denom_rel_eps"] step = state["step"] grad = self._project(grad, state, forward=True) exp_avg_sq = state["exp_avg_sq"] exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1-beta2)) this_step = state["step"] - (state["zero_step"] if "zero_step" in state else 0) bias_correction2 = 1 - beta2 ** (this_step + 1) if bias_correction2 < 0.99: # note: not in-place. exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2) denom = exp_avg_sq.sqrt() denom += eps + denom_rel_eps * _mean(denom, exclude_dims=[0], keepdim=True) grad = grad / denom # and project back.. grad = self._project(grad, state, forward=False) alpha = -lr * (1-beta1) * state["param_rms"] delta = state["delta"] delta.add_(grad * alpha) p.add_(delta) def _step_scalar(self, group: dict, p: Tensor, state: dict): """ A simplified form of the core update for scalar tensors, where we cannot get a good estimate of the parameter rms. p will not be a scalar, because it has a batch dim. """ beta1, beta2 = group["betas"] scalar_max = group["scalar_max"] eps = group["eps"] lr = group["lr"] grad = p.grad exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1-beta2) # bias_correction2 is like in Adam. Don't bother with bias_correction1; # slower update at the start will help stability anyway. bias_correction2 = 1 - beta2 ** (state["step"] + 1) denom = (exp_avg_sq / bias_correction2).sqrt() + eps delta = state["delta"] delta.add_(grad / denom, alpha=-lr*(1-beta1)) p.clamp_(min=-scalar_max, max=scalar_max) p.add_(delta) def _project(self, M: Tensor, state: dict, forward: bool) -> Tensor: """ Multiply a tensor M by a projection Q on each of its dimensions (for which we have such a projection) Args: M: The tensor to project, e.g. a gradient or a parameter change. state: dict to look up the projections forward: if True, go in the forward direction (from canonical to diagnonalized co-ordinates); if False, the reverse. They differ by a transpose, not an inverse, and do not make a round trip. """ numel = M.numel() for dim in range(1, M.ndim): # dim 0 is batch dim (batch of parameter # tensors of same shape) try: Q = state[f"Q_{dim}"] # shape: (batch_size, num_blocks, block_size, block_size) except KeyError: continue # no projection for this dim (batch_size, num_blocks, block_size, block_size) = Q.shape size = M.shape[dim] # == num_blocks * block_size if forward: # Q is indexed # [batch_idx, block_idx, diagonalized_idx, canonical_idx]; in # the forward direction we want to change canonical to # diagonalized index, so we transpose. Q = Q.transpose(2, 3) # assume M currently has shape (batch_size, x, size, y, z), and dim==2. M = M.transpose(-1, dim) # (batch_size, x, y, z, size) M = M.reshape(batch_size, *M.shape[1:-1], num_blocks, block_size) # (batch_size, x, y, z, num_blocks, block_size) M = _move_dim(M, -2, 1) # (batch_size, num_blocks, x, y, z, block_size) while Q.ndim < M.ndim: Q = Q.unsqueeze(2) # now Q has shape (batch_size, num_blocks, 1, 1, block_size, block_size) if M.ndim < Q.ndim: # special case where M shape is e.g. (batch_size, num_blocks, x) M = torch.matmul(M.unsqueeze(-2), Q).squeeze(-2) else: M = torch.matmul(M, Q) # (batch_size, num_blocks, x, y, z, block_size) M = _move_dim(M, 1, -2) # (batch_size, x, y, z, num_blocks, block_size) M = M.reshape(*M.shape[:-2], size) # (batch_size, x, y, z, size) M = M.transpose(-1, dim) # (batch_size, x, size, y, z) return M def _move_dim(x: Tensor, orig_dim: int, new_dim: int) -> Tensor: """ Moves the position of one dimension in a Tensor, while keeping the remaining dims in the same order. E.g. if x has shape [10, 1, 2, 3], _move_dim(x, 3, 1) will have shape [10, 3, 1, 2]. Returns the reshaped tensor which will be a view of the original Tensor. Args: x: Tensor to reshape orig_dim: Original dimension to move, with -x.ndim <= orig_dim < x.ndim new_dim: New position of the original dimension, with -x.ndim <= new_dim < x.ndim Returns: a permuted view of x """ if orig_dim < 0: orig_dim += x.ndim if new_dim < 0: new_dim += x.ndim dims = list(range(x.ndim)) dims[orig_dim] = -1 dims.remove(-1) dims.insert(new_dim, orig_dim) return x.permute(dims) def _diag(x: Tensor): """ like torch diag(), this returns the diagonal of a matrix; but it returns an aliased tensor, and it supports batch dims, i.e. input of shape (B, M, M) returns output of shape (B, M), or input of shape (A, B, M, M) returns output of shape (A, B, M). """ stride = x.stride() if x.ndim == 3: (B, M, M2) = x.shape assert M == M2 ans = x.as_strided(size=(B, M), stride=(stride[0], stride[1] + stride[2])) elif x.ndim == 4: (B, C, M, M2) = x.shape assert M == M2 ans = x.as_strided(size=(B, C, M), stride=(stride[0], stride[1], stride[2] + stride[3])) elif x.ndim == 2: (M, M2) = x.shape assert M == M2 ans = x.as_strided(size=(M,), stride=(stride[0] + stride[1],)) return ans def _sum(x: Tensor, exclude_dims: List[int] = [], keepdim: bool = False): """ Version of torch.sum where you specify dims to exclude, not dims to sum. """ if len(exclude_dims) == 0: # dim=[] works the same as tuple(range(x.ndim)), which makes # no sense.. supplying the full tuple for clarity and in case # the interface changes in future. return x.sum(dim=tuple(range(x.ndim)), keepdim=keepdim) elif x.ndim == 1: assert exclude_dim == [0] or exclude_dim == [-1] return x # if one dim is excluded, there are no dims to sum, and # x.sum(dim=[]) sums all dims so we should not call sum(). exclude_dims_norm = [i if i >= 0 else i + x.ndim for i in exclude_dims] dims = [i for i in range(x.ndim) if i not in exclude_dims_norm] assert (len(dims) + len(exclude_dims) == x.ndim), exclude_dims return x.sum(dim=dims, keepdim=keepdim) def _mean(x: Tensor, exclude_dims: List[int] = [], keepdim: bool = False): """ Version of torch.mean where you specify dims to exclude, not dims to mean. """ if len(exclude_dims) == 0: # dim=[] works the same as tuple(range(x.ndim)), which makes # no sense.. supplying the full tuple for clarity and in case # the interface changes in future. return x.mean(dim=tuple(range(x.ndim)), keepdim=keepdim) elif x.ndim == 1: assert exclude_dims == [0] or exclude_dims == [-1] return x # if one dim is excluded, there are no dims to mean, and # x.mean(dim=[]) means all dims so we should not call mean(). exclude_dims_norm = [i if i >= 0 else i + x.ndim for i in exclude_dims] dims = [i for i in range(x.ndim) if i not in exclude_dims_norm] assert (len(dims) + len(exclude_dims) == x.ndim), exclude_dims return x.mean(dim=dims, keepdim=keepdim) def _svd(x: Tensor, recursion_depth: int = 0): # Wrapper for torch svd that catches errors and re-tries (to address bugs in # some versions of torch SVD for CUDA) xsum = x.sum() if not (xsum - xsum == 0): raise RuntimeError("svd on inf or nan input") try: U, S, V = x.svd() s = U.sum() if s - s != 0 or (__name__ == '__main__' and random.random() < 0.01): raise RuntimeError(f"svd failed (or random test), sum={s}") return U, S, V # success except: if recursion_depth < 2: logging.warning(f"svd failed: recursion_depth={recursion_depth}, " f"x.shape={tuple(x.shape)}, x.sum()={x.sum()}: likely " f"error in SVD. Will try again after random permutation of rows") # randomly permute the row indexes of the matrices, and retry, hoping this # fixes the issue n = x.shape[-2] randperm = torch.randperm(n, device=x.device) inv_randperm = torch.zeros(n, device=x.device, dtype=torch.int64) inv_randperm[randperm] = torch.arange(n, device=x.device) x = torch.index_select(x, dim=-2, index=randperm) U,S,V = _svd(x, recursion_depth + 1) return torch.index_select(U, dim=-2, index=inv_randperm), S, V elif recursion_depth < 4: logging.warning(f"svd failed after {recursion_depth} tries: x.shape={tuple(x.shape)}, x.sum()={x.sum()}. Will try orthogonal transformation") Urand, _, _ = torch.randn(x.shape[-2], x.shape[-1], device=x.device, dtype=x.dtype).svd() U, S, V = _svd(torch.matmul(Urand, x), recursion_depth + 1) return torch.matmul(Urand.t(), U), S, V else: raise RuntimeError(f"svd failed after {recursion_depth} tries") class Cain(Optimizer): """Implements Cain algorithm, which is a substantially modified form of the Adam optimizer. The modifications can be summarized as follows: (i) change Adam to what we called "Eve", which adds an extra configuration value, "target_rms" (default value: 0.1); and for non-scalar parameters, if their root-mean-square value exceeds "target_rms", we apply weight decay (aggressive enough weight decay that the rms stays at target_rms). This weight decay is applied in the same way as in AdamW, i.e. by parameter shrinkage rather than by modifying the gradients. (ii) By limiting the rms to 0.1, we reduce our modeling power; we restore it by replacing all weights and biases-- in fact, all model parameters-- with expressions like weight * weight_scale.exp() or bias * bias_scale.exp(). This has the benefit of making the learnable offsets and scales of BatchNorm/LayerNorm unnecessary. Below we describe the changes from Eve to Cain: (iii) We removed all the weight_scale and bias_scale parameters from the model and "moved them into the optimizer". Now we allow parameters to have any rms value (that's <= 10), but we: (a) make the update speed for a parameter proportional to its rms value. (b) Multiply the learning rate by 10 because we are getting rid of target_rms=0.1 (c) simulate the effect of learning the weight_scale or bias_scale in log space using standard Adam, with a scale to prevent it from learning too fast. (see scale_exp_avg_sq in the code below). The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_. The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_. Eve is unpublished so far. Arguments: params (iterable): iterable of parameters to optimize or dicts defining parameter groups lr (float, optional): learning rate (default: 1e-3) betas (Tuple[float, float], optional): coefficients used for computing running averages of gradient and its square (default: (0.9, 0.999)) eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8) scale_speed (float, optional): we multiply the learning rate by this value when learning (in log space) the scales of the parameter tensors rms_eps (float, optional): epsilon value such that when parameter rms value goes below this, we stop learning slower for that parameter, i.e. we treat the rms as if it were rms_eps. In practice, rms values will not go much below this. param_max (float, optional): maximum absolute value for any parameter; we will clamp parameters to satisfy -param_max <= param <= param_max .. _Adam\: A Method for Stochastic Optimization: https://arxiv.org/abs/1412.6980 .. _Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101 .. _On the Convergence of Adam and Beyond: https://openreview.net/forum?id=ryQu7f-RZ """ def __init__( self, params, lr=1e-2, betas=(0.9, 0.98), eps=1e-8, scale_speed=0.05, rms_eps=1.0e-05, param_max=20.0, ): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if not 0.0 <= betas[0] < 1.0: raise ValueError( "Invalid beta parameter at index 0: {}".format(betas[0]) ) if not 0.0 <= betas[1] < 1.0: raise ValueError( "Invalid beta parameter at index 1: {}".format(betas[1]) ) if not 0 < scale_speed < 1.0: raise ValueError("Invalid scale_speed value: {}".format(scale_speed)) if not 0.1 < param_max <= 10000.0: raise ValueError("Invalid param_max value: {}".format(param_max)) if not 0.0 < rms_eps < 0.1: raise ValueError("Invalid rms_eps value: {}".format(rms_eps)) defaults = dict( lr=lr, betas=betas, eps=eps, scale_speed=0.05, rms_eps=rms_eps, param_max=param_max, ) super(Cain, self).__init__(params, defaults) def __setstate__(self, state): super(Cain, self).__setstate__(state) @torch.no_grad() def step(self, closure=None): """Performs a single optimization step. Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ loss = None if closure is not None: with torch.enable_grad(): loss = closure() for group in self.param_groups: beta1, beta2 = group["betas"] lr = group["lr"] scale_speed = group["scale_speed"] rms_eps = group["rms_eps"] eps = group["eps"] param_max = group["param_max"] for p in group["params"]: if p.grad is None: continue # Perform optimization step grad = p.grad if grad.is_sparse: raise RuntimeError( "Cain does not support sparse gradients" ) state = self.state[p] # State initialization if len(state) == 0: state["step"] = 0 # Exponential moving average of gradient values state["delta"] = torch.zeros_like( p, memory_format=torch.preserve_format ) # Exponential moving average of squared gradient values state["exp_avg_sq"] = torch.zeros_like( p, memory_format=torch.preserve_format ) if p.numel() > 1: # we keep track of the RMS value of the parameters to # help set the learning rate... this emulates Eve. state["param_rms"] = torch.ones_like(p) # we learn the scale on the parameters in a way that # also emulates Eve. scale_exp_avg_sq is the # moving-average square of the gradient w.r.t. this # scalar scale. state["scale_exp_avg_sq"] = torch.zeros((), device=p.device, dtype=p.dtype) delta, exp_avg_sq = state["delta"], state["exp_avg_sq"] step = state["step"] delta.mul_(beta1) if p.numel() > 1: if True: # This block learns the scale on the parameters. # scale_deriv is the derivative w.r.t. a log-space scale # on the parameters, i.e. if we imagine: p = # underlying_param * scale.exp(), scale_deriv is the # derivative w.r.t. scale (it does not matter # what value `scale` currently has). scale_deriv = (p * grad).sum() scale_exp_avg_sq = state["scale_exp_avg_sq"] scale_exp_avg_sq.mul_(beta2).addcmul_(scale_deriv, scale_deriv, value=1 - beta2) # should actually be step + 1, so on 1st minibatch we are not learning # anything here. May fix this at some point. scale_bias_correction2 = 1 - beta2 ** (step + 1) scale_denom = (scale_exp_avg_sq.sqrt()).add_(eps) scale_alpha = -lr * (scale_bias_correction2 ** 0.5) * scale_speed * (1-beta1) # scale_delta is going to be the change in log-scale (before momentum), i.e. if # p = underlying_param * scale.exp(), # delta is the change in `scale`. scale_delta = scale_alpha * (scale_deriv / scale_denom) delta.add_(p, alpha=scale_delta) # Decay the second moment running average coefficient exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) bias_correction2 = 1 - beta2 ** step grad_rms = (exp_avg_sq.sqrt()).add_(eps) this_delta = grad / grad_rms param_rms = state["param_rms"] self._update_param_rms(p, param_rms, step, rms_eps) alpha = (-(1-beta1) * lr * bias_correction2) # geometrically interpolate the per-element param_rms with its mean #param_rms_half = (param_rms * param_rms.mean()) ** 0.5 delta.add_(this_delta * param_rms, alpha=alpha) # below, will do: delta.add_(this_delta, alpha=alpha) else: # p.numel() == 1 # Decay the second moment running average coefficient exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) bias_correction2 = 1 - beta2 ** step denom = (exp_avg_sq.sqrt()).add_(eps) this_delta = grad / denom alpha = -lr * (1-beta1) * (bias_correction2 ** 0.5) delta.add_(this_delta, alpha=alpha) p.add_(delta) if step % 10 == 0: p.clamp_(min=-param_max, max=param_max) state["step"] += 1 return loss def _update_param_rms(self, p: Tensor, param_rms: Tensor, step: int, rms_eps: float): """ Updates `param_rms`. Require p.numel() > 1 Args: p: parameter whose magnitude/rms we are measuring param_rms: a Tensor containing non-negative values, with the same shape as p. Contains an estimate of the root-mean-square value of the parameter. This has a special structure where it must be a product of one-dimensional quantities, one for each axis of p. We exclude certain axes to avoid estimating the rms value from a too-small number of parameters. step: the step of the algorithm rms_eps: epsilon such that we'll aim to prevent param_rms from going much lower than this. """ if step % 1000 == 0: # Periodically refresh param_rms to keep the invariance that it is a # product over its axes (numerical roundoff might destroy this) param_rms.fill_(1.0) for dim in range(p.ndim): self._update_param_rms_for_dim(p, param_rms, dim, rms_eps) else: dim = step % p.ndim self._update_param_rms_for_dim(p, param_rms, dim, rms_eps) def _update_param_rms_for_dim(self, p: Tensor, param_rms: Tensor, dim: int, rms_eps: float): """ Updates `param_rms` on one of its axes/dimensions. Args: p: parameter whose magnitude/rms we are measuring param_rms: a Tensor containing non-negative values, with the same shape as p. Contains an estimate of the root-mean-square value of the parameter. This has a special structure where it must be a product of one-dimensional quantities, one for each axis of p. We exclude certain axes to avoid estimating the rms value from a too-small number of parameters. dim: with 0 <= dim < p.ndim, the dimension/axis on which to update the scale factor rms_eps: epsilon such that we'll aim to prevent param_rms from going much lower than this. """ # var_factor is the square of the factor to change param_rms by. p will # not be super large, we limit it to -param_max <= p <= param_max. # Clamping p**2 to min=1.0e-10 should prevent param_rms from getting much # smaller than 1.0e-05, which should prevent division by zero here. var_factor = (p ** 2).clamp(min=rms_eps*rms_eps) / (param_rms ** 2) if p.numel() <= 2 * p.shape[dim]: var_factor = var_factor.mean() else: dims = tuple(d for d in range(p.ndim) if d != dim) var_factor = var_factor.mean(dim=dims, keepdim=True) #if random.random() < 0.01: # logging.info(f"shape={p.shape}, var_factor={var_factor}") param_rms.mul_(var_factor.sqrt()) class LRScheduler(object): """ Base-class for learning rate schedulers where the learning-rate depends on both the batch and the epoch. """ def __init__(self, optimizer: Optimizer, verbose: bool = False): # Attach optimizer if not isinstance(optimizer, Optimizer): raise TypeError( "{} is not an Optimizer".format(type(optimizer).__name__) ) self.optimizer = optimizer self.verbose = verbose for group in optimizer.param_groups: group.setdefault("initial_lr", group["lr"]) self.base_lrs = [ group["initial_lr"] for group in optimizer.param_groups ] self.epoch = 0 self.batch = 0 def state_dict(self): """Returns the state of the scheduler as a :class:`dict`. It contains an entry for every variable in self.__dict__ which is not the optimizer. """ return { "base_lrs": self.base_lrs, "epoch": self.epoch, "batch": self.batch, } def load_state_dict(self, state_dict): """Loads the schedulers state. Args: state_dict (dict): scheduler state. Should be an object returned from a call to :meth:`state_dict`. """ self.__dict__.update(state_dict) def get_last_lr(self) -> List[float]: """Return last computed learning rate by current scheduler. Will be a list of float.""" return self._last_lr def get_lr(self): # Compute list of learning rates from self.epoch and self.batch and # self.base_lrs; this must be overloaded by the user. # e.g. return [some_formula(self.batch, self.epoch, base_lr) for base_lr in self.base_lrs ] raise NotImplementedError def step_batch(self, batch: Optional[int] = None) -> None: # Step the batch index, or just set it. If `batch` is specified, it # must be the batch index from the start of training, i.e. summed over # all epochs. # You can call this in any order; if you don't provide 'batch', it should # of course be called once per batch. if batch is not None: self.batch = batch else: self.batch = self.batch + 1 self._set_lrs() def step_epoch(self, epoch: Optional[int] = None): # Step the epoch index, or just set it. If you provide the 'epoch' arg, # you should call this at the start of the epoch; if you don't provide the 'epoch' # arg, you should call it at the end of the epoch. if epoch is not None: self.epoch = epoch else: self.epoch = self.epoch + 1 self._set_lrs() def _set_lrs(self): values = self.get_lr() assert len(values) == len(self.optimizer.param_groups) for i, data in enumerate(zip(self.optimizer.param_groups, values)): param_group, lr = data param_group["lr"] = lr self.print_lr(self.verbose, i, lr) self._last_lr = [group["lr"] for group in self.optimizer.param_groups] def print_lr(self, is_verbose, group, lr): """Display the current learning rate.""" if is_verbose: logging.info( f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate" f" of group {group} to {lr:.4e}." ) class Eve(Optimizer): """ Implements Eve algorithm. This is a modified version of AdamW with a special way of setting the weight-decay / shrinkage-factor, which is designed to make the rms of the parameters approach a particular target_rms (default: 0.1). This is for use with networks with 'scaled' versions of modules (see scaling.py), which will be close to invariant to the absolute scale on the parameter matrix. The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_. The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_. Eve is unpublished so far. Arguments: params (iterable): iterable of parameters to optimize or dicts defining parameter groups lr (float, optional): learning rate (default: 1e-3) betas (Tuple[float, float], optional): coefficients used for computing running averages of gradient and its square (default: (0.9, 0.999)) eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8) weight_decay (float, optional): weight decay coefficient (default: 3e-4; this value means that the weight would decay significantly after about 3k minibatches. Is not multiplied by learning rate, but is conditional on RMS-value of parameter being > target_rms. target_rms (float, optional): target root-mean-square value of parameters, if they fall below this we will stop applying weight decay. .. _Adam\: A Method for Stochastic Optimization: https://arxiv.org/abs/1412.6980 .. _Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101 .. _On the Convergence of Adam and Beyond: https://openreview.net/forum?id=ryQu7f-RZ """ def __init__( self, params, lr=1e-3, betas=(0.9, 0.98), eps=1e-8, weight_decay=1e-3, target_rms=0.1, ): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if not 0.0 <= betas[0] < 1.0: raise ValueError( "Invalid beta parameter at index 0: {}".format(betas[0]) ) if not 0.0 <= betas[1] < 1.0: raise ValueError( "Invalid beta parameter at index 1: {}".format(betas[1]) ) if not 0 <= weight_decay <= 0.1: raise ValueError( "Invalid weight_decay value: {}".format(weight_decay) ) if not 0 < target_rms <= 10.0: raise ValueError("Invalid target_rms value: {}".format(target_rms)) defaults = dict( lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, target_rms=target_rms, ) super(Eve, self).__init__(params, defaults) def __setstate__(self, state): super(Eve, self).__setstate__(state) @torch.no_grad() def step(self, closure=None): """Performs a single optimization step. Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ loss = None if closure is not None: with torch.enable_grad(): loss = closure() for group in self.param_groups: for p in group["params"]: if p.grad is None: continue # Perform optimization step grad = p.grad if grad.is_sparse: raise RuntimeError( "AdamW does not support sparse gradients" ) state = self.state[p] # State initialization if len(state) == 0: state["step"] = 0 # Exponential moving average of gradient values state["exp_avg"] = torch.zeros_like( p, memory_format=torch.preserve_format ) # Exponential moving average of squared gradient values state["exp_avg_sq"] = torch.zeros_like( p, memory_format=torch.preserve_format ) exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] beta1, beta2 = group["betas"] state["step"] += 1 bias_correction1 = 1 - beta1 ** state["step"] bias_correction2 = 1 - beta2 ** state["step"] # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_( group["eps"] ) step_size = group["lr"] / bias_correction1 target_rms = group["target_rms"] weight_decay = group["weight_decay"] if p.numel() > 1: # avoid applying this weight-decay on "scaling factors" # (which are scalar). is_above_target_rms = p.norm() > ( target_rms * (p.numel() ** 0.5) ) p.mul_(1 - (weight_decay * is_above_target_rms)) p.addcdiv_(exp_avg, denom, value=-step_size) if random.random() < 0.0005: step = (exp_avg/denom) * step_size logging.info(f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}") return loss class Eden(LRScheduler): """ Eden scheduler. lr = initial_lr * (((batch**2 + lr_batches**2) / lr_batches**2) ** -0.25 * (((epoch**2 + lr_epochs**2) / lr_epochs**2) ** -0.25)) E.g. suggest initial-lr = 0.003 (passed to optimizer). Args: optimizer: the optimizer to change the learning rates on lr_batches: the number of batches after which we start significantly decreasing the learning rate, suggest 5000. lr_epochs: the number of epochs after which we start significantly decreasing the learning rate, suggest 6 if you plan to do e.g. 20 to 40 epochs, but may need smaller number if dataset is huge and you will do few epochs. """ def __init__( self, optimizer: Optimizer, lr_batches: Union[int, float], lr_epochs: Union[int, float], verbose: bool = False, ): super(Eden, self).__init__(optimizer, verbose) self.lr_batches = lr_batches self.lr_epochs = lr_epochs def get_lr(self): factor = ( (self.batch ** 2 + self.lr_batches ** 2) / self.lr_batches ** 2 ) ** -0.25 * ( ((self.epoch ** 2 + self.lr_epochs ** 2) / self.lr_epochs ** 2) ** -0.25 ) return [x * factor for x in self.base_lrs] def _test_eden(): m = torch.nn.Linear(100, 100) optim = Cain(m.parameters(), lr=0.003) scheduler = Eden(optim, lr_batches=100, lr_epochs=2, verbose=True) for epoch in range(10): scheduler.step_epoch(epoch) # sets epoch to `epoch` for step in range(20): x = torch.randn(200, 100).detach() x.requires_grad = True y = m(x) dy = torch.randn(200, 100).detach() f = (y * dy).sum() f.backward() optim.step() scheduler.step_batch() optim.zero_grad() logging.info(f"last lr = {scheduler.get_last_lr()}") logging.info(f"state dict = {scheduler.state_dict()}") def _test_eve_cain(): import timeit from scaling import ScaledLinear E = 100 B = 4 T = 2 logging.info("in test_eve_cain") #device = torch.device('cuda') device = torch.device('cpu') dtype = torch.float32 fix_random_seed(42) # these input_magnitudes and output_magnitudes are to test that # Abel is working as we expect and is able to adjust scales of # different dims differently. input_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp() output_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp() #for iter in [3, 2, 1, 0]: # will restore 1,0 later for iter in [3, 2]: fix_random_seed(42) Linear = torch.nn.Linear if iter == 0 else ScaledLinear hidden_dim = 200 m = torch.nn.Sequential(Linear(E, hidden_dim), torch.nn.PReLU(), Linear(hidden_dim, hidden_dim), torch.nn.PReLU(), Linear(hidden_dim, E), ).to(device) train_pairs = [ (100.0 * torch.randn(B, T, E, device=device, dtype=dtype) * input_magnitudes, torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes) for _ in range(20) ] if iter == 0: optim = Eve(m.parameters(), lr=0.003) elif iter == 1: optim = Cain(m.parameters(), lr=0.03) elif iter == 2: optim = PrAdam(m.parameters(), lr=0.03) elif iter == 3: optim = PrAdam(m.parameters(), lr=0.03, max_block_size=256) scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False) start = timeit.default_timer() avg_loss = 0.0 for epoch in range(180): scheduler.step_epoch() #if epoch == 100 and iter in [2,3]: # optim.reset_speedup() # check it doesn't crash. #if epoch == 130: # opts = diagnostics.TensorDiagnosticOptions( # 2 ** 22 # ) # allow 4 megabytes per sub-module # diagnostic = diagnostics.attach_diagnostics(m, opts) for n, (x,y) in enumerate(train_pairs): y_out = m(x) loss = ((y_out - y)**2).mean() * 100.0 if epoch == 0 and n == 0: avg_loss = loss.item() else: avg_loss = 0.98 * avg_loss + 0.02 * loss.item() if n == 0 and epoch % 5 == 0: #norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item() #norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item() #norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item() #norm2b = '%.2e' % (m[2].bias**2).mean().sqrt().item() #scale1 = '%.2e' % (m[0].weight_scale.exp().item()) #scale1b = '%.2e' % (m[0].bias_scale.exp().item()) #scale2 = '%.2e' % (m[2].weight_scale.exp().item()) #scale2b = '%.2e' % (m[2].bias_scale.exp().item()) lr = scheduler.get_last_lr()[0] logging.info(f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}") #, norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b} loss.log().backward() optim.step() optim.zero_grad() scheduler.step_batch() #diagnostic.print_diagnostics() stop = timeit.default_timer() logging.info(f"Iter={iter}, Time taken: {stop - start}") logging.info(f"last lr = {scheduler.get_last_lr()}") #logging.info("state dict = ", scheduler.state_dict()) #logging.info("optim state_dict = ", optim.state_dict()) logging.info(f"input_magnitudes = {input_magnitudes}") logging.info(f"output_magnitudes = {output_magnitudes}") def _test_svd(): device = 'cuda' for i in range(1000): X = torch.eye(100, device=device) + 1.0e-08 * torch.randn(100, 100, device=device) U, S, V = _svd(X) X2 = torch.matmul(U*S, V.t()) assert torch.allclose(X2, X, atol=0.001) def _test_smooth_cov(): b = (torch.arange(10)*0.5).exp().diag() b = b.unsqueeze(0).unsqueeze(0) c = PrAdam._smooth_cov(None, b, 0.0, 10.0) logging.info(f"c[noinv] = {_diag(c)}") c = PrAdam._smooth_cov(None, b, 0.0, 10.0, 1.00001) logging.info(f"c[svd] = {_diag(c)}") if __name__ == "__main__": torch.set_num_threads(1) torch.set_num_interop_threads(1) logging.getLogger().setLevel(logging.INFO) import subprocess s = subprocess.check_output("git status -uno .; git log -1", shell=True) _test_smooth_cov() logging.info(s) #_test_svd() _test_eve_cain() #_test_eden()