1963 lines
83 KiB
Python

# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
#
# See ../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
from typing import List, Optional, Union, Tuple, List
from lhotse.utils import fix_random_seed
import torch
from scaling import ActivationBalancer
import random
from torch import Tensor
from torch.optim import Optimizer
import logging
class BatchedOptimizer(Optimizer):
"""
This class adds to class Optimizer the capability to optimize parameters in batches:
it will stack the parameters and their grads for you so the optimizer can work
on tensors with an extra leading dimension.
Args:
params:
"""
def __init__(self, params, defaults):
super(BatchedOptimizer, self).__init__(params, defaults)
def batched_params(self, param_group):
"""
This function returns an iterator over tuples (p, state), where
p is a `fake` parameter that is stacked (over axis 0) from real parameters
that share the same shape, and its gradient is also stacked;
`state` is the state corresponding to this fake parameter
(it will be physically located in the "state" for one of the real
parameters, the last one that has any particular shape and dtype).
The idea is, instead of doing:
<code>
for p in group["params"]:
state = self.state[p]
...
</code>
you can do:
<code>
for p, state in self.batched_params(group["params"]):
...
</code>
Args:
group: a parameter group, which is a list of parameters; should be
one of self.groups.
"""
batches = defaultdict(list) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter
for p in param_group:
assert p.grad is not None and "All parameters must have a grad when you call optimizer.step()"
assert not p.grad.is_sparse and "Sparse gradients not supported."
key = (str(p.dtype), *p.shape)
batches[key].append(p)
for p in param_group:
key = (str(p.dtype), *p.shape)
batch = batches[key]
if p is batch[0]: # if this is the 1st param in the batch...
# we arbitrarily store the state in the
# state corresponding to the 1st parameter in the
# group. class Optimizer will take care of saving/loading state.
state = self.state[p]
p_stacked = torch.stack(batch)
grad = torch.stack([p.grad for p in batch])
p_stacked.grad = grad
yield p_stacked, state # <-- calling code will do the actual optimization here!
# Now un-stack the parameter changes
for i,p in enumerate(batch):
p.copy_(p_stacked[i])
class PrAdam(BatchedOptimizer):
"""
Implements 'Projected Adam', a variant of Adam with projections for each
dim of each tensor.
Args:
params: The parameters or param_groups to optimize (like other Optimizer subclasses)
lr: The learning rate. We will typically use a learning rate schedule that starts
at 0.03 and decreases over time, i.e. much higher than other common
optimizers.
betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad only for
scalar tensors. Must satisfy 0 < beta <= beta2 < 1.
size_lr_scale: A scaling factor on the learning rate, that we use to update the
scale of each parameter tensor. If each parameter were decomposed
as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, size_lr_scale
is the scaling factor on the learning rate of p_scale.
param_rms_smooth0: Limiting value of smoothing proportion for parameter matrix, as
assumed rank of param covariance [==product of sizes on the other
tensor dims] approaches 0.
param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
param covariance equals the dimension of the covaraince matrix.
param_rms_smooth{0,1} determine the smoothing proportions for other
conditions.
cov_min,cov_max: [IMPORTANT] 5-tuples of minimums and maximums of the diagonal values of
covariance matrices, after normalizing to unit-mean. The first 3 are
for smoothing the parameter covariance, normalized in 3 different ways:
(1) relative to its own diagonal (in a basis that diagonalizes the grad
covariance);
(2) multiplied by the grad covariance,
(3) in the canonical basis.
(4) is for smoothing the grad covariance used for (2)
cov_pow: This was mainly added for development and experimentation purposes;
it allows you to smooth the parameter covariance matrices at the
stages (1), (2), (3) of smoothing mentioned above, and also
the gradient covariance matrix used in stage (2) of smoothing. If the
1st 3 values are not 1.0 it will cause a measurable speed penalty because it
requires SVD. Recommend to leave all these at 1.0.
eps: A general-purpose epsilon to prevent division by zero
denom_rel_eps: An epsilon value used to keep the elements of denominator (exp_avg_grad_sqrt.sqrt())
not too small relative to the mean value; this will limit the speed of update
along directions with very small gradients.
param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of
learning the scale on the parameters (we'll constrain the rms of each non-scalar
parameter tensor to be >= this value)
param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of
learning the scale on the parameters (we'll constrain the rms of each non-scalar
parameter tensor to be <= this value)
scalar_max: Maximum absolute value for scalar parameters (applicable if your
model has any parameters with numel() == 1)
size_update_period: The periodicity, in steps, with which we update the size (scale)
of the parameter tensor. This is provided to save a little time
in the update.
lr_update_period: [IMPORTANT]: A 2-tuple of ints that Determines the periodicity, in steps, with
which we update the learning-rate matrices. The first number is the periodicity at
the start of training, the second number is the periodicity
later in training, and we gradually increase from one to the other.
The reason for such a complicated schedule is that updating the learning
rate matrices is very slow, principally because it requires SVD, and SVD
seems to have quite slow implementations.
max_block_size: [IMPORTANT] The maximum block size in block-diagonal co-ordinate
transformations. You can probably set this to 512 or 1024. Larger
values will require more MEMORY and may be a bit slower, but should
lead to better optimization performance.
"""
def __init__(
self,
params,
lr=3e-02,
betas=(0.9, 0.98),
size_lr_scale=0.1,
cov_min=(0.025, 0.0025, 0.02, 0.0001),
cov_max=(10.0, 80.0, 5.0, 400.0),
cov_pow=(1.0, 1.0, 1.0, 1.0),
param_rms_smooth0=0.4,
param_rms_smooth1=0.2,
eps=1.0e-08,
denom_rel_eps=1.0e-05,
param_min_rms=1.0e-05,
param_max_rms=2.0,
scalar_max=2.0,
size_update_period=4,
lr_update_period=(200, 1000),
# grad_cov_period does not affect speed very much. If __name__ ==
# "__main__", keeping this coprime to 20 which is the number of
# sample pairs used in the test code
grad_cov_period=(3 if __name__ == "__main__" else 5),
max_block_size=1024,
):
defaults = dict(
lr=lr,
size_lr_scale=size_lr_scale,
cov_min=cov_min,
cov_max=cov_max,
cov_pow=cov_pow,
param_rms_smooth0=param_rms_smooth0,
param_rms_smooth1=param_rms_smooth1,
betas=betas,
eps=eps,
denom_rel_eps=denom_rel_eps,
param_min_rms=param_min_rms,
param_max_rms=param_max_rms,
scalar_max=scalar_max,
size_update_period=size_update_period,
lr_update_period=lr_update_period,
grad_cov_period=grad_cov_period,
max_block_size=max_block_size,
)
super(PrAdam, self).__init__(params, defaults)
def __setstate__(self, state):
super(PrAdam, self).__setstate__(state)
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
batch = True
for group in self.param_groups:
for p, state in self.batched_params(group["params"]):
# Perform optimization step
grad = p.grad
if grad.is_sparse:
raise RuntimeError(
"PrAdam optimizer does not support sparse gradients"
)
# State initialization
if len(state) == 0:
self._init_state(group, p, state)
self._step_one_batch(group, p, state)
return loss
def _init_state(self,
group: dict,
p: Tensor,
state: dict):
"""
Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p
is actually the batch dimension, corresponding to batched-together
parameters of a given shape.
Args:
group: Dict to look up configuration values.
p: The parameter that we are initializing the state for, actually
will be stacked-together "real" parameters
state: Dict from string to whatever state we are initializing
"""
eps = group["eps"]
size_update_period = group["size_update_period"]
max_block_size = group["max_block_size"]
state["step"] = 0
kwargs = {'device':p.device, 'dtype':p.dtype}
# 'delta' implements conventional momentum. There are
# several different kinds of update going on, so rather than
# compute "exp_avg" like in Adam, we store and decay a
# parameter-change "delta", which combines all forms of
# update. this is equivalent to how it's done in Adam,
# except for the first few steps.
state["delta"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
batch_size = p.shape[0]
numel = p.numel() // batch_size
if numel > 1:
# "param_rms" just periodically records the scalar root-mean-square value of
# the parameter tensor.
# it has a shape like (batch_size, 1, 1, 1, 1)
param_rms = _mean(p**2, exclude_dims=[0], keepdim=True).sqrt() + eps
state["param_rms"] = param_rms
state["scale_exp_avg_sq"] = torch.zeros_like(param_rms)
state["scale_grads"] = torch.zeros(size_update_period, *param_rms.shape,
**kwargs)
# exp_avg_sq is in the projected space. It is periodically zeroed, and we
# set zero_step whenever this happens.
state["exp_avg_sq"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
# ignore_rank1_dims = True means we don't do our basis-changing update
# on tensors like biases that only have one non-trivial dimension.
# "rank1" refers to the rank of our estimate of the parameter
# covariance, if we were to estimate it just from the current parameter
# value.
ignore_rank1_dims = True
trivial_update = True
for dim in range(1, p.ndim):
size = p.shape[dim]
if size == 1 or (size == numel and ignore_rank1_dims):
continue
trivial_update = False
state["trivial_update"] = trivial_update # so we know whether to zero exp_avg_sq stats.
if trivial_update:
return
# "zero_step" being a member of state is the sign that this parameter has
# at least one dim that has a projection. It also records the most recent
# step on which we zeroed state["exp_avg_sq"]
state["zero_step"] = 0
# last_param_scale_update records the last time we updated the part of the learning rate
# matrices that relates to the parameter covariance; we avoid doing this too often
# as it doesn't change very rapidly and the SVD is quite slow.
state["last_param_scale_update"] = -1
for dim in range(1, p.ndim):
size = p.shape[dim]
if size == 1 or (size == numel and ignore_rank1_dims):
continue
# Most of the time, (num_blocks, block_size) will equal (1, size)
num_blocks, block_size = self._get_block_size(size, max_block_size)
# Q_{dim} is be the learning-rate matrix for this
# dimension, a matrix of indexed [batch_idx, block_idx, diagonalized_coordinate, canonical_coordinate].
# this is used twice in the update step, once transposed and once without transpose.
Q = torch.eye(block_size, block_size, **kwargs).unsqueeze(0).unsqueeze(0).expand(
batch_size, num_blocks, block_size, block_size).contiguous()
state[f"Q_{dim}"] = Q
# cov_{dim} is the averaged-over-time covariance of parameters on this dimension, treating
# all other dims as a batch axis. Also initialize as identity.
state[f"param_cov_{dim}"] = Q.clone()
# grad_cov_{dim} is the covariance of gradients on this axis (without
# any co-ordinate changes), treating all other axes as as a batch axis.
# This is needed when we re-diagonalize, and to compute the
# grad_rms_{dim}. We store it as a decaying average, decaying with beta2
# only for purposes of computing the scalar factor on the
# learning rate; and, as a result of this, also contributes something
# to the gradient w.r.t. f"Q_{dim}", which is one reason
# why we allocate a variable to keep track of its moving average
# instead of just using a temporary and smoothing the scalar factor.
state[f"grad_cov_{dim}"] = torch.zeros_like(Q)
def _get_block_size(self, size: int, max_block_size: int) -> Tuple[int,int]:
"""
Returns information about the block size for a block-diagonal structure
of covariances and co-ordinate transformations.
size: The size of the parameter tensor on one of its dimensions, e.g.
a channel dim or kernel size
max_block_size: The maximum block size allowed (of the blocks in a
block-diagonal structure)
Returns:
(num_blocks, block_size)
where nun_blocks * block_size == size and block_size <= max_block_size
"""
for num_blocks in range(1, size):
block_size = size // num_blocks
if block_size * num_blocks == size and block_size <= max_block_size:
return (num_blocks, block_size)
assert False, (size, max_block_size) # code error or e.g. negative or
# zero inputs
def _step_one_batch(self,
group: dict,
p: Tensor,
state: dict):
"""
Do the step for one parameter, which is actually going to be a batch of
`real` parameters, with dim 0 as the batch dim.
Args:
group: dict to look up configuration values
p: parameter to update (actually multiple parameters stacked together
as a batch)
state: state-dict for p, to look up the optimizer state
"""
lr = group["lr"]
size_update_period = group["size_update_period"]
grad_cov_period = group["grad_cov_period"]
eps = group["eps"]
beta1 = group["betas"][0]
grad = p.grad
step = state["step"]
delta = state["delta"]
delta.mul_(beta1)
batch_size = p.shape[0]
numel = p.numel() // batch_size
if numel > 1:
# Update the size/scale of p, and set param_rms
scale_grads = state["scale_grads"]
scale_grads[step % size_update_period] = _sum(p * grad,
exclude_dims=[0],
keepdim=True)
if step % size_update_period == size_update_period - 1:
param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..)
param_rms.copy_(_mean(p ** 2, exclude_dims=[0],
keepdim=True).sqrt().clamp_(min=eps))
if step > 0:
# self._size_update() learns the overall scale on the
# parameter, by shrinking or expanding it.
self._size_update(group, scale_grads, p, state)
if numel == 1:
# For parameters with 1 element we just use regular Adam.
# Updates delta.
self._step_scalar(group, p, state)
else:
if self._is_lr_update_step(group, state):
self._update_param_cov(group, p, state)
self._compute_bases(group, p.shape, state)
self._zero_exp_avg_sq(state)
if step % grad_cov_period == 0:
self._update_grad_cov(group, p, state)
self._step(group, p, state)
state["step"] = step + 1
def _size_update(self,
group: dict,
scale_grads: Tensor,
p: Tensor,
state: dict) -> None:
"""
Called only where p.numel() > 1, this updates the scale of the parameter.
If we imagine: p = underlying_param * scale.exp(), and we are doing
gradient descent on underlying param and on scale, this function does the update
on `scale`.
Args:
group: dict to look up configuration values
scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing
grads w.r.t. the scales.
p: The parameter to update
state: The state-dict of p
"""
param_rms = state["param_rms"]
beta1, beta2 = group["betas"]
size_lr = group["lr"] * group["size_lr_scale"]
param_min_rms = group["param_min_rms"]
param_max_rms = group["param_max_rms"]
step = state["step"]
batch_size = p.shape[0]
size_update_period = scale_grads.shape[0]
# correct beta2 for the size update period: we will have
# faster decay at this level.
beta2_corr = beta2 ** size_update_period
scale_exp_avg_sq = state["scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..)
scale_exp_avg_sq.mul_(beta2_corr).add_(
(scale_grads ** 2).mean(dim=0), # mean over dim `size_update_period`
alpha=1-beta2_corr) # shape is (batch_size, 1, 1, ...)
# The 1st time we reach here is when size_step == 1.
size_step = (step + 1) // size_update_period
bias_correction2 = 1 - beta2_corr ** size_step
# we don't bother with bias_correction1; this will help prevent divergence
# at the start of training.
eps = 1.0e-10 # this value is not critical.
denom = scale_exp_avg_sq.sqrt() + eps # shape: (batch_size, 1, 1, ..)
scale_step = -size_lr * (bias_correction2 ** 0.5) * scale_grads.sum(dim=0) / denom
is_too_small = (param_rms < param_min_rms)
is_too_large = (param_rms > param_max_rms)
scale_step.masked_fill_(is_too_small, size_lr * size_update_period)
scale_step.masked_fill_(is_too_large, -size_lr * size_update_period)
delta = state["delta"]
# the factor of (1-beta1) relates to momentum.
delta.add_(p * scale_step, alpha=(1-beta1))
def _update_param_cov(self,
group: dict,
p: Tensor,
state: dict) -> None:
"""
Updates state[f"param_cov_{dim}"] for each dim of tensor p
(except batch and trivial and rank-1 dims)
"""
eps = group["eps"]
# zero_step is always the last time we called _update_param_cov.
# Our aim is to compute the parameter covariance averaged over all time
# (in steps) from the start, so "this_weight" that we give to this
# new covariance is proportional to the interval of time it represents,
# i.e. the interval from zero_step until step, while the existing "param_cov"
# represents the interval from step until zero_step.
# The min of 0.5 is a special case to ensure that the "torch.eye" we initialized
# param_cov with gets some weight on the 1st time we call this.
zero_step = state["zero_step"]
step = state["step"]
this_weight = min(0.5, (step - zero_step) / step)
batch_size = p.shape[0]
numel = p.numel() // batch_size
for dim in range(1, p.ndim):
size = p.shape[dim]
try:
param_cov = state[f"param_cov_{dim}"]
except KeyError:
continue # e.g. size == 1 or size == numel
# param_cov shape: (batch_size, num_blocks, block_size, block_size)
(batch_size, num_blocks, block_size, block_size) = param_cov.shape
# M: (batch_size, num_blocks, x, block_size)
M = p.transpose(dim, -1).reshape(batch_size, -1,
num_blocks, block_size).transpose(1, 2)
# will be batched matrix multiply. shape of this_param_cov:
# (batch_size, num_blocks, block_size, block_size)
this_param_cov = torch.matmul(M.transpose(2, 3), M)
# normalize scale of this param_cov, in case parameter scale
# changes significantly during training, which would cause some
# parts of the training timeline to be more highly weighted.
# shape of expression on r.h.s of "/=" has shape (batch_size, 1, 1, 1)
this_param_cov /= _mean(_diag(this_param_cov),
exclude_dims=[0], keepdim=True).unsqueeze(-1) + eps
param_cov.mul_(1-this_weight).add_(this_param_cov,
alpha=this_weight)
def _zero_exp_avg_sq(self,
state: dict) -> None:
"""
Zero the exp_avg_sq stats, and set state["zero_step"] to state["step"]
"""
state["exp_avg_sq"].zero_()
state["zero_step"] = state["step"]
def _is_lr_update_step(self,
group: dict,
state: dict) -> bool:
"""
Returns True if on this step we need to update the learning-rate matrices for this tensor
and False if not (note: this may just mean we update the gradient-dependent part).
The periodicity with which we update them increases from
(by default) 200 at the start of training to 2000 later on.
"""
try:
zero_step = state["zero_step"]
except:
# This parameter tensor has no learning-rate matrices to estimate
return False
step = state["step"]
period_initial, period_final = group["lr_update_period"]
# this formula gradually increases the periodicity from period_initial at the start
# to period_final when step >> 4 * period_final
cur_update_period = (period_initial +
((period_final - period_initial) * step /
(step + 4 * period_final)))
return (step >= zero_step + cur_update_period)
def _compute_bases(self,
group: dict,
p_shape: torch.Size,
state: dict):
"""
For each tensor dimension that we are preconditioning, this function sets
state[f"Q_{dim}"] to an orthogonal matrix (per batch element and
block); it will be indexed [batch_index, block_index, diagonalized_coordinate, canonical_coordinate].
This will be the matrix that diagonalizes Z,
where Z is the symmetric matrix satisfying:
Z G Z = P (eqn:1)
with G being the gradient covariance grad_cov, P being the smoothed version of the parameter
covariance param_cov, and Z the symmetric positive definit learning-rate matrices.
We'll discuss later how we solve for Z. Firstly, because we want
a basis that is as meaningful possible for smoothing P, we will first put P and G in a basis
where G is diagonal. That is, we do an SVD G = U G' U^T, and multiplying (eqn:1) on the
left and right by U^T and U respetively, and inserting some expressions equivalent to I,
we have:
(U^T Z U) (U^T G U) (U^T Z U) = (U^T P U)
or: Z' G' Z' = P' (eqn:2)
where Z' = U^T Z U and P' = U^T P U, and of course G' is diagonal. We are actually going to
be solving (eqn:2), and then computing:
Z = U Z' U^T.
A solution to (eqn:2) is as follows. We are going to be using a Cholesky-based solution in
favor of one that requires SVD or eigenvalue decomposition, because it is much faster (we first
have to be careful that the input is not close to singular, though).
So, with the Cholesky decomposition of P' being:
P' = C C^T,
the solution is going to be:
Z' = C (C^T G' C)^{-0.5} C^T (eqn:3)
[[ we can verify that this is a solution by multiplying (eqn:1) by C^{-1} and C^{-T} on the
left and right respectively, giving:
(C^T G' C)^{-0.5} C^T G' C (C^T G' C)^{-0.5} = I
which we can immediately see is the case. ]]
It's actually going to be more convenient to compute Z'^{1} (the inverse of Z'), because
ultimately only need the basis that diagonalizes Z, and (C^T G' C) may be singular.
It's not hard to work out that:
Z'^{-1} = C^{-T} (C^T G' C)^0.5 C^{-1} (eqn:4)
Args:
group: dict to look up config values
p_shape: the shape of the batch of identical-sized tensors we are optimizing
state: dict to look up optimization state
Return:
This function returns a list of Tensors P_proj indexed by dim from 0..ndim-1, the tensors being
of shape (batch_size, size), which contain the diagonal of the smoothed parameter covariance
P after projection by the orthogonal matrix state[f"Q_{dim}"] that diagonalizes Z (this is
the space in which we do the actual update).
"""
ndim = len(p_shape)
P_proj = [None] * ndim
denom_rel_eps = group["denom_rel_eps"]
eps = group["eps"]
for dim in range(1, ndim):
size = p_shape[dim]
try:
Q = state[f"Q_{dim}"]
except KeyError:
assert size == 1 or size == numel, size
continue # e.g. size == 1 or size == numel
param_cov = state[f"param_cov_{dim}"] # (batch_size, num_blocks, block_size, block_size)
grad_cov = state[f"grad_cov_{dim}"] # (batch_size, num_blocks, block_size, block_size)
(batch_size, num_blocks, block_size, block_size) = param_cov.shape
# U_g diagonalizes grad_cov, in the sense that U_g^T grad_cov U_g is diagonal; its shape
# is the same as param_cov and grad_cov.
#
# G_prime is the diagonalized grad_cov, G' above, of shape (batch_size, num_blocks, block_size).
G = grad_cov.clone()
# Use the form of the diagonalized gradient matrix that we get after
# we add the Adam-type smoothing with epsilon.
G_diag = _diag(G) # aliased with G
G_diag += (_mean(G_diag, exclude_dims=[0], keepdim=True) *(denom_rel_eps * denom_rel_eps) +
(eps * eps))
P_unsmoothed = param_cov
P = param_cov.clone()
P = self._smooth_param_cov(group, p_shape, P, G)
if True:
_,S,_ = _svd(P)
logging.info(f"Eigs of P are {S[0][0]}")
# C will satisfy: P == torch.matmul(C, C.transpose(2, 3))
# C is of shape (batch_size, num_blocks, block_size, block_size).
#def _fake_cholesky(X):
# U_, S_, _ = _svd(X)
# return U_ * S_.sqrt().unsqueeze(-2)
#C = _fake_cholesky(P)
C = P.cholesky()
# A matrix that takes normally distributed data to P would
# be C, because C I C^T = C C^T = P. We can actually use *any* matrix
# that takes normally distributed data to P, so we can use
# C U for any orthogonal U, since C U I U^T C^T == P.
# So there is no harm in choosing a matrix U that diagonalizes the
# projected grad_cov. grad_cov gets projected by
grad_cov_proj = torch.matmul(C.transpose(2, 3),
torch.matmul(grad_cov, C))
# OK, grad_cov is diagonalized by U^T C^T. So the projection that we
# apply to the param cov is C U
U, S, _ = _svd(grad_cov_proj)
# proj is indexed [batch_idx,block_idx,canonical_coordinate,diagonalized_coordinate],
# so we need to transpose to get Q_{dim}.
proj = torch.matmul(C, U)
Q[:] = proj.transpose(2, 3)
continue
def _smooth_param_cov(self,
group: dict,
p_shape: torch.Size,
P: Tensor,
G: Tensor) -> Tensor:
"""
This function returns a modified/smoothed version of the parameter covariance
P.
Args:
group: dict to look up config values
p_shape: The shape of the parameter we are optimizing
P: a Tensor of shape (batch_size, num_blocks, block_size, block_size),
containing the parameter covariance
G: the gradient covariance, of shape (batch_size, num_blocks,
block_size, block_size)
state[f"param_cov_{dim}"], which is an estimate of the covariance of the parameter
p, averaged over time, and taken over dimension `dim` of the tensor.
The smoothing done here limits the extent to which the parameter covariance
can be strongly "off-diagonal" with respect to the gradient covariance. That is:
if the parameter covariance is just the gradient covariance to some power, this
function does no smoothing; but if it is highly off-diagonal we do more smoothing.
"""
# do smoothing based on 'rank',
# that is intended to compensate for bad estimates of P.
(batch_size, num_blocks, block_size, block_size) = P.shape
# `rank_per_block` is the rank of each block of P_prime if we were to estimate it from just one
# parameter tensor. We average it over time, but actually it won't be changing
# too much, so `rank` does tell us something.
size = num_blocks * block_size
rank = p_shape.numel() // (size * batch_size) # actually the rank of each block
smooth0 = group["param_rms_smooth0"]
smooth1 = group["param_rms_smooth1"]
# We want expr for smoothing amount to be of the form: smooth = alpha * size / (beta*rank + size)
# for "size" here, we actually want to use block_size, since we are concerned about the
# robustness of the covariance within these blocks.
# param_rms_smooth{0,1} represents the user-specified desired amount of smoothing
# when rank==0*size and rank==1*size, respectively.
# from rank==0*size, we get smooth0 = alpha * size/size, so alpha = smooth0.
# from setting rank==size, we get smooth1 = alpha * size / (beta*size * size) = alpha/(1+beta),
# so smooth1 == smooth0 / (1+beta), so (1+beta) = smooth0/smooth1, so beta=smooth0/smooth1 - 1
smooth = smooth0 * block_size / ((smooth0/smooth1 - 1) * rank + block_size)
if True:
logging.info(f"block size={block_size}, rank={rank}, smooth={smooth}")
# add rank-dependent smoothing amount to diagonal of P_prime. _diag() returns an aliased tensor.
# we don't need to multiply `smooth` by anything, because at this point, P_prime should have
# diagonal elements close to 1.
P = P.clone()
P_diag = _diag(P)
P_diag_mean = _mean(P_diag, exclude_dims=[0], keepdim=True)
P_diag += smooth * P_diag_mean
#G = G.clone()
#G_diag = _diag(G) # aliased
#G_diag *= 1.005 # ensure invertible.
G = self._smooth_cov(G,
group["cov_min"][3],
group["cov_max"][3],
group["cov_pow"][3])
P = self._apply_min_max_with_metric(P, G,
group["cov_min"][1],
group["cov_max"][1])
# Apply a 3rd round of smoothing in the canonical basis.
P = self._smooth_cov(P,
group["cov_min"][2],
group["cov_max"][2],
group["cov_pow"][2])
return P
def _smooth_cov(self,
X: Tensor,
min_eig: float,
max_eig: float,
power: float = 1.0) -> Tensor:
"""
Returns a `smoothed` version of a symmetric positive definite covariance matrix
[with block-diagonal structure, in a batch]. This is done without SVD (which
can be very slow).
The eigenvalues L will be transformed as:
L = L + min_eig * L.mean() + eps
L /= L.mean()
L = 1 / (1/L + 1/max_eig) # soft-min between L and max_eig
L = L ** power # need SVD for this, will get rid of the requirement later.
L /= L.mean()
# Note on approximation functions like x^0.75 for smallish x: on wolframalpha, type:
# plot x^0.75 and 0.05 + (1.1x - 0.18 x^2 + 0.02 x^3) for x from 0 to 10
# [this starts to diverge after 5 or so]
Args:
min_eig: minimum allowed eigenvalue of returned X
max_eig: maximum allowed eigenvalue of returned X
power: power to take eigenvalues to
X: the batch of symmetric positive definite tensors we are smoothing;
of shape (batch_size, num_blocks, block_size, block_size)
"""
eps = 1.0e-20
if power != 1.0:
U, S, _ = _svd(X)
S_mean = _mean(S, exclude_dims=[0], keepdim=True)
S = S + min_eig * S_mean + eps
S_mean = S_mean * (1 + min_eig) + eps
S = S / S_mean
S = 1. / (1./S + 1./max_eig)
S = S ** power
S = S / _mean(S, exclude_dims=[0], keepdim=True)
return torch.matmul(U * S.unsqueeze(-2), U.transpose(2, 3))
else:
X = X.clone()
diag = _diag(X) # Aliased with X
mean_eig = _mean(diag, exclude_dims=[0], keepdim=True)
diag += (mean_eig * min_eig + eps)
cur_diag_mean = mean_eig * (1 + min_eig) + eps
X /= cur_diag_mean.unsqueeze(-1)
# OK, now the mean of the diagonal of X is 1 (or less than 1, in
# which case X is extremely tiny).
# eig_ceil is the maximum possible eigenvalue that X could possibly
# have at this time, equal to num_blocks * block_size.
eig_ceil = X.shape[1] * X.shape[3]
# the next statement wslightly adjusts the target to be the same as
# what the baseline function, eig -> 1./(1./eig + 1./max_eig) would
# give. so "max_eig" is now the target of the function for arg ==
# eig_ceil.
max_eig = 1. / (1. / max_eig + 1. / eig_ceil)
while max_eig <= eig_ceil * 0.5:
# this is equivalent to the operation: l -> l - 0.5/eig_ceil*l*l
# on eigenvalues, which maps eig_ceil to 0.5*eig_ceil and is monotonically
# increasing from 0..eig_ceil.
# the transpose on the 2nd X is to try to stop small asymmetries from
# propagating.
X = X - 0.5/eig_ceil * torch.matmul(X, X.transpose(2, 3))
eig_ceil = 0.5 * eig_ceil
# max_eig > eig_ceil * 0.5
if max_eig < eig_ceil:
# map l -> l - coeff*l*l, if l==eig_ceil, this
# takes us to:
# eig_ceil - (eig_ceil-max_eig/(eig_ceil*eig_ceil))*eig_ceil*eig_ceil
# == max_eig
# .. and the fact that coeff <= 0.5/eig_ceil [since max_eig>eig_ceil*0.5]
# means that the function is monotonic on inputs from 0 to eig_ceil.
coeff = (eig_ceil - max_eig) / (eig_ceil*eig_ceil)
X = X - coeff * torch.matmul(X, X.transpose(2, 3))
# Normalize again.
X /= _mean(_diag(X), exclude_dims=[0], keepdim=True).unsqueeze(-1)
X = 0.5 * (X + X.transpose(2, 3)) # make sure exactly symmetric.
return X
def _apply_min_max_with_metric(self,
X: Tensor,
M: Tensor,
min_eig: float,
max_eig: float) -> Tensor:
"""
Smooths X with maximum eigenvalue (relative to the mean) relative to
metric M. Equivalent to applying
Y := M^{0.5} X M^{0.5}
Apply maximum to eigenvalues of Y, as done in _smooth_cov()
X := M^{-0.5} Y M^{-0.5}
Args:
X: Batch of block-diagonal nonzero positive definite matrices to have max eigenvalue applied
M: Batch of positive definite block-diagonal matrices to use as metrics w.r.t. X.
"""
X = X.clone()
# size of the block-diagonal matrix..
size = X.shape[1] * X.shape[3]
# mean eig of M^{0.5} X M^{0.5} ...
mean_eig = _sum(X*M, exclude_dims=[0], keepdim=True) / size
# make sure eigs of M^{0.5} X M^{0.5} are average 1. this imposes limit on the max.
X /= mean_eig
if min_eig != 0.0:
X = X * (1.0-min_eig) + min_eig * M.inverse()
X = 0.5 * (X + X.transpose(-2, -1)) # make sure exactly symmetric.
# eig_ceil is the maximum possible eigenvalue that X could possibly
# have at this time, equal to num_blocks * block_size.
eig_ceil = size
# the next statement wslightly adjusts the target to be the same as
# what the baseline function, eig -> 1./(1./eig + 1./max_eig) would
# give. so "max_eig" is now the target of the function for arg ==
# eig_ceil.
max_eig = 1. / (1. / max_eig + 1. / eig_ceil)
while max_eig <= eig_ceil * 0.5:
# this is equivalent to the operation: l -> l - 0.5/eig_ceil*l*l
# on eigenvalues, which maps eig_ceil to 0.5*eig_ceil and is monotonically
# increasing from 0..eig_ceil.
# The transpose is to try to stop small asymmetries from increasing.
X = X - 0.5/eig_ceil * torch.matmul(X, torch.matmul(M, X.transpose(2, 3)))
eig_ceil = 0.5 * eig_ceil
# max_eig > eig_ceil * 0.5
if max_eig < eig_ceil:
# map l -> l - coeff*l*l, if l==eig_ceil, this
# takes us to:
# eig_ceil - (eig_ceil-max_eig/(eig_ceil*eig_ceil))*eig_ceil*eig_ceil
# == max_eig
# .. and the fact that coeff <= 0.5/eig_ceil [since max_eig>eig_ceil*0.5]
# means that the function is monotonic on inputs from 0 to eig_ceil.
coeff = (eig_ceil - max_eig) / (eig_ceil*eig_ceil)
X = X - coeff * torch.matmul(X, torch.matmul(M, X.transpose(2, 3)))
# Normalize again.
X /= _mean(_diag(X), exclude_dims=[0], keepdim=True).unsqueeze(-1)
X = 0.5 * (X + X.transpose(-2, -1)) # make sure exactly symmetric.
return X
def _update_grad_cov(self,
group: dict,
p: Tensor,
state: dict):
"""
Updates the gradient covariance state[f"grad_cov_{dim}"] for appropriate
dimensions of the tensor.
"""
beta2 = group["betas"][1]
grad = p.grad
batch_size = p.shape[0]
for dim in range(1, grad.ndim): # dim 0 is batch dim.
size = grad.shape[dim]
name = f"grad_cov_{dim}"
if name not in state:
continue # e.g. size==1, size == numel, size > max_block_size
grad_cov = state[name] # shaped: (batch_size, num_blocks, block_size, block_size)
(batch_size, num_blocks, block_size, block_size) = grad_cov.shape
g = grad.transpose(-1, dim)
# if grad were of shape (batch_size, x, size, y, z),
# g would be of shape (batch_size, x, z, y, size); and
# after the next line, g will be of shape
# (batch_size, x, z, y, num_blocks, block_size)
g = g.reshape(*g.shape[:-1], num_blocks, block_size)
g = _move_dim(g, -2, 1) # (batch_size, num_blocks, x, z, y, block_size)
g = g.reshape(batch_size, num_blocks, -1, block_size)
# now g is of shape (batch_size, num_blocks, x*z*y, block_size)
# this_grad_cov: (batch_size, num_blocks, block_size, block_size)
this_grad_cov = torch.matmul(g.transpose(-2, -1), g)
grad_cov.mul_(beta2).add_(this_grad_cov, alpha=(1-beta2))
def _step(self,
group: dict,
p: Tensor,
state: dict):
"""
This function does the core update of self.step(), in the case where the tensor
has more than 1 elements. It multiplies the moving-average gradient tensor by a
state["lr_{dim}"] on each non-trivial axis/dimension, does a length normalization,
and takes a step in that direction.
Args:
group: A dict which will be used to look up configuration values
p: The parameter to be updated
grad: The grad of p
state: The state-dict corresponding to parameter p
batch: if true, the first dimension of p is a batch dim, i.e. a batch of
parameter matrices.
This function modifies p.
"""
grad = p.grad
lr = group["lr"]
beta1, beta2 = group["betas"]
eps = group["eps"]
denom_rel_eps = group["denom_rel_eps"]
step = state["step"]
grad = self._project(grad, state, forward=True)
exp_avg_sq = state["exp_avg_sq"]
exp_avg_sq.mul_(beta2).addcmul_(grad, grad,
value=(1-beta2))
this_step = state["step"] - (state["zero_step"] if "zero_step" in state else 0)
bias_correction2 = 1 - beta2 ** (this_step + 1)
if bias_correction2 < 0.99:
# note: not in-place.
exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2)
denom = exp_avg_sq.sqrt()
denom += eps + denom_rel_eps * _mean(denom, exclude_dims=[0], keepdim=True)
grad = grad / denom
# and project back..
grad = self._project(grad, state, forward=False)
alpha = -lr * (1-beta1) * state["param_rms"]
delta = state["delta"]
delta.add_(grad * alpha)
p.add_(delta)
def _step_scalar(self,
group: dict,
p: Tensor,
state: dict):
"""
A simplified form of the core update for scalar tensors, where we cannot get a good
estimate of the parameter rms. p will not be a scalar, because it has a batch dim.
"""
beta1, beta2 = group["betas"]
scalar_max = group["scalar_max"]
eps = group["eps"]
lr = group["lr"]
grad = p.grad
exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad,
value=1-beta2)
# bias_correction2 is like in Adam. Don't bother with bias_correction1;
# slower update at the start will help stability anyway.
bias_correction2 = 1 - beta2 ** (state["step"] + 1)
denom = (exp_avg_sq / bias_correction2).sqrt() + eps
delta = state["delta"]
delta.add_(grad / denom, alpha=-lr*(1-beta1))
p.clamp_(min=-scalar_max, max=scalar_max)
p.add_(delta)
def _project(self,
M: Tensor,
state: dict,
forward: bool) -> Tensor:
"""
Multiply a tensor M by a projection Q on each of its dimensions (for which we have
such a projection)
Args:
M: The tensor to project, e.g. a gradient or a parameter change.
state: dict to look up the projections
forward: if True, go in the forward direction (from canonical to diagnonalized
co-ordinates); if False, the reverse. They differ by a transpose,
not an inverse, and do not make a round trip.
"""
numel = M.numel()
for dim in range(1, M.ndim): # dim 0 is batch dim (batch of parameter
# tensors of same shape)
try:
Q = state[f"Q_{dim}"] # shape: (batch_size, num_blocks, block_size, block_size)
except KeyError:
continue # no projection for this dim
(batch_size, num_blocks, block_size, block_size) = Q.shape
size = M.shape[dim] # == num_blocks * block_size
if forward:
# Q is indexed
# [batch_idx, block_idx, diagonalized_idx, canonical_idx]; in
# the forward direction we want to change canonical to
# diagonalized index, so we transpose.
Q = Q.transpose(2, 3)
# assume M currently has shape (batch_size, x, size, y, z), and dim==2.
M = M.transpose(-1, dim) # (batch_size, x, y, z, size)
M = M.reshape(batch_size, *M.shape[1:-1],
num_blocks, block_size) # (batch_size, x, y, z, num_blocks, block_size)
M = _move_dim(M, -2, 1) # (batch_size, num_blocks, x, y, z, block_size)
while Q.ndim < M.ndim:
Q = Q.unsqueeze(2)
# now Q has shape (batch_size, num_blocks, 1, 1, block_size, block_size)
if M.ndim < Q.ndim: # special case where M shape is e.g. (batch_size, num_blocks, x)
M = torch.matmul(M.unsqueeze(-2), Q).squeeze(-2)
else:
M = torch.matmul(M, Q) # (batch_size, num_blocks, x, y, z, block_size)
M = _move_dim(M, 1, -2) # (batch_size, x, y, z, num_blocks, block_size)
M = M.reshape(*M.shape[:-2], size) # (batch_size, x, y, z, size)
M = M.transpose(-1, dim) # (batch_size, x, size, y, z)
return M
def _move_dim(x: Tensor, orig_dim: int, new_dim: int) -> Tensor:
"""
Moves the position of one dimension in a Tensor, while keeping
the remaining dims in the same order. E.g. if x has
shape [10, 1, 2, 3], _move_dim(x, 3, 1) will have shape
[10, 3, 1, 2]. Returns the reshaped tensor which will be
a view of the original Tensor.
Args:
x: Tensor to reshape
orig_dim: Original dimension to move, with -x.ndim <= orig_dim < x.ndim
new_dim: New position of the original dimension, with
-x.ndim <= new_dim < x.ndim
Returns: a permuted view of x
"""
if orig_dim < 0:
orig_dim += x.ndim
if new_dim < 0:
new_dim += x.ndim
dims = list(range(x.ndim))
dims[orig_dim] = -1
dims.remove(-1)
dims.insert(new_dim, orig_dim)
return x.permute(dims)
def _diag(x: Tensor):
"""
like torch diag(), this returns the diagonal of a matrix; but it returns an
aliased tensor, and it supports batch dims,
i.e. input of shape (B, M, M) returns
output of shape (B, M), or input of shape (A, B, M, M) returns output of shape
(A, B, M).
"""
stride = x.stride()
if x.ndim == 3:
(B, M, M2) = x.shape
assert M == M2
ans = x.as_strided(size=(B, M), stride=(stride[0], stride[1] + stride[2]))
elif x.ndim == 4:
(B, C, M, M2) = x.shape
assert M == M2
ans = x.as_strided(size=(B, C, M), stride=(stride[0], stride[1], stride[2] + stride[3]))
elif x.ndim == 2:
(M, M2) = x.shape
assert M == M2
ans = x.as_strided(size=(M,), stride=(stride[0] + stride[1],))
return ans
def _sum(x: Tensor,
exclude_dims: List[int] = [],
keepdim: bool = False):
"""
Version of torch.sum where you specify dims to exclude, not dims to sum.
"""
if len(exclude_dims) == 0:
# dim=[] works the same as tuple(range(x.ndim)), which makes
# no sense.. supplying the full tuple for clarity and in case
# the interface changes in future.
return x.sum(dim=tuple(range(x.ndim)), keepdim=keepdim)
elif x.ndim == 1:
assert exclude_dim == [0] or exclude_dim == [-1]
return x # if one dim is excluded, there are no dims to sum, and
# x.sum(dim=[]) sums all dims so we should not call sum().
exclude_dims_norm = [i if i >= 0 else i + x.ndim for i in exclude_dims]
dims = [i for i in range(x.ndim) if i not in exclude_dims_norm]
assert (len(dims) + len(exclude_dims) == x.ndim), exclude_dims
return x.sum(dim=dims, keepdim=keepdim)
def _mean(x: Tensor,
exclude_dims: List[int] = [],
keepdim: bool = False):
"""
Version of torch.mean where you specify dims to exclude, not dims to mean.
"""
if len(exclude_dims) == 0:
# dim=[] works the same as tuple(range(x.ndim)), which makes
# no sense.. supplying the full tuple for clarity and in case
# the interface changes in future.
return x.mean(dim=tuple(range(x.ndim)), keepdim=keepdim)
elif x.ndim == 1:
assert exclude_dims == [0] or exclude_dims == [-1]
return x # if one dim is excluded, there are no dims to mean, and
# x.mean(dim=[]) means all dims so we should not call mean().
exclude_dims_norm = [i if i >= 0 else i + x.ndim for i in exclude_dims]
dims = [i for i in range(x.ndim) if i not in exclude_dims_norm]
assert (len(dims) + len(exclude_dims) == x.ndim), exclude_dims
return x.mean(dim=dims, keepdim=keepdim)
def _svd(x: Tensor, recursion_depth: int = 0):
# Wrapper for torch svd that catches errors and re-tries (to address bugs in
# some versions of torch SVD for CUDA)
xsum = x.sum()
if not (xsum - xsum == 0):
raise RuntimeError("svd on inf or nan input")
try:
U, S, V = x.svd()
s = U.sum()
if s - s != 0 or (__name__ == '__main__' and random.random() < 0.01):
raise RuntimeError(f"svd failed (or random test), sum={s}")
return U, S, V # success
except:
if recursion_depth < 2:
logging.warning(f"svd failed: recursion_depth={recursion_depth}, "
f"x.shape={tuple(x.shape)}, x.sum()={x.sum()}: likely "
f"error in SVD. Will try again after random permutation of rows")
# randomly permute the row indexes of the matrices, and retry, hoping this
# fixes the issue
n = x.shape[-2]
randperm = torch.randperm(n, device=x.device)
inv_randperm = torch.zeros(n, device=x.device, dtype=torch.int64)
inv_randperm[randperm] = torch.arange(n, device=x.device)
x = torch.index_select(x, dim=-2, index=randperm)
U,S,V = _svd(x, recursion_depth + 1)
return torch.index_select(U, dim=-2, index=inv_randperm), S, V
elif recursion_depth < 4:
logging.warning(f"svd failed after {recursion_depth} tries: x.shape={tuple(x.shape)}, x.sum()={x.sum()}. Will try orthogonal transformation")
Urand, _, _ = torch.randn(x.shape[-2], x.shape[-1], device=x.device,
dtype=x.dtype).svd()
U, S, V = _svd(torch.matmul(Urand, x),
recursion_depth + 1)
return torch.matmul(Urand.t(), U), S, V
else:
raise RuntimeError(f"svd failed after {recursion_depth} tries")
class Cain(Optimizer):
"""Implements Cain algorithm, which is a substantially modified form of the
Adam optimizer.
The modifications can be summarized as follows:
(i) change Adam to what we called "Eve", which adds an extra configuration
value, "target_rms" (default value: 0.1); and for non-scalar parameters,
if their root-mean-square value exceeds "target_rms", we apply weight
decay (aggressive enough weight decay that the rms stays at target_rms).
This weight decay is applied in the same way as in AdamW, i.e. by
parameter shrinkage rather than by modifying the gradients.
(ii) By limiting the rms to 0.1, we reduce our modeling power; we restore
it by replacing all weights and biases-- in fact, all model parameters--
with expressions like
weight * weight_scale.exp() or bias * bias_scale.exp().
This has the benefit of making the learnable offsets and scales of
BatchNorm/LayerNorm unnecessary.
Below we describe the changes from Eve to Cain:
(iii) We removed all the weight_scale and bias_scale parameters from
the model and "moved them into the optimizer". Now we allow
parameters to have any rms value (that's <= 10), but we:
(a) make the update speed for a parameter proportional to
its rms value.
(b) Multiply the learning rate by 10 because we are getting rid of
target_rms=0.1
(c) simulate the effect of learning the weight_scale or bias_scale
in log space using standard Adam, with a scale to prevent it
from learning too fast. (see scale_exp_avg_sq in the
code below).
The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_.
The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_.
Eve is unpublished so far.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
scale_speed (float, optional): we multiply the learning rate by this value
when learning (in log space) the scales of the parameter tensors
rms_eps (float, optional): epsilon value such that when parameter
rms value goes below this, we stop learning slower for that parameter,
i.e. we treat the rms as if it were rms_eps. In practice, rms values
will not go much below this.
param_max (float, optional): maximum absolute value for any parameter;
we will clamp parameters to satisfy -param_max <= param <= param_max
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _Decoupled Weight Decay Regularization:
https://arxiv.org/abs/1711.05101
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
def __init__(
self,
params,
lr=1e-2,
betas=(0.9, 0.98),
eps=1e-8,
scale_speed=0.05,
rms_eps=1.0e-05,
param_max=20.0,
):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError(
"Invalid beta parameter at index 0: {}".format(betas[0])
)
if not 0.0 <= betas[1] < 1.0:
raise ValueError(
"Invalid beta parameter at index 1: {}".format(betas[1])
)
if not 0 < scale_speed < 1.0:
raise ValueError("Invalid scale_speed value: {}".format(scale_speed))
if not 0.1 < param_max <= 10000.0:
raise ValueError("Invalid param_max value: {}".format(param_max))
if not 0.0 < rms_eps < 0.1:
raise ValueError("Invalid rms_eps value: {}".format(rms_eps))
defaults = dict(
lr=lr,
betas=betas,
eps=eps,
scale_speed=0.05,
rms_eps=rms_eps,
param_max=param_max,
)
super(Cain, self).__init__(params, defaults)
def __setstate__(self, state):
super(Cain, self).__setstate__(state)
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
beta1, beta2 = group["betas"]
lr = group["lr"]
scale_speed = group["scale_speed"]
rms_eps = group["rms_eps"]
eps = group["eps"]
param_max = group["param_max"]
for p in group["params"]:
if p.grad is None:
continue
# Perform optimization step
grad = p.grad
if grad.is_sparse:
raise RuntimeError(
"Cain does not support sparse gradients"
)
state = self.state[p]
# State initialization
if len(state) == 0:
state["step"] = 0
# Exponential moving average of gradient values
state["delta"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
if p.numel() > 1:
# we keep track of the RMS value of the parameters to
# help set the learning rate... this emulates Eve.
state["param_rms"] = torch.ones_like(p)
# we learn the scale on the parameters in a way that
# also emulates Eve. scale_exp_avg_sq is the
# moving-average square of the gradient w.r.t. this
# scalar scale.
state["scale_exp_avg_sq"] = torch.zeros((), device=p.device,
dtype=p.dtype)
delta, exp_avg_sq = state["delta"], state["exp_avg_sq"]
step = state["step"]
delta.mul_(beta1)
if p.numel() > 1:
if True:
# This block learns the scale on the parameters.
# scale_deriv is the derivative w.r.t. a log-space scale
# on the parameters, i.e. if we imagine: p =
# underlying_param * scale.exp(), scale_deriv is the
# derivative w.r.t. scale (it does not matter
# what value `scale` currently has).
scale_deriv = (p * grad).sum()
scale_exp_avg_sq = state["scale_exp_avg_sq"]
scale_exp_avg_sq.mul_(beta2).addcmul_(scale_deriv, scale_deriv,
value=1 - beta2)
# should actually be step + 1, so on 1st minibatch we are not learning
# anything here. May fix this at some point.
scale_bias_correction2 = 1 - beta2 ** (step + 1)
scale_denom = (scale_exp_avg_sq.sqrt()).add_(eps)
scale_alpha = -lr * (scale_bias_correction2 ** 0.5) * scale_speed * (1-beta1)
# scale_delta is going to be the change in log-scale (before momentum), i.e. if
# p = underlying_param * scale.exp(),
# delta is the change in `scale`.
scale_delta = scale_alpha * (scale_deriv / scale_denom)
delta.add_(p, alpha=scale_delta)
# Decay the second moment running average coefficient
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
bias_correction2 = 1 - beta2 ** step
grad_rms = (exp_avg_sq.sqrt()).add_(eps)
this_delta = grad / grad_rms
param_rms = state["param_rms"]
self._update_param_rms(p, param_rms, step, rms_eps)
alpha = (-(1-beta1) * lr * bias_correction2)
# geometrically interpolate the per-element param_rms with its mean
#param_rms_half = (param_rms * param_rms.mean()) ** 0.5
delta.add_(this_delta * param_rms, alpha=alpha)
# below, will do: delta.add_(this_delta, alpha=alpha)
else:
# p.numel() == 1
# Decay the second moment running average coefficient
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
bias_correction2 = 1 - beta2 ** step
denom = (exp_avg_sq.sqrt()).add_(eps)
this_delta = grad / denom
alpha = -lr * (1-beta1) * (bias_correction2 ** 0.5)
delta.add_(this_delta, alpha=alpha)
p.add_(delta)
if step % 10 == 0:
p.clamp_(min=-param_max, max=param_max)
state["step"] += 1
return loss
def _update_param_rms(self, p: Tensor, param_rms: Tensor,
step: int, rms_eps: float):
"""
Updates `param_rms`. Require p.numel() > 1
Args:
p: parameter whose magnitude/rms we are measuring
param_rms: a Tensor containing non-negative values, with the same
shape as p. Contains an estimate of the root-mean-square
value of the parameter. This has a special structure where
it must be a product of one-dimensional quantities,
one for each axis of p.
We exclude certain axes to avoid estimating the rms value
from a too-small number of parameters.
step: the step of the algorithm
rms_eps: epsilon such that we'll aim to prevent param_rms from
going much lower than this.
"""
if step % 1000 == 0:
# Periodically refresh param_rms to keep the invariance that it is a
# product over its axes (numerical roundoff might destroy this)
param_rms.fill_(1.0)
for dim in range(p.ndim):
self._update_param_rms_for_dim(p, param_rms, dim, rms_eps)
else:
dim = step % p.ndim
self._update_param_rms_for_dim(p, param_rms, dim, rms_eps)
def _update_param_rms_for_dim(self, p: Tensor,
param_rms: Tensor,
dim: int,
rms_eps: float):
"""
Updates `param_rms` on one of its axes/dimensions.
Args:
p: parameter whose magnitude/rms we are measuring
param_rms: a Tensor containing non-negative values, with the same
shape as p. Contains an estimate of the root-mean-square
value of the parameter. This has a special structure where
it must be a product of one-dimensional quantities,
one for each axis of p.
We exclude certain axes to avoid estimating the rms value
from a too-small number of parameters.
dim: with 0 <= dim < p.ndim, the
dimension/axis on which to update the scale factor
rms_eps: epsilon such that we'll aim to prevent param_rms from
going much lower than this.
"""
# var_factor is the square of the factor to change param_rms by. p will
# not be super large, we limit it to -param_max <= p <= param_max.
# Clamping p**2 to min=1.0e-10 should prevent param_rms from getting much
# smaller than 1.0e-05, which should prevent division by zero here.
var_factor = (p ** 2).clamp(min=rms_eps*rms_eps) / (param_rms ** 2)
if p.numel() <= 2 * p.shape[dim]:
var_factor = var_factor.mean()
else:
dims = tuple(d for d in range(p.ndim) if d != dim)
var_factor = var_factor.mean(dim=dims, keepdim=True)
#if random.random() < 0.01:
# logging.info(f"shape={p.shape}, var_factor={var_factor}")
param_rms.mul_(var_factor.sqrt())
class LRScheduler(object):
"""
Base-class for learning rate schedulers where the learning-rate depends on both the
batch and the epoch.
"""
def __init__(self, optimizer: Optimizer, verbose: bool = False):
# Attach optimizer
if not isinstance(optimizer, Optimizer):
raise TypeError(
"{} is not an Optimizer".format(type(optimizer).__name__)
)
self.optimizer = optimizer
self.verbose = verbose
for group in optimizer.param_groups:
group.setdefault("initial_lr", group["lr"])
self.base_lrs = [
group["initial_lr"] for group in optimizer.param_groups
]
self.epoch = 0
self.batch = 0
def state_dict(self):
"""Returns the state of the scheduler as a :class:`dict`.
It contains an entry for every variable in self.__dict__ which
is not the optimizer.
"""
return {
"base_lrs": self.base_lrs,
"epoch": self.epoch,
"batch": self.batch,
}
def load_state_dict(self, state_dict):
"""Loads the schedulers state.
Args:
state_dict (dict): scheduler state. Should be an object returned
from a call to :meth:`state_dict`.
"""
self.__dict__.update(state_dict)
def get_last_lr(self) -> List[float]:
"""Return last computed learning rate by current scheduler. Will be a list of float."""
return self._last_lr
def get_lr(self):
# Compute list of learning rates from self.epoch and self.batch and
# self.base_lrs; this must be overloaded by the user.
# e.g. return [some_formula(self.batch, self.epoch, base_lr) for base_lr in self.base_lrs ]
raise NotImplementedError
def step_batch(self, batch: Optional[int] = None) -> None:
# Step the batch index, or just set it. If `batch` is specified, it
# must be the batch index from the start of training, i.e. summed over
# all epochs.
# You can call this in any order; if you don't provide 'batch', it should
# of course be called once per batch.
if batch is not None:
self.batch = batch
else:
self.batch = self.batch + 1
self._set_lrs()
def step_epoch(self, epoch: Optional[int] = None):
# Step the epoch index, or just set it. If you provide the 'epoch' arg,
# you should call this at the start of the epoch; if you don't provide the 'epoch'
# arg, you should call it at the end of the epoch.
if epoch is not None:
self.epoch = epoch
else:
self.epoch = self.epoch + 1
self._set_lrs()
def _set_lrs(self):
values = self.get_lr()
assert len(values) == len(self.optimizer.param_groups)
for i, data in enumerate(zip(self.optimizer.param_groups, values)):
param_group, lr = data
param_group["lr"] = lr
self.print_lr(self.verbose, i, lr)
self._last_lr = [group["lr"] for group in self.optimizer.param_groups]
def print_lr(self, is_verbose, group, lr):
"""Display the current learning rate."""
if is_verbose:
logging.info(
f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate"
f" of group {group} to {lr:.4e}."
)
class Eve(Optimizer):
"""
Implements Eve algorithm. This is a modified version of AdamW with a special
way of setting the weight-decay / shrinkage-factor, which is designed to make the
rms of the parameters approach a particular target_rms (default: 0.1). This is
for use with networks with 'scaled' versions of modules (see scaling.py), which
will be close to invariant to the absolute scale on the parameter matrix.
The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_.
The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_.
Eve is unpublished so far.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay coefficient (default: 3e-4;
this value means that the weight would decay significantly after
about 3k minibatches. Is not multiplied by learning rate, but
is conditional on RMS-value of parameter being > target_rms.
target_rms (float, optional): target root-mean-square value of
parameters, if they fall below this we will stop applying weight decay.
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _Decoupled Weight Decay Regularization:
https://arxiv.org/abs/1711.05101
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.98),
eps=1e-8,
weight_decay=1e-3,
target_rms=0.1,
):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError(
"Invalid beta parameter at index 0: {}".format(betas[0])
)
if not 0.0 <= betas[1] < 1.0:
raise ValueError(
"Invalid beta parameter at index 1: {}".format(betas[1])
)
if not 0 <= weight_decay <= 0.1:
raise ValueError(
"Invalid weight_decay value: {}".format(weight_decay)
)
if not 0 < target_rms <= 10.0:
raise ValueError("Invalid target_rms value: {}".format(target_rms))
defaults = dict(
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
target_rms=target_rms,
)
super(Eve, self).__init__(params, defaults)
def __setstate__(self, state):
super(Eve, self).__setstate__(state)
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
# Perform optimization step
grad = p.grad
if grad.is_sparse:
raise RuntimeError(
"AdamW does not support sparse gradients"
)
state = self.state[p]
# State initialization
if len(state) == 0:
state["step"] = 0
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
beta1, beta2 = group["betas"]
state["step"] += 1
bias_correction1 = 1 - beta1 ** state["step"]
bias_correction2 = 1 - beta2 ** state["step"]
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_(
group["eps"]
)
step_size = group["lr"] / bias_correction1
target_rms = group["target_rms"]
weight_decay = group["weight_decay"]
if p.numel() > 1:
# avoid applying this weight-decay on "scaling factors"
# (which are scalar).
is_above_target_rms = p.norm() > (
target_rms * (p.numel() ** 0.5)
)
p.mul_(1 - (weight_decay * is_above_target_rms))
p.addcdiv_(exp_avg, denom, value=-step_size)
if random.random() < 0.0005:
step = (exp_avg/denom) * step_size
logging.info(f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}")
return loss
class Eden(LRScheduler):
"""
Eden scheduler.
lr = initial_lr * (((batch**2 + lr_batches**2) / lr_batches**2) ** -0.25 *
(((epoch**2 + lr_epochs**2) / lr_epochs**2) ** -0.25))
E.g. suggest initial-lr = 0.003 (passed to optimizer).
Args:
optimizer: the optimizer to change the learning rates on
lr_batches: the number of batches after which we start significantly
decreasing the learning rate, suggest 5000.
lr_epochs: the number of epochs after which we start significantly
decreasing the learning rate, suggest 6 if you plan to do e.g.
20 to 40 epochs, but may need smaller number if dataset is huge
and you will do few epochs.
"""
def __init__(
self,
optimizer: Optimizer,
lr_batches: Union[int, float],
lr_epochs: Union[int, float],
verbose: bool = False,
):
super(Eden, self).__init__(optimizer, verbose)
self.lr_batches = lr_batches
self.lr_epochs = lr_epochs
def get_lr(self):
factor = (
(self.batch ** 2 + self.lr_batches ** 2) / self.lr_batches ** 2
) ** -0.25 * (
((self.epoch ** 2 + self.lr_epochs ** 2) / self.lr_epochs ** 2)
** -0.25
)
return [x * factor for x in self.base_lrs]
def _test_eden():
m = torch.nn.Linear(100, 100)
optim = Cain(m.parameters(), lr=0.003)
scheduler = Eden(optim, lr_batches=100, lr_epochs=2, verbose=True)
for epoch in range(10):
scheduler.step_epoch(epoch) # sets epoch to `epoch`
for step in range(20):
x = torch.randn(200, 100).detach()
x.requires_grad = True
y = m(x)
dy = torch.randn(200, 100).detach()
f = (y * dy).sum()
f.backward()
optim.step()
scheduler.step_batch()
optim.zero_grad()
logging.info(f"last lr = {scheduler.get_last_lr()}")
logging.info(f"state dict = {scheduler.state_dict()}")
def _test_eve_cain():
import timeit
from scaling import ScaledLinear
E = 100
B = 4
T = 2
logging.info("in test_eve_cain")
#device = torch.device('cuda')
device = torch.device('cpu')
dtype = torch.float32
fix_random_seed(42)
# these input_magnitudes and output_magnitudes are to test that
# Abel is working as we expect and is able to adjust scales of
# different dims differently.
input_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp()
output_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp()
#for iter in [3, 2, 1, 0]: # will restore 1,0 later
for iter in [3, 2]:
fix_random_seed(42)
Linear = torch.nn.Linear if iter == 0 else ScaledLinear
hidden_dim = 300
m = torch.nn.Sequential(Linear(E, hidden_dim),
torch.nn.PReLU(),
Linear(hidden_dim, hidden_dim),
torch.nn.PReLU(),
Linear(hidden_dim, E),
).to(device)
train_pairs = [ (100.0 * torch.randn(B, T, E, device=device, dtype=dtype) * input_magnitudes,
torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes) for _ in range(20) ]
if iter == 0: optim = Eve(m.parameters(), lr=0.003)
elif iter == 1: optim = Cain(m.parameters(), lr=0.03)
elif iter == 2: optim = PrAdam(m.parameters(), lr=0.03)
elif iter == 3: optim = PrAdam(m.parameters(), lr=0.03, max_block_size=256)
scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False)
start = timeit.default_timer()
avg_loss = 0.0
for epoch in range(180):
scheduler.step_epoch()
#if epoch == 100 and iter in [2,3]:
# optim.reset_speedup() # check it doesn't crash.
#if epoch == 130:
# opts = diagnostics.TensorDiagnosticOptions(
# 2 ** 22
# ) # allow 4 megabytes per sub-module
# diagnostic = diagnostics.attach_diagnostics(m, opts)
for n, (x,y) in enumerate(train_pairs):
y_out = m(x)
loss = ((y_out - y)**2).mean() * 100.0
if epoch == 0 and n == 0:
avg_loss = loss.item()
else:
avg_loss = 0.98 * avg_loss + 0.02 * loss.item()
if n == 0 and epoch % 5 == 0:
#norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item()
#norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item()
#norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item()
#norm2b = '%.2e' % (m[2].bias**2).mean().sqrt().item()
#scale1 = '%.2e' % (m[0].weight_scale.exp().item())
#scale1b = '%.2e' % (m[0].bias_scale.exp().item())
#scale2 = '%.2e' % (m[2].weight_scale.exp().item())
#scale2b = '%.2e' % (m[2].bias_scale.exp().item())
lr = scheduler.get_last_lr()[0]
logging.info(f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}") #, norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b}
loss.log().backward()
optim.step()
optim.zero_grad()
scheduler.step_batch()
#diagnostic.print_diagnostics()
stop = timeit.default_timer()
logging.info(f"Iter={iter}, Time taken: {stop - start}")
logging.info(f"last lr = {scheduler.get_last_lr()}")
#logging.info("state dict = ", scheduler.state_dict())
#logging.info("optim state_dict = ", optim.state_dict())
logging.info(f"input_magnitudes = {input_magnitudes}")
logging.info(f"output_magnitudes = {output_magnitudes}")
def _test_svd():
device = 'cuda'
for i in range(1000):
X = torch.eye(100, device=device) + 1.0e-08 * torch.randn(100, 100, device=device)
U, S, V = _svd(X)
X2 = torch.matmul(U*S, V.t())
assert torch.allclose(X2, X, atol=0.001)
def _test_smooth_cov():
b = (torch.arange(10)*0.5).exp().diag()
b = b.unsqueeze(0).unsqueeze(0)
c = PrAdam._smooth_cov(None, b, 0.0, 10.0)
logging.info(f"c[noinv] = {_diag(c)}")
c = PrAdam._smooth_cov(None, b, 0.0, 10.0, 1.00001)
logging.info(f"c[svd] = {_diag(c)}")
if __name__ == "__main__":
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
logging.getLogger().setLevel(logging.INFO)
import subprocess
s = subprocess.check_output("git status -uno .; git log -1", shell=True)
_test_smooth_cov()
logging.info(s)
#_test_svd()
_test_eve_cain()
#_test_eden()