1840 lines
77 KiB
Python

# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
#
# See ../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
from typing import List, Optional, Union, Tuple, List
from lhotse.utils import fix_random_seed
import torch
from scaling import ActivationBalancer
import random
from torch import Tensor
from torch.optim import Optimizer
import logging
class BatchedOptimizer(Optimizer):
"""
This class adds to class Optimizer the capability to optimize parameters in batches:
it will stack the parameters and their grads for you so the optimizer can work
on tensors with an extra leading dimension.
Args:
params:
"""
def __init__(self, params, defaults):
super(BatchedOptimizer, self).__init__(params, defaults)
def batched_params(self, param_group):
"""
This function returns an iterator over tuples (p, state), where
p is a `fake` parameter that is stacked (over axis 0) from real parameters
that share the same shape, and its gradient is also stacked;
`state` is the state corresponding to this fake parameter
(it will be physically located in the "state" for one of the real
parameters, the last one that has any particular shape and dtype).
The idea is, instead of doing:
<code>
for p in group["params"]:
state = self.state[p]
...
</code>
you can do:
<code>
for p, state in self.batched_params(group["params"]):
...
</code>
Args:
group: a parameter group, which is a list of parameters; should be
one of self.groups.
"""
batches = defaultdict(list) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter
for p in param_group:
assert p.grad is not None and "All parameters must have a grad when you call optimizer.step()"
assert not p.grad.is_sparse and "Sparse gradients not supported."
key = (str(p.dtype), *p.shape)
batches[key].append(p)
for p in param_group:
key = (str(p.dtype), *p.shape)
batch = batches[key]
if p is batch[0]: # if this is the 1st param in the batch...
# we arbitrarily store the state in the
# state corresponding to the 1st parameter in the
# group. class Optimizer will take care of saving/loading state.
state = self.state[p]
p_stacked = torch.stack(batch)
grad = torch.stack([p.grad for p in batch])
p_stacked.grad = grad
yield p_stacked, state # <-- calling code will do the actual optimization here!
# Now un-stack the parameter changes
for i,p in enumerate(batch):
p.copy_(p_stacked[i])
class PrAdam(BatchedOptimizer):
"""
Implements 'Projected Adam', a variant of Adam with projections for each
dim of each tensor.
Args:
params: The parameters or param_groups to optimize (like other Optimizer subclasses)
lr: The learning rate. We will typically use a learning rate schedule that starts
at 0.03 and decreases over time, i.e. much higher than other common
optimizers.
betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad only for
scalar tensors. Must satisfy 0 < beta <= beta2 < 1.
size_lr_scale: A scaling factor on the learning rate, that we use to update the
scale of each parameter tensor. If each parameter were decomposed
as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, size_lr_scale
is the scaling factor on the learning rate of p_scale.
param_pow: Power on the parameter covariance matrix, 1.0 means learn proportional
to parameter rms (1.0 will be too much, should be between 0 and 1.)
param_rms_smooth0: Limiting value of smoothing proportion for parameter matrix, as
assumed rank of param covariance [==product of sizes on the other
tensor dims] approaches 0.
param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
param covariance equals the dimension of the covaraince matrix.
param_rms_smooth{0,1} determine the smoothing proportions for other
conditions.
param_cov_freshness: Constant that determines how "recent" the parameter covariance
matrix is. 1.0 corresponds to flat average over time; 2.0
would mean weighting with a quadratic function, etc.
eps: An epsilon to prevent division by zero
param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of
learning the scale on the parameters (we'll keep it >= this size)
param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of
learning the scale on the parameters (we'll keep it <= this size)
scalar_max: Maximum absolute value for scalar parameters
size_update_period: The periodicity, in steps, with which we update the size (scale)
of the parameter tensor. This is provided to save a little time.
lr_update_period: Determines the periodicity, in steps, with which we update the
learning-rate matrices. The first number is the periodicity at
the start of training, the second number is the periodicity
later in training. One step of updating the learning rate matrices
can take as long as over 50 minibatches, because SVD on GPU is slow.
** This is important for the speed/optimizaton tradeoff. **
param_cov_period: The periodicity, in steps, with which we update the parameter covariance
stats.
max_block_size: The maximum block size in block-diagonal co-ordinate transformations.
"""
def __init__(
self,
params,
lr=3e-02,
betas=(0.9, 0.98),
size_lr_scale=0.1,
param_pow=0.5,
param_rms_smooth0=0.75,
param_rms_smooth1=0.25,
param_cov_freshness=1.0,
eps=1.0e-08,
param_min_rms=1.0e-05,
param_max_rms=2.0,
scalar_max=2.0,
size_update_period=4,
lr_update_period=(200, 2000),
grad_cov_period=3,
param_cov_period=100,
max_block_size=1024,
):
defaults = dict(
lr=lr,
size_lr_scale=size_lr_scale,
param_pow=param_pow,
param_rms_smooth0=param_rms_smooth0,
param_rms_smooth1=param_rms_smooth1,
param_cov_freshness=param_cov_freshness,
betas=betas,
eps=eps,
param_min_rms=param_min_rms,
param_max_rms=param_max_rms,
scalar_max=scalar_max,
size_update_period=size_update_period,
lr_update_period=lr_update_period,
grad_cov_period=grad_cov_period,
param_cov_period=param_cov_period,
max_block_size=max_block_size,
)
super(PrAdam, self).__init__(params, defaults)
def __setstate__(self, state):
super(PrAdam, self).__setstate__(state)
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
batch = True
for group in self.param_groups:
for p, state in self.batched_params(group["params"]):
# Perform optimization step
grad = p.grad
if grad.is_sparse:
raise RuntimeError(
"PrAdam optimizer does not support sparse gradients"
)
# State initialization
if len(state) == 0:
self._init_state(group, p, state)
self._step_one_batch(group, p, state)
return loss
def _init_state(self,
group: dict,
p: Tensor,
state: dict):
"""
Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p
is actually the batch dimension, corresponding to batched-together
parameters of a given shape.
Args:
group: Dict to look up configuration values.
p: The parameter that we are initializing the state for, actually
will be stacked-together "real" parameters
state: Dict from string to whatever state we are initializing
"""
eps = group["eps"]
size_update_period = group["size_update_period"]
max_block_size = group["max_block_size"]
state["step"] = 0
kwargs = {'device':p.device, 'dtype':p.dtype}
# 'delta' implements conventional momentum. There are
# several different kinds of update going on, so rather than
# compute "exp_avg" like in Adam, we store and decay a
# parameter-change "delta", which combines all forms of
# update. this is equivalent to how it's done in Adam,
# except for the first few steps.
state["delta"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
batch_size = p.shape[0]
numel = p.numel() // batch_size
if numel > 1:
# "param_rms" just periodically records the scalar root-mean-square value of
# the parameter tensor.
# it has a shape like (batch_size, 1, 1, 1, 1)
param_rms = _mean(p**2, exclude_dims=[0], keepdim=True).sqrt() + eps
state["param_rms"] = param_rms
state["scale_exp_avg_sq"] = torch.zeros_like(param_rms)
state["scale_grads"] = torch.zeros(size_update_period, *param_rms.shape,
**kwargs)
# exp_avg_sq is in the projected space. It is periodically zeroed, and we
# set zero_step whenever this happens.
state["exp_avg_sq"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
ignore_rank1_dims = True # a config value, can change this (TODO: tune)
trivial_update = True
for dim in range(1, p.ndim):
size = p.shape[dim]
if size == 1 or (size == numel and ignore_rank1_dims):
continue
trivial_update = False
state["trivial_update"] = trivial_update # so we know whether to zero exp_avg_sq stats.
if trivial_update:
return
# "zero_step" being a member of state is the sign that this parameter has
# at least one dim that has a projection.
state["zero_step"] = 0
for dim in range(1, p.ndim):
size = p.shape[dim]
if size == 1 or (size == numel and ignore_rank1_dims):
continue
# Most of the time, (num_blocks, block_size) will equal (1, size)
num_blocks, block_size = self._get_block_size(size, max_block_size)
# Q_{dim} is be the learning-rate matrix for this
# dimension, a matrix of indexed [batch_idx, block_idx, diagonalized_coordinate, canonical_coordinate].
# this is used twice in the update step, once transposed and once without transpose.
Q = torch.eye(block_size, block_size, **kwargs).unsqueeze(0).unsqueeze(0).expand(
batch_size, num_blocks, block_size, block_size).contiguous()
state[f"Q_{dim}"] = Q
# param_cov_{dim} is the averaged-over-time gradient of parameters on this dimension, treating
# all other dims as a batch axis.
state[f"param_cov_{dim}"] = torch.zeros_like(Q)
# grad_cov_{dim} is the covariance of gradients on this axis (without
# any co-ordinate changes), treating all other axes as as a batch axis.
# This is needed when we re-diagonalize, and to compute the
# grad_rms_{dim}. We store it as a decaying average, decaying with beta2
# only for purposes of computing the scalar factor on the
# learning rate; and, as a result of this, also contributes something
# to the gradient w.r.t. f"Q_{dim}", which is one reason
# why we allocate a variable to keep track of its moving average
# instead of just using a temporary and smoothing the scalar factor.
state[f"grad_cov_{dim}"] = torch.zeros_like(Q)
def _get_block_size(self, size: int, max_block_size: int) -> Tuple[int,int]:
"""
Returns information about the block size for a block-diagonal structure
of covariances and co-ordinate transformations.
size: The size of the parameter tensor on one of its dimensions, e.g.
a channel dim or kernel size
max_block_size: The maximum block size allowed (of the blocks in a
block-diagonal structure)
Returns:
(num_blocks, block_size)
where nun_blocks * block_size == size and block_size <= max_block_size
"""
for num_blocks in range(1, size):
block_size = size // num_blocks
if block_size * num_blocks == size and block_size <= max_block_size:
return (num_blocks, block_size)
assert False, (size, max_block_size) # code error or e.g. negative or
# zero inputs
def _step_one_batch(self,
group: dict,
p: Tensor,
state: dict):
"""
Do the step for one parameter, which is actually going to be a batch of
`real` parameters, with dim 0 as the batch dim.
Args:
group: dict to look up configuration values
p: parameter to update (actually multiple parameters stacked together
as a batch)
state: state-dict for p, to look up the optimizer state
"""
lr = group["lr"]
size_update_period = group["size_update_period"]
grad_cov_period = group["grad_cov_period"]
param_cov_period = group["param_cov_period"]
eps = group["eps"]
beta1 = group["betas"][0]
grad = p.grad
step = state["step"]
delta = state["delta"]
delta.mul_(beta1)
batch_size = p.shape[0]
numel = p.numel() // batch_size
if numel > 1:
# Update the size/scale of p, and set param_rms
scale_grads = state["scale_grads"]
scale_grads[step % size_update_period] = _sum(p * grad,
exclude_dims=[0],
keepdim=True)
if step % size_update_period == size_update_period - 1:
param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..)
param_rms.copy_(_mean(p ** 2, exclude_dims=[0],
keepdim=True).sqrt().clamp_(min=eps))
if step > 0:
# self._size_update() learns the overall scale on the
# parameter, by shrinking or expanding it.
self._size_update(group, scale_grads, p, state)
if numel == 1:
# For parameters with 1 element we just use regular Adam.
# Updates delta.
self._step_scalar(group, p, state)
else:
if step % param_cov_period == 0:
self._update_param_cov(group, p, state)
if self._is_lr_update_step(group, state):
self._update_lrs(group, p, state)
self._zero_exp_avg_sq(state)
if step % grad_cov_period == 0:
self._update_grad_cov(group, p, state)
self._step(group, p, state)
state["step"] = step + 1
def _size_update(self,
group: dict,
scale_grads: Tensor,
p: Tensor,
state: dict) -> None:
"""
Called only where p.numel() > 1, this updates the scale of the parameter.
If we imagine: p = underlying_param * scale.exp(), and we are doing
gradient descent on underlying param and on scale, this function does the update
on `scale`.
Args:
group: dict to look up configuration values
scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing
grads w.r.t. the scales.
p: The parameter to update
state: The state-dict of p
"""
param_rms = state["param_rms"]
beta1, beta2 = group["betas"]
size_lr = group["lr"] * group["size_lr_scale"]
param_min_rms = group["param_min_rms"]
param_max_rms = group["param_max_rms"]
step = state["step"]
batch_size = p.shape[0]
size_update_period = scale_grads.shape[0]
# correct beta2 for the size update period: we will have
# faster decay at this level.
beta2_corr = beta2 ** size_update_period
scale_exp_avg_sq = state["scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..)
scale_exp_avg_sq.mul_(beta2_corr).add_(
(scale_grads ** 2).mean(dim=0), # mean over dim `size_update_period`
alpha=1-beta2_corr) # shape is (batch_size, 1, 1, ...)
# The 1st time we reach here is when size_step == 1.
size_step = (step + 1) // size_update_period
bias_correction2 = 1 - beta2_corr ** size_step
# we don't bother with bias_correction1; this will help prevent divergence
# at the start of training.
eps = 1.0e-10 # this value is not critical.
denom = scale_exp_avg_sq.sqrt() + eps # shape: (batch_size, 1, 1, ..)
scale_step = -size_lr * (bias_correction2 ** 0.5) * scale_grads.sum(dim=0) / denom
is_too_small = (param_rms < param_min_rms)
is_too_large = (param_rms > param_max_rms)
scale_step.masked_fill_(is_too_small, size_lr * size_update_period)
scale_step.masked_fill_(is_too_large, -size_lr * size_update_period)
delta = state["delta"]
# the factor of (1-beta1) relates to momentum.
delta.add_(p * scale_step, alpha=(1-beta1))
def _update_param_cov(self,
group: dict,
p: Tensor,
state: dict) -> None:
"""
Updates state[f"param_cov_{dim}"] for each dim of tensor p
(except batch and trivial and rank-1 dims)
"""
eps = group["eps"]
param_cov_period = group["param_cov_period"]
step = state["step"]
this_weight = (group["param_cov_freshness"] * param_cov_period /
(step + param_cov_period))
batch_size = p.shape[0]
numel = p.numel() // batch_size
for dim in range(1, p.ndim):
size = p.shape[dim]
try:
param_cov = state[f"param_cov_{dim}"]
except KeyError:
continue # e.g. size == 1 or size == numel
# param_cov shape: (batch_size, num_blocks, block_size, block_size)
(batch_size, num_blocks, block_size, block_size) = param_cov.shape
# M: (batch_size, num_blocks, x, block_size)
M = p.transpose(dim, -1).reshape(batch_size, -1,
num_blocks, block_size).transpose(1, 2)
# will be batched matrix multiply. shape of this_param_cov:
# (batch_size, num_blocks, block_size, block_size)
this_param_cov = torch.matmul(M.transpose(2, 3), M)
# normalize scale of this param_cov, in case parameter scale
# changes significantly during training, which would cause some
# parts of the training timeline to be more highly weighted.
# shape of this_param_cov
# expression after /= has shape (batch_size, num_blocks, 1, 1)
this_param_cov /= _diag(this_param_cov).mean(dim=[0,1], keepdim=True).unsqueeze(-1) + eps
param_cov.mul_(1-this_weight).add_(this_param_cov,
alpha=this_weight)
def _zero_exp_avg_sq(self,
state: dict) -> None:
"""
Zero the exp_avg_sq stats, and set state["zero_step"] to state["step"]
"""
state["exp_avg_sq"].zero_()
state["zero_step"] = state["step"]
def _is_lr_update_step(self,
group: dict,
state: dict) -> bool:
"""
Returns True if on this step we need to update the learning-rate matrices for this tensor
and False if not. The periodicity with which we update them increases from
(by default) 200 at the start of training to 2000 later on.
"""
try:
zero_step = state["zero_step"]
except:
# This parameter tensor has no learning-rate matrices to estimate
return False
step = state["step"]
period_initial, period_final = group["lr_update_period"]
# this formula gradually increases the periodicity from period_initial at the start
# to period_final when step >> 4 * period_final
cur_update_period = (period_initial +
((period_final - period_initial) * step /
(step + 4 * period_final)))
return (step >= zero_step + cur_update_period)
def _update_lrs(self,
group: dict,
p: Tensor,
state: dict) -> None:
"""
Computes learning-rate matrices Q for each dim of this tensor.
Args:
group: dict to look up configuration values
p: parameter matrix that we are updating. The learning rate matrices
are actually factors of p, so p itself will change when we change
them.
state: state dict for the current parameter
"""
ndim = p.ndim
batch_size = p.shape[0]
numel = p.numel() // batch_size
if numel in p[0].shape:
return # Nothing to do for this parameter matrix. E.g. a bias or a scalar.
scale_arr = [None] * ndim
# the small random part is to ensure there are no exact zeros, e.g.
# if we initialized some parameters to zero.
#
# we are going to make "cur_p" a corrected version of the parameter
# covariance that is projected with the orthogonal projections U, and that,
# on the axes given by these projections, has covariance proportional
# to that of state[f"param_cov_{dim}"] on those same axes. So it's
# *slightly* whitened, just to match the stats, but not completely
# whitened.
eps = 1.0e-20
cur_p = p + eps * torch.randn_like(p)
for dim in range(1, ndim):
size = p.shape[dim]
try:
Q = state[f"Q_{dim}"]
param_cov = state[f"param_cov_{dim}"]
except KeyError:
assert size == 1 or size == numel, size
continue # e.g. size == 1 or size == numel:
# param_cov has the same shape as Q
(batch_size, num_blocks, block_size, block_size) = Q.shape
# U,S,V have extra leading dimensions, i.e. of shape
# {U,V}.shape == (batch_size, num_blocks, block_size, block_size),
# S.shape == (batch_size, num_blocks, block_size).
U, S, V = _svd(param_cov)
Q[:] = U.transpose(2, 3)
M = cur_p.transpose(dim, -1)
# if p were of shape (batch_size, x, size, y, z),
# after the next line M would be of shape
# (batch_size, x, y, z, num_blocks, block_size)
M = M.reshape(batch_size, *M.shape[1:-1],
num_blocks, block_size)
M = _move_dim(M, -2, 1) # (batch_size, num_blocks, x, y, z, block_size)
while U.ndim < M.ndim:
U = U.unsqueeze(2)
# Now U is of shape (batch_size, num_blocks, 1, 1, block_size, block_size)
M = torch.matmul(M, U) # (batch_size, num_blocks, x, y, z, block_size)
M = _move_dim(M, 1, -2) # (batch_size, x, y, z, num_blocks, block_size)
M = M.reshape(*M.shape[:-2], size) # # (batch_size, x, y, z, size)
cur_p = M.transpose(dim, -1) # (batch_size, x, size, y, z)
# cur_param_var is a diagonal parameter variance over dimension `dim`,
# of the current "slightly-whitened" parameter; it
# will have shape something like [1, size, 1]; or [batch_size, 1, size, 1].
cur_param_var = _mean(cur_p**2,
exclude_dims=[0,dim],
keepdim=True) # (batch_size, 1, size, 1, 1) if dim==2
S = S.reshape(cur_param_var.shape) # (batch_size, 1, size, 1, 1) if dim==2
# OK, cur_param_var would have the values as S if the variance stats
# param_cov_{dim} were accumulated from this exact parameter matrix,
# but actually they contain older versions of the parameter
# covariance so they will, in general, be less extreme ("whiter
# spectrum"). We scale p so that it matches the accumulated stats,
# the idea is to ensure it doesn't have any too-small eigenvalues
# (where the stats permit).
scale = (S.clamp(min=eps) / cur_param_var.clamp(min=eps)).sqrt()
if random.random() < 0.01:
skip = 10 if size < 20 else 1
logging.info(f"shape={p.shape}, dim={dim}, scale={scale[0].flatten()[::skip]}, cur_param_var={cur_param_var[0].flatten()[::skip]}, S={S[0].flatten()[::skip]}")
# scale shape: (batch_size, 1, size, 1, 1)
cur_p *= scale
# OK, at this point we have a matrix cur_p that is (somewhat)
# diagonalized by orthogonal matrices in each non-trivial dim, that is
# also multiplied by scalars to match the accumulated covariance stats,
# i.e. "slightly whitened" (in general this will make a modest
# difference, making the eigenvalue distribution a bit flatter). Now we
# will work out the "scaling" part of the learning-rate matrices Q. We
# can't do this independenty for each dim, because there is a risk of
# "counting things twice". E.g. for a matrix with 2 dims, if we do SVD
# M = U S V^T, if we consider the covariance on the left and the right,
# S will be reflected in both covariances, so it doesn't make sense to
# correct for S twice. Not all parameter matrices have exactly 2 dims,
# and also we're dealing with accumulated parameter stats which makes
# things not quite so simple, so we don't want to just take the sqrt of
# S.
# cur_scales[dim] will be a 1-d tensor with shapes like (batch_size, 1, 1, size, 1),
# containing the scales on the learning-rate matrix for this dimension.
# we apply these scales to the parameter matrix before estimating the
# cur_scales for the other dims.
cur_scales = [None] * ndim
debug = (random.random() < 0.001)
for i in range(4): # for 4 iterations (this is quite arbitrary)
for dim in range(1, ndim):
size = p.shape[dim]
if not f"Q_{dim}" in state:
assert size == 1 or size == numel, size
continue
if cur_scales[dim] is not None:
# correct for the fact that we have already normalized this dim in cur_p,
# i.e. remove the scaling for *this* dim, while keeping the scaling for
# all other dims.
cur_p *= cur_scales[dim] # cur_scales shape: (batch_size, 1, size, 1, 1)
# rms shape: (batch_size, 1, size, 1, 1)
rms = _mean(cur_p**2, exclude_dims=[0,dim], keepdim=True).sqrt()
rank = numel // size
smoothed_rms = self._smooth_param_rms(group, rms, rank)
cur_scales[dim] = smoothed_rms
cur_p /= smoothed_rms # normalize/"whiten" cur_p on this dim..
if debug:
def _summarize(rms):
rms = rms[0] # get rid of batch dim by selecting one example
rms = rms.flatten()
return rms[::10] # subset one every ten items
logging.info(f"i={i} shape={tuple(p.shape)}, dim={dim}, rank={rank}, size={size}, rms={_summarize(rms)}, smoothed_rms={_summarize(smoothed_rms)}")
# Apply the scales in `cur_scales` to Q for each dim; this reflects the
# parameter rms values in the parameter-diagonalized space, that we have
# estimated in the loop above.
#
# We normalize the scales in such a way that the Frobenius norm
# after projecting (grad / denom) with Q should be unchanged, i.e. the
# same as (grad / denom), which is equivalent to having rms=1.0 due
# to how denom is constructed. This simplifies the normalization of the overall
# scale of the parameter change: we just have to multiply by the learning
# rate and param_rms.
for dim in range(1, ndim):
if cur_scales[dim] is not None:
size = p.shape[dim]
Q = state[f"Q_{dim}"]
(batch_size, num_blocks, block_size, block_size) = Q.shape
scale = cur_scales[dim].reshape(batch_size, num_blocks, block_size, 1)
# The following normalization step will ensure the Frobenius
# norm is unchanged, from applying this scale: at least,
# assuming "grad / denom" gives uncorrelated outputs so that
# they will have equal variances after projecting to the space
# where the parameter var is diagonalized... this is *roughly*
# true because the gradients at the point where we compute "grad
# / denom" should be decorrelated at least considering
# individual tensor dims
scale /= _mean(scale**2, exclude_dims=[0], keepdim=True).sqrt()
# Q is indexed [batch_index, block_index, diagonalized_coordinate, canonical_coordinate],
# want to multiply on the diagonalized co-ordinate.
# else: Q is indexed [batch_index, canonical_coordinate].
state[f"Q_{dim}"] *= scale
self._diagonalize_grad_cov(group, p, state)
def _diagonalize_grad_cov(self,
group: dict,
p: Tensor,
state: dict) -> None:
"""
Called from _update_lrs(), this function diagonalizes the gradient covariance
state[f"grad_cov_{dim}"] by modifying the projections state[f"Q_{dim}"]: specifically,
by left-multiplying the projections by an orthogonal matrix.
"""
batch_size = p.shape[0]
numel = p.numel() // batch_size
for dim in range(1, p.ndim): # dim 0 is batch dim
size = p.shape[dim]
try:
# A and grad_cov shape: (batch_size, num_blocks, block_size, block_size)
Q = state[f"Q_{dim}"]
grad_cov = state[f"grad_cov_{dim}"]
except KeyError:
assert size == 1 or size == numel
continue
# Suppose the actual parameter matrix p is M, of shape (-1, size), where
# the -1 represents all other tensor dims treated as a batch dimension.
# M_grad is the same shape as M. We could write a pseudo-loss as
# loss = tr(M_grad^T M)
# Because we can decompose M as M == N Q, we can write:
# loss = tr(M_grad^T N Q) = tr(Q M_grad^T N),
# so we can write this at tr(N_grad^T N),
# where N_grad == (Q M_grad^T)^T = M_grad Q^T.
# Now,
# grad_cov == M_grad^T M_grad,
# decaying-averaged over minibatches; this is of shape (size,size).
# Using N_grad = M_grad Q^T, we can write the gradient covariance w.r.t. N
# (which is what we want to diagonalize), as::
# N_grad_cov = N_grad^T N_grad
# = Q M_grad^T M_grad Q^T
# = Q grad_cov Q^T
# (note: this makes sense because the 1st index of Q is the diagonalized
# index).
def dispersion(v):
# a ratio that, if the eigenvalues are more dispersed, it will be larger.
return '%.3e' % ((v**2).mean() / (v.mean() ** 2)).item()
# N_grad_cov shape: (batch_size, num_blocks, block_size, block_size)
N_grad_cov = torch.matmul(Q, torch.matmul(grad_cov, Q.transpose(2, 3)))
N_grad_cov = N_grad_cov + N_grad_cov.transpose(2, 3) # ensure symmetric
U, S, V = _svd(N_grad_cov)
if random.random() < 0.01:
logging.info(f"Diagonalizing, shape={tuple(p.shape)}, dim={dim}, dispersion "
f"changed from {dispersion(_diag(N_grad_cov))} to {dispersion(S)}")
# N_grad_cov is SPD, so
# N_grad_cov = U S U^T.
# Now, we can diagonalize N_grad_cov with:
# U^T N_grad_cov U == S.
# N_grad_cov is a sum of N_grad^T N_grad.
# We know U^T N_grad_cov U is diagonal, so U^T N_grad^T N_grad U is diagonal.
# The linearized pseudo-loss can be written as tr(N_grad^T N_grad).
# This can be written as tr(U U^T N_grad^T N_grad), since U U^T == I,
#
# which we can rearrange as tr(U^T N_grad^T N U). This can be interpreted
# as tr(hat_N_grad hat_N), where:
# hat_N_grad = N_grad U
# hat_N = N U
# (hat_N means \hat{N}, or N with a hat on it).
# So if we interpret hat_N = N U, the gradient covariance w.r.t.
# hat_N will be diagonalized. We also modify Q to hat_Q when
# we modify hat_N, to keep the product M unchanged:
# M = N Q = N U U^T Q = hat_N hat_Q
# This can be done by setting
# hat_Q := U^T Q (eq.10)
#
# This is the only thing we have to do, as N is implicit
# and not materialized at this point.
Q[:] = torch.matmul(U.transpose(2, 3), Q)
def _update_grad_cov(self,
group: dict,
p: Tensor,
state: dict):
"""
Updates the gradient covariance state[f"grad_cov_{dim}"] for appropriate
dimensions of the tensor.
"""
beta2 = group["betas"][1]
grad = p.grad
batch_size = p.shape[0]
for dim in range(1, grad.ndim): # dim 0 is batch dim.
size = grad.shape[dim]
name = f"grad_cov_{dim}"
if name not in state:
continue # e.g. size==1, size == numel, size > max_block_size
grad_cov = state[name] # shaped: (batch_size, num_blocks, block_size, block_size)
(batch_size, num_blocks, block_size, block_size) = grad_cov.shape
g = grad.transpose(-1, dim)
# if grad were of shape (batch_size, x, size, y, z),
# and g of shape (batch_size, x, y, z, size),
# after the next line g will be of shape
# (batch_size, x, y, z, num_blocks, block_size)
g = g.reshape(batch_size, *g.shape[1:-1],
num_blocks, block_size)
g = _move_dim(g, -2, 1) # (batch_size, num_blocks, x, y, z, block_size)
g = g.reshape(batch_size, num_blocks, -1, block_size)
# this_grad_cov: (batch_size, num_blocks, block_size, block_size)
this_grad_cov = torch.matmul(g.transpose(-2, -1), g)
grad_cov.mul_(beta2).add_(this_grad_cov, alpha=(1-beta2))
def _step(self,
group: dict,
p: Tensor,
state: dict):
"""
This function does the core update of self._step, in the case where the tensor
has more than 1 elements. It multiplies the moving-average gradient tensor by a
state["lr_{dim}"] on each non-trivial axis/dimension, does a length normalization,
and takes a step in that direction.
Args:
group: A dict which will be used to look up configuration values
p: The parameter to be updated
grad: The grad of p
state: The state-dict corresponding to parameter p
batch: if true, the first dimension of p is a batch dim, i.e. a batch of
parameter matrices.
This function modifies p.
"""
grad = p.grad
lr = group["lr"]
beta1, beta2 = group["betas"]
eps = group["eps"]
step = state["step"]
grad = self._project(grad, state, forward=True)
exp_avg_sq = state["exp_avg_sq"]
exp_avg_sq.mul_(beta2).addcmul_(grad, grad,
value=(1-beta2))
this_step = state["step"] - (state["zero_step"] if "zero_step" in state else 0)
bias_correction2 = 1 - beta2 ** (this_step + 1)
if bias_correction2 < 0.99:
# note: not in-place.
exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2)
denom = exp_avg_sq.sqrt() + eps
grad = grad / denom
# and project back..
grad = self._project(grad, state, forward=False)
alpha = -lr * (1-beta1) * state["param_rms"]
delta = state["delta"]
delta.add_(grad * alpha)
p.add_(delta)
def _step_scalar(self,
group: dict,
p: Tensor,
state: dict):
"""
A simplified form of the core update for scalar tensors, where we cannot get a good
estimate of the parameter rms. p will not be a scalar, because it has a batch dim.
"""
beta1, beta2 = group["betas"]
scalar_max = group["scalar_max"]
eps = group["eps"]
lr = group["lr"]
grad = p.grad
exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad,
value=1-beta2)
# bias_correction2 is like in Adam. Don't bother with bias_correction1;
# slower update at the start will help stability anyway.
bias_correction2 = 1 - beta2 ** (state["step"] + 1)
denom = (exp_avg_sq / bias_correction2).sqrt() + eps
alpha = -lr * (1-beta1) / denom
delta = state["delta"]
delta.add_(grad / denom, alpha=-lr*(1-beta1))
p.clamp_(min=-scalar_max, max=scalar_max)
p.add_(delta)
def _project(self,
M: Tensor,
state: dict,
forward: bool) -> Tensor:
"""
Multiply a tensor M by a projection Q on each of its dimensions (for which we have
such a projection)
Args:
M: The tensor to project, e.g. a gradient or a parameter change.
state: dict to look up the projections
forward: if True, go in the forward direction (from canonical to diagnonalized
co-ordinates); if False, the reverse. They differ by a transpose,
not an inverse, and do not make a round trip.
"""
numel = M.numel()
for dim in range(1, M.ndim): # dim 0 is batch dim (batch of parameter
# tensors of same shape)
try:
Q = state[f"Q_{dim}"] # shape: (batch_size, num_blocks, block_size, block_size)
except KeyError:
continue # no projection for this dim
(batch_size, num_blocks, block_size, block_size) = Q.shape
size = M.shape[dim] # == num_blocks * block_size
if forward:
# Q is indexed
# [batch_idx, block_idx, diagonalized_idx, canonical_idx]; in
# the forward direction we want to change canonical to
# diagonalized index, so we transpose.
Q = Q.transpose(2, 3)
# assume M currently has shape (batch_size, x, size, y, z), and dim==2.
M = M.transpose(-1, dim) # (batch_size, x, y, z, size)
M = M.reshape(batch_size, *M.shape[1:-1],
num_blocks, block_size) # (batch_size, x, y, z, num_blocks, block_size)
M = _move_dim(M, -2, 1) # (batch_size, num_blocks, x, y, z, block_size)
while Q.ndim < M.ndim:
Q = Q.unsqueeze(2)
# now Q has shape (batch_size, num_blocks, 1, 1, block_size, block_size)
M = torch.matmul(M, Q) # (batch_size, num_blocks, x, y, z, block_size)
M = _move_dim(M, 1, -2) # (batch_size, x, y, z, num_blocks, block_size)
M = M.reshape(*M.shape[:-2], size) # (batch_size, x, y, z, size)
M = M.transpose(-1, dim) # (batch_size, x, size, y, z)
return M
def _smooth_param_rms(self,
group: dict,
rms: Tensor,
rank: int) -> Tensor:
"""
Smooths and normalizes to mean=1 a tensor of shape something like (batch_size, 1, size, 1, 1),
where for the nontrivial dim `size` it contains dagonalized parameter rms values; this is for
one dimension of a multiple-dimensional tensor.
This will be used to construct a learning-rate matrix.
Args:
group: dict for configuration values
rms: A tensor of shape (batch_size, [1,1,..], size, [1,1,1,..]), representing
(for each batch element) a list of root-mean-square values, one per
dimension of the space of size `size`.
rank: the assumed rank of the covariance matrix from which rms was derived,
used to decide how much smoothing to apply. This is actually the
minimal rank, if the parameter matrix were stay fixed during training,
but still relevant to know how robust the parameter covariance estimate is.
Returns:
a Tensor with the same shape as `rms` but with some smoothing applied so there
are no values too close to zero.
"""
smooth0 = group["param_rms_smooth0"]
smooth1 = group["param_rms_smooth1"]
param_pow = group["param_pow"]
eps = group["eps"]
batch_size = rms.shape[0]
size = rms.numel() // batch_size
rms = rms ** param_pow
# want expr to be of the form: smooth = alpha * size / (beta*rank + size)
# from rank==0, we get smooth0 = alpha * size/size, so alpha = smooth0.
# from setting rank==size, we get smooth1 = alpha * size / (beta*size * size) = alpha/(1+beta),
# so smooth1 == smooth0 / (1+beta), so (1+beta) = smooth0/smooth1, so beta=smooth0/smooth1 - 1
smooth = smooth0 * size / ((smooth0/smooth1 - 1) * rank + size)
mean = _mean(rms, exclude_dims=[0], keepdim=True)
rms += eps + smooth * mean
new_mean = (eps + (smooth + 1) * mean) # mean of modified rms.
ans = rms / new_mean
if True:
# Apply max_rms
max_rms = 4.0
ans.clamp_(max=max_rms*2)
ans /= _mean(ans, exclude_dims=[0], keepdim=True)
ans.clamp_(max=max_rms)
ans /= _mean(ans, exclude_dims=[0], keepdim=True)
return ans
def _move_dim(x: Tensor, orig_dim: int, new_dim: int) -> Tensor:
"""
Moves the position of one dimension in a Tensor, while keeping
the remaining dims in the same order. E.g. if x has
shape [10, 1, 2, 3], _move_dim(x, 3, 1) will have shape
[10, 3, 1, 2]. Returns the reshaped tensor which will be
a view of the original Tensor.
Args:
x: Tensor to reshape
orig_dim: Original dimension to move, with -x.ndim <= orig_dim < x.ndim
new_dim: New position of the original dimension, with
-x.ndim <= new_dim < x.ndim
Returns: a permuted view of x
"""
if orig_dim < 0:
orig_dim += x.ndim
if new_dim < 0:
new_dim += x.ndim
dims = list(range(x.ndim))
dims[orig_dim] = -1
dims.remove(-1)
dims.insert(new_dim, orig_dim)
return x.permute(dims)
def _diag(x: Tensor):
"""
like torch diag(), but supports batch dim, i.e. input of shape (B, M, M) returns
output of shape (B, M)
"""
if x.ndim == 3:
(B, M, M2) = x.shape
assert M == M2
stride = x.stride()
return x.as_strided(size=(B, M), stride=(stride[0], stride[1] + stride[2])).contiguous()
elif x.ndim == 4:
(B, C, M, M2) = x.shape
assert M == M2
stride = x.stride()
return x.as_strided(size=(B, C, M), stride=(stride[0], stride[1], stride[2] + stride[3])).contiguous()
else:
return x.diag()
def _sum(x: Tensor,
exclude_dims: List[int] = [],
keepdim: bool = False):
"""
Version of torch.sum where you specify dims to exclude, not dims to sum.
"""
if len(exclude_dims) == 0:
# dim=[] works the same as tuple(range(x.ndim)), which makes
# no sense.. supplying the full tuple for clarity and in case
# the interface changes in future.
return x.sum(dim=tuple(range(x.ndim)), keepdim=keepdim)
elif x.ndim == 1:
assert exclude_dim == [0] or exclude_dim == [-1]
return x # if one dim is excluded, there are no dims to sum, and
# x.sum(dim=[]) sums all dims so we should not call sum().
exclude_dims_norm = [i if i >= 0 else i + x.ndim for i in exclude_dims]
dims = [i for i in range(x.ndim) if i not in exclude_dims_norm]
assert (len(dims) + len(exclude_dims) == x.ndim), exclude_dims
return x.sum(dim=dims, keepdim=keepdim)
def _mean(x: Tensor,
exclude_dims: List[int] = [],
keepdim: bool = False):
"""
Version of torch.mean where you specify dims to exclude, not dims to mean.
"""
if len(exclude_dims) == 0:
# dim=[] works the same as tuple(range(x.ndim)), which makes
# no sense.. supplying the full tuple for clarity and in case
# the interface changes in future.
return x.mean(dim=tuple(range(x.ndim)), keepdim=keepdim)
elif x.ndim == 1:
assert exclude_dim == [0] or exclude_dim == [-1]
return x # if one dim is excluded, there are no dims to mean, and
# x.mean(dim=[]) means all dims so we should not call mean().
exclude_dims_norm = [i if i >= 0 else i + x.ndim for i in exclude_dims]
dims = [i for i in range(x.ndim) if i not in exclude_dims_norm]
assert (len(dims) + len(exclude_dims) == x.ndim), exclude_dims
return x.mean(dim=dims, keepdim=keepdim)
def _svd(x: Tensor):
# Wrapper for torch svd that catches errors and re-tries (to address bugs in
# some versions of torch SVD for CUDA)
randU = None
for i in range(4):
try:
U, S, V = x.svd()
s = U.sum()
if s - s != 0 or (__name__ == '__main__' and random.random() < 0.01):
raise RuntimeError(f"svd failed (or random test), sum={s}")
if randU is not None:
U = torch.matmul(randU.t(), U)
return U, S, V # success
except:
logging.warning(f"svd failed: i={i}, x.shape={tuple(x.shape)}, x.sum()={x.sum()}: likely "
f"error in SVD. Will try again after random change.")
this_x = x
while this_x.ndim > 2:
this_x = this_x[0] # get rid of any batch dims
U, S, V = torch.randn_like(this_x).svd()
x = torch.matmul(U, x)
if randU is None:
randU = U
else:
randU = torch.matmul(U, randU)
class Cain(Optimizer):
"""Implements Cain algorithm, which is a substantially modified form of the
Adam optimizer.
The modifications can be summarized as follows:
(i) change Adam to what we called "Eve", which adds an extra configuration
value, "target_rms" (default value: 0.1); and for non-scalar parameters,
if their root-mean-square value exceeds "target_rms", we apply weight
decay (aggressive enough weight decay that the rms stays at target_rms).
This weight decay is applied in the same way as in AdamW, i.e. by
parameter shrinkage rather than by modifying the gradients.
(ii) By limiting the rms to 0.1, we reduce our modeling power; we restore
it by replacing all weights and biases-- in fact, all model parameters--
with expressions like
weight * weight_scale.exp() or bias * bias_scale.exp().
This has the benefit of making the learnable offsets and scales of
BatchNorm/LayerNorm unnecessary.
Below we describe the changes from Eve to Cain:
(iii) We removed all the weight_scale and bias_scale parameters from
the model and "moved them into the optimizer". Now we allow
parameters to have any rms value (that's <= 10), but we:
(a) make the update speed for a parameter proportional to
its rms value.
(b) Multiply the learning rate by 10 because we are getting rid of
target_rms=0.1
(c) simulate the effect of learning the weight_scale or bias_scale
in log space using standard Adam, with a scale to prevent it
from learning too fast. (see scale_exp_avg_sq in the
code below).
The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_.
The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_.
Eve is unpublished so far.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
scale_speed (float, optional): we multiply the learning rate by this value
when learning (in log space) the scales of the parameter tensors
rms_eps (float, optional): epsilon value such that when parameter
rms value goes below this, we stop learning slower for that parameter,
i.e. we treat the rms as if it were rms_eps. In practice, rms values
will not go much below this.
param_max (float, optional): maximum absolute value for any parameter;
we will clamp parameters to satisfy -param_max <= param <= param_max
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _Decoupled Weight Decay Regularization:
https://arxiv.org/abs/1711.05101
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
def __init__(
self,
params,
lr=1e-2,
betas=(0.9, 0.98),
eps=1e-8,
scale_speed=0.05,
rms_eps=1.0e-05,
param_max=20.0,
):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError(
"Invalid beta parameter at index 0: {}".format(betas[0])
)
if not 0.0 <= betas[1] < 1.0:
raise ValueError(
"Invalid beta parameter at index 1: {}".format(betas[1])
)
if not 0 < scale_speed < 1.0:
raise ValueError("Invalid scale_speed value: {}".format(scale_speed))
if not 0.1 < param_max <= 10000.0:
raise ValueError("Invalid param_max value: {}".format(param_max))
if not 0.0 < rms_eps < 0.1:
raise ValueError("Invalid rms_eps value: {}".format(rms_eps))
defaults = dict(
lr=lr,
betas=betas,
eps=eps,
scale_speed=0.05,
rms_eps=rms_eps,
param_max=param_max,
)
super(Cain, self).__init__(params, defaults)
def __setstate__(self, state):
super(Cain, self).__setstate__(state)
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
beta1, beta2 = group["betas"]
lr = group["lr"]
scale_speed = group["scale_speed"]
rms_eps = group["rms_eps"]
eps = group["eps"]
param_max = group["param_max"]
for p in group["params"]:
if p.grad is None:
continue
# Perform optimization step
grad = p.grad
if grad.is_sparse:
raise RuntimeError(
"Cain does not support sparse gradients"
)
state = self.state[p]
# State initialization
if len(state) == 0:
state["step"] = 0
# Exponential moving average of gradient values
state["delta"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
if p.numel() > 1:
# we keep track of the RMS value of the parameters to
# help set the learning rate... this emulates Eve.
state["param_rms"] = torch.ones_like(p)
# we learn the scale on the parameters in a way that
# also emulates Eve. scale_exp_avg_sq is the
# moving-average square of the gradient w.r.t. this
# scalar scale.
state["scale_exp_avg_sq"] = torch.zeros((), device=p.device,
dtype=p.dtype)
delta, exp_avg_sq = state["delta"], state["exp_avg_sq"]
step = state["step"]
delta.mul_(beta1)
if p.numel() > 1:
if True:
# This block learns the scale on the parameters.
# scale_deriv is the derivative w.r.t. a log-space scale
# on the parameters, i.e. if we imagine: p =
# underlying_param * scale.exp(), scale_deriv is the
# derivative w.r.t. scale (it does not matter
# what value `scale` currently has).
scale_deriv = (p * grad).sum()
scale_exp_avg_sq = state["scale_exp_avg_sq"]
scale_exp_avg_sq.mul_(beta2).addcmul_(scale_deriv, scale_deriv,
value=1 - beta2)
# should actually be step + 1, so on 1st minibatch we are not learning
# anything here. May fix this at some point.
scale_bias_correction2 = 1 - beta2 ** (step + 1)
scale_denom = (scale_exp_avg_sq.sqrt()).add_(eps)
scale_alpha = -lr * (scale_bias_correction2 ** 0.5) * scale_speed * (1-beta1)
# scale_delta is going to be the change in log-scale (before momentum), i.e. if
# p = underlying_param * scale.exp(),
# delta is the change in `scale`.
scale_delta = scale_alpha * (scale_deriv / scale_denom)
delta.add_(p, alpha=scale_delta)
# Decay the second moment running average coefficient
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
bias_correction2 = 1 - beta2 ** step
grad_rms = (exp_avg_sq.sqrt()).add_(eps)
this_delta = grad / grad_rms
param_rms = state["param_rms"]
self._update_param_rms(p, param_rms, step, rms_eps)
alpha = (-(1-beta1) * lr * bias_correction2)
# geometrically interpolate the per-element param_rms with its mean
#param_rms_half = (param_rms * param_rms.mean()) ** 0.5
delta.add_(this_delta * param_rms, alpha=alpha)
# below, will do: delta.add_(this_delta, alpha=alpha)
else:
# p.numel() == 1
# Decay the second moment running average coefficient
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
bias_correction2 = 1 - beta2 ** step
denom = (exp_avg_sq.sqrt()).add_(eps)
this_delta = grad / denom
alpha = -lr * (1-beta1) * (bias_correction2 ** 0.5)
delta.add_(this_delta, alpha=alpha)
p.add_(delta)
if step % 10 == 0:
p.clamp_(min=-param_max, max=param_max)
state["step"] += 1
return loss
def _update_param_rms(self, p: Tensor, param_rms: Tensor,
step: int, rms_eps: float):
"""
Updates `param_rms`. Require p.numel() > 1
Args:
p: parameter whose magnitude/rms we are measuring
param_rms: a Tensor containing non-negative values, with the same
shape as p. Contains an estimate of the root-mean-square
value of the parameter. This has a special structure where
it must be a product of one-dimensional quantities,
one for each axis of p.
We exclude certain axes to avoid estimating the rms value
from a too-small number of parameters.
step: the step of the algorithm
rms_eps: epsilon such that we'll aim to prevent param_rms from
going much lower than this.
"""
if step % 1000 == 0:
# Periodically refresh param_rms to keep the invariance that it is a
# product over its axes (numerical roundoff might destroy this)
param_rms.fill_(1.0)
for dim in range(p.ndim):
self._update_param_rms_for_dim(p, param_rms, dim, rms_eps)
else:
dim = step % p.ndim
self._update_param_rms_for_dim(p, param_rms, dim, rms_eps)
def _update_param_rms_for_dim(self, p: Tensor,
param_rms: Tensor,
dim: int,
rms_eps: float):
"""
Updates `param_rms` on one of its axes/dimensions.
Args:
p: parameter whose magnitude/rms we are measuring
param_rms: a Tensor containing non-negative values, with the same
shape as p. Contains an estimate of the root-mean-square
value of the parameter. This has a special structure where
it must be a product of one-dimensional quantities,
one for each axis of p.
We exclude certain axes to avoid estimating the rms value
from a too-small number of parameters.
dim: with 0 <= dim < p.ndim, the
dimension/axis on which to update the scale factor
rms_eps: epsilon such that we'll aim to prevent param_rms from
going much lower than this.
"""
# var_factor is the square of the factor to change param_rms by. p will
# not be super large, we limit it to -param_max <= p <= param_max.
# Clamping p**2 to min=1.0e-10 should prevent param_rms from getting much
# smaller than 1.0e-05, which should prevent division by zero here.
var_factor = (p ** 2).clamp(min=rms_eps*rms_eps) / (param_rms ** 2)
if p.numel() <= 2 * p.shape[dim]:
var_factor = var_factor.mean()
else:
dims = tuple(d for d in range(p.ndim) if d != dim)
var_factor = var_factor.mean(dim=dims, keepdim=True)
#if random.random() < 0.01:
# logging.info(f"shape={p.shape}, var_factor={var_factor}")
param_rms.mul_(var_factor.sqrt())
class LRScheduler(object):
"""
Base-class for learning rate schedulers where the learning-rate depends on both the
batch and the epoch.
"""
def __init__(self, optimizer: Optimizer, verbose: bool = False):
# Attach optimizer
if not isinstance(optimizer, Optimizer):
raise TypeError(
"{} is not an Optimizer".format(type(optimizer).__name__)
)
self.optimizer = optimizer
self.verbose = verbose
for group in optimizer.param_groups:
group.setdefault("initial_lr", group["lr"])
self.base_lrs = [
group["initial_lr"] for group in optimizer.param_groups
]
self.epoch = 0
self.batch = 0
def state_dict(self):
"""Returns the state of the scheduler as a :class:`dict`.
It contains an entry for every variable in self.__dict__ which
is not the optimizer.
"""
return {
"base_lrs": self.base_lrs,
"epoch": self.epoch,
"batch": self.batch,
}
def load_state_dict(self, state_dict):
"""Loads the schedulers state.
Args:
state_dict (dict): scheduler state. Should be an object returned
from a call to :meth:`state_dict`.
"""
self.__dict__.update(state_dict)
def get_last_lr(self) -> List[float]:
"""Return last computed learning rate by current scheduler. Will be a list of float."""
return self._last_lr
def get_lr(self):
# Compute list of learning rates from self.epoch and self.batch and
# self.base_lrs; this must be overloaded by the user.
# e.g. return [some_formula(self.batch, self.epoch, base_lr) for base_lr in self.base_lrs ]
raise NotImplementedError
def step_batch(self, batch: Optional[int] = None) -> None:
# Step the batch index, or just set it. If `batch` is specified, it
# must be the batch index from the start of training, i.e. summed over
# all epochs.
# You can call this in any order; if you don't provide 'batch', it should
# of course be called once per batch.
if batch is not None:
self.batch = batch
else:
self.batch = self.batch + 1
self._set_lrs()
def step_epoch(self, epoch: Optional[int] = None):
# Step the epoch index, or just set it. If you provide the 'epoch' arg,
# you should call this at the start of the epoch; if you don't provide the 'epoch'
# arg, you should call it at the end of the epoch.
if epoch is not None:
self.epoch = epoch
else:
self.epoch = self.epoch + 1
self._set_lrs()
def _set_lrs(self):
values = self.get_lr()
assert len(values) == len(self.optimizer.param_groups)
for i, data in enumerate(zip(self.optimizer.param_groups, values)):
param_group, lr = data
param_group["lr"] = lr
self.print_lr(self.verbose, i, lr)
self._last_lr = [group["lr"] for group in self.optimizer.param_groups]
def print_lr(self, is_verbose, group, lr):
"""Display the current learning rate."""
if is_verbose:
logging.info(
f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate"
f" of group {group} to {lr:.4e}."
)
class Eve(Optimizer):
"""
Implements Eve algorithm. This is a modified version of AdamW with a special
way of setting the weight-decay / shrinkage-factor, which is designed to make the
rms of the parameters approach a particular target_rms (default: 0.1). This is
for use with networks with 'scaled' versions of modules (see scaling.py), which
will be close to invariant to the absolute scale on the parameter matrix.
The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_.
The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_.
Eve is unpublished so far.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay coefficient (default: 3e-4;
this value means that the weight would decay significantly after
about 3k minibatches. Is not multiplied by learning rate, but
is conditional on RMS-value of parameter being > target_rms.
target_rms (float, optional): target root-mean-square value of
parameters, if they fall below this we will stop applying weight decay.
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _Decoupled Weight Decay Regularization:
https://arxiv.org/abs/1711.05101
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.98),
eps=1e-8,
weight_decay=1e-3,
target_rms=0.1,
):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError(
"Invalid beta parameter at index 0: {}".format(betas[0])
)
if not 0.0 <= betas[1] < 1.0:
raise ValueError(
"Invalid beta parameter at index 1: {}".format(betas[1])
)
if not 0 <= weight_decay <= 0.1:
raise ValueError(
"Invalid weight_decay value: {}".format(weight_decay)
)
if not 0 < target_rms <= 10.0:
raise ValueError("Invalid target_rms value: {}".format(target_rms))
defaults = dict(
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
target_rms=target_rms,
)
super(Eve, self).__init__(params, defaults)
def __setstate__(self, state):
super(Eve, self).__setstate__(state)
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
# Perform optimization step
grad = p.grad
if grad.is_sparse:
raise RuntimeError(
"AdamW does not support sparse gradients"
)
state = self.state[p]
# State initialization
if len(state) == 0:
state["step"] = 0
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
beta1, beta2 = group["betas"]
state["step"] += 1
bias_correction1 = 1 - beta1 ** state["step"]
bias_correction2 = 1 - beta2 ** state["step"]
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_(
group["eps"]
)
step_size = group["lr"] / bias_correction1
target_rms = group["target_rms"]
weight_decay = group["weight_decay"]
if p.numel() > 1:
# avoid applying this weight-decay on "scaling factors"
# (which are scalar).
is_above_target_rms = p.norm() > (
target_rms * (p.numel() ** 0.5)
)
p.mul_(1 - (weight_decay * is_above_target_rms))
p.addcdiv_(exp_avg, denom, value=-step_size)
if random.random() < 0.0005:
step = (exp_avg/denom) * step_size
logging.info(f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}")
return loss
class Eden(LRScheduler):
"""
Eden scheduler.
lr = initial_lr * (((batch**2 + lr_batches**2) / lr_batches**2) ** -0.25 *
(((epoch**2 + lr_epochs**2) / lr_epochs**2) ** -0.25))
E.g. suggest initial-lr = 0.003 (passed to optimizer).
Args:
optimizer: the optimizer to change the learning rates on
lr_batches: the number of batches after which we start significantly
decreasing the learning rate, suggest 5000.
lr_epochs: the number of epochs after which we start significantly
decreasing the learning rate, suggest 6 if you plan to do e.g.
20 to 40 epochs, but may need smaller number if dataset is huge
and you will do few epochs.
"""
def __init__(
self,
optimizer: Optimizer,
lr_batches: Union[int, float],
lr_epochs: Union[int, float],
verbose: bool = False,
):
super(Eden, self).__init__(optimizer, verbose)
self.lr_batches = lr_batches
self.lr_epochs = lr_epochs
def get_lr(self):
factor = (
(self.batch ** 2 + self.lr_batches ** 2) / self.lr_batches ** 2
) ** -0.25 * (
((self.epoch ** 2 + self.lr_epochs ** 2) / self.lr_epochs ** 2)
** -0.25
)
return [x * factor for x in self.base_lrs]
def _test_eden():
m = torch.nn.Linear(100, 100)
optim = Cain(m.parameters(), lr=0.003)
scheduler = Eden(optim, lr_batches=100, lr_epochs=2, verbose=True)
for epoch in range(10):
scheduler.step_epoch(epoch) # sets epoch to `epoch`
for step in range(20):
x = torch.randn(200, 100).detach()
x.requires_grad = True
y = m(x)
dy = torch.randn(200, 100).detach()
f = (y * dy).sum()
f.backward()
optim.step()
scheduler.step_batch()
optim.zero_grad()
logging.info(f"last lr = {scheduler.get_last_lr()}")
logging.info(f"state dict = {scheduler.state_dict()}")
def _test_eve_cain():
import timeit
from scaling import ScaledLinear
E = 100
B = 4
T = 2
logging.info("in test_eve_cain")
#device = torch.device('cuda')
device = torch.device('cpu')
dtype = torch.float32
fix_random_seed(42)
# these input_magnitudes and output_magnitudes are to test that
# Abel is working as we expect and is able to adjust scales of
# different dims differently.
input_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp()
output_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp()
for iter in [3, 2, 1, 0]:
fix_random_seed(42)
Linear = torch.nn.Linear if iter == 0 else ScaledLinear
# TODO: find out why this is not converging...
m = torch.nn.Sequential(Linear(E, 200),
torch.nn.PReLU(),
Linear(200, 200),
torch.nn.PReLU(),
Linear(200, E),
).to(device)
train_pairs = [ (100.0 * torch.randn(B, T, E, device=device, dtype=dtype) * input_magnitudes,
torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes) for _ in range(20) ]
if iter == 0: optim = Eve(m.parameters(), lr=0.003)
elif iter == 1: optim = Cain(m.parameters(), lr=0.03)
elif iter == 2: optim = PrAdam(m.parameters(), lr=0.03)
elif iter == 3: optim = PrAdam(m.parameters(), lr=0.03, max_block_size=100)
scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False)
start = timeit.default_timer()
avg_loss = 0.0
for epoch in range(150):
scheduler.step_epoch()
#if epoch == 100 and iter in [2,3]:
# optim.reset_speedup() # check it doesn't crash.
#if epoch == 130:
# opts = diagnostics.TensorDiagnosticOptions(
# 2 ** 22
# ) # allow 4 megabytes per sub-module
# diagnostic = diagnostics.attach_diagnostics(m, opts)
for n, (x,y) in enumerate(train_pairs):
y_out = m(x)
loss = ((y_out - y)**2).mean() * 100.0
if epoch == 0 and n == 0:
avg_loss = loss.item()
else:
avg_loss = 0.95 * avg_loss + 0.05 * loss.item()
if n == 0 and epoch % 5 == 0:
#norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item()
#norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item()
#norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item()
#norm2b = '%.2e' % (m[2].bias**2).mean().sqrt().item()
#scale1 = '%.2e' % (m[0].weight_scale.exp().item())
#scale1b = '%.2e' % (m[0].bias_scale.exp().item())
#scale2 = '%.2e' % (m[2].weight_scale.exp().item())
#scale2b = '%.2e' % (m[2].bias_scale.exp().item())
lr = scheduler.get_last_lr()[0]
logging.info(f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}") #, norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b}
loss.log().backward()
optim.step()
optim.zero_grad()
scheduler.step_batch()
#diagnostic.print_diagnostics()
stop = timeit.default_timer()
logging.info(f"Iter={iter}, Time taken: {stop - start}")
logging.info(f"last lr = {scheduler.get_last_lr()}")
#logging.info("state dict = ", scheduler.state_dict())
#logging.info("optim state_dict = ", optim.state_dict())
logging.info(f"input_magnitudes = {input_magnitudes}")
logging.info(f"output_magnitudes = {output_magnitudes}")
def _test_svd():
device = 'cuda'
for i in range(1000):
X = torch.eye(100, device=device) + 1.0e-08 * torch.randn(100, 100, device=device)
U, S, V = _svd(X)
X2 = torch.matmul(U*S, V.t())
assert torch.allclose(X2, X, atol=0.001)
if __name__ == "__main__":
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
logging.getLogger().setLevel(logging.INFO)
#_test_svd()
_test_eve_cain()
#_test_eden()