mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
1963 lines
83 KiB
Python
1963 lines
83 KiB
Python
# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
|
|
#
|
|
# See ../LICENSE for clarification regarding multiple authors
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from collections import defaultdict
|
|
from typing import List, Optional, Union, Tuple, List
|
|
from lhotse.utils import fix_random_seed
|
|
import torch
|
|
from scaling import ActivationBalancer
|
|
import random
|
|
from torch import Tensor
|
|
from torch.optim import Optimizer
|
|
import logging
|
|
|
|
|
|
class BatchedOptimizer(Optimizer):
|
|
"""
|
|
This class adds to class Optimizer the capability to optimize parameters in batches:
|
|
it will stack the parameters and their grads for you so the optimizer can work
|
|
on tensors with an extra leading dimension.
|
|
|
|
Args:
|
|
params:
|
|
"""
|
|
def __init__(self, params, defaults):
|
|
super(BatchedOptimizer, self).__init__(params, defaults)
|
|
|
|
|
|
def batched_params(self, param_group):
|
|
"""
|
|
This function returns an iterator over tuples (p, state), where
|
|
p is a `fake` parameter that is stacked (over axis 0) from real parameters
|
|
that share the same shape, and its gradient is also stacked;
|
|
`state` is the state corresponding to this fake parameter
|
|
(it will be physically located in the "state" for one of the real
|
|
parameters, the last one that has any particular shape and dtype).
|
|
|
|
The idea is, instead of doing:
|
|
<code>
|
|
for p in group["params"]:
|
|
state = self.state[p]
|
|
...
|
|
</code>
|
|
you can do:
|
|
<code>
|
|
for p, state in self.batched_params(group["params"]):
|
|
...
|
|
</code>
|
|
|
|
Args:
|
|
group: a parameter group, which is a list of parameters; should be
|
|
one of self.groups.
|
|
"""
|
|
batches = defaultdict(list) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter
|
|
|
|
for p in param_group:
|
|
assert p.grad is not None and "All parameters must have a grad when you call optimizer.step()"
|
|
assert not p.grad.is_sparse and "Sparse gradients not supported."
|
|
key = (str(p.dtype), *p.shape)
|
|
batches[key].append(p)
|
|
|
|
for p in param_group:
|
|
key = (str(p.dtype), *p.shape)
|
|
batch = batches[key]
|
|
if p is batch[0]: # if this is the 1st param in the batch...
|
|
# we arbitrarily store the state in the
|
|
# state corresponding to the 1st parameter in the
|
|
# group. class Optimizer will take care of saving/loading state.
|
|
state = self.state[p]
|
|
p_stacked = torch.stack(batch)
|
|
grad = torch.stack([p.grad for p in batch])
|
|
p_stacked.grad = grad
|
|
yield p_stacked, state # <-- calling code will do the actual optimization here!
|
|
# Now un-stack the parameter changes
|
|
for i,p in enumerate(batch):
|
|
p.copy_(p_stacked[i])
|
|
|
|
|
|
|
|
|
|
class PrAdam(BatchedOptimizer):
|
|
"""
|
|
Implements 'Projected Adam', a variant of Adam with projections for each
|
|
dim of each tensor.
|
|
|
|
|
|
Args:
|
|
params: The parameters or param_groups to optimize (like other Optimizer subclasses)
|
|
lr: The learning rate. We will typically use a learning rate schedule that starts
|
|
at 0.03 and decreases over time, i.e. much higher than other common
|
|
optimizers.
|
|
betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad only for
|
|
scalar tensors. Must satisfy 0 < beta <= beta2 < 1.
|
|
size_lr_scale: A scaling factor on the learning rate, that we use to update the
|
|
scale of each parameter tensor. If each parameter were decomposed
|
|
as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, size_lr_scale
|
|
is the scaling factor on the learning rate of p_scale.
|
|
param_rms_smooth0: Limiting value of smoothing proportion for parameter matrix, as
|
|
assumed rank of param covariance [==product of sizes on the other
|
|
tensor dims] approaches 0.
|
|
param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|
param covariance equals the dimension of the covaraince matrix.
|
|
param_rms_smooth{0,1} determine the smoothing proportions for other
|
|
conditions.
|
|
cov_min,cov_max: [IMPORTANT] 5-tuples of minimums and maximums of the diagonal values of
|
|
covariance matrices, after normalizing to unit-mean. The first 3 are
|
|
for smoothing the parameter covariance, normalized in 3 different ways:
|
|
(1) relative to its own diagonal (in a basis that diagonalizes the grad
|
|
covariance);
|
|
(2) multiplied by the grad covariance,
|
|
(3) in the canonical basis.
|
|
|
|
(4) is for smoothing the grad covariance used for (2)
|
|
|
|
cov_pow: This was mainly added for development and experimentation purposes;
|
|
it allows you to smooth the parameter covariance matrices at the
|
|
stages (1), (2), (3) of smoothing mentioned above, and also
|
|
the gradient covariance matrix used in stage (2) of smoothing. If the
|
|
1st 3 values are not 1.0 it will cause a measurable speed penalty because it
|
|
requires SVD. Recommend to leave all these at 1.0.
|
|
eps: A general-purpose epsilon to prevent division by zero
|
|
denom_rel_eps: An epsilon value used to keep the elements of denominator (exp_avg_grad_sqrt.sqrt())
|
|
not too small relative to the mean value; this will limit the speed of update
|
|
along directions with very small gradients.
|
|
param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of
|
|
learning the scale on the parameters (we'll constrain the rms of each non-scalar
|
|
parameter tensor to be >= this value)
|
|
param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of
|
|
learning the scale on the parameters (we'll constrain the rms of each non-scalar
|
|
parameter tensor to be <= this value)
|
|
scalar_max: Maximum absolute value for scalar parameters (applicable if your
|
|
model has any parameters with numel() == 1)
|
|
size_update_period: The periodicity, in steps, with which we update the size (scale)
|
|
of the parameter tensor. This is provided to save a little time
|
|
in the update.
|
|
lr_update_period: [IMPORTANT]: A 2-tuple of ints that Determines the periodicity, in steps, with
|
|
which we update the learning-rate matrices. The first number is the periodicity at
|
|
the start of training, the second number is the periodicity
|
|
later in training, and we gradually increase from one to the other.
|
|
The reason for such a complicated schedule is that updating the learning
|
|
rate matrices is very slow, principally because it requires SVD, and SVD
|
|
seems to have quite slow implementations.
|
|
max_block_size: [IMPORTANT] The maximum block size in block-diagonal co-ordinate
|
|
transformations. You can probably set this to 512 or 1024. Larger
|
|
values will require more MEMORY and may be a bit slower, but should
|
|
lead to better optimization performance.
|
|
"""
|
|
def __init__(
|
|
self,
|
|
params,
|
|
lr=3e-02,
|
|
betas=(0.9, 0.98),
|
|
size_lr_scale=0.1,
|
|
cov_min=(0.025, 0.0025, 0.02, 0.0001),
|
|
cov_max=(10.0, 80.0, 5.0, 400.0),
|
|
cov_pow=(1.0, 1.0, 1.0, 1.0),
|
|
param_rms_smooth0=0.4,
|
|
param_rms_smooth1=0.2,
|
|
eps=1.0e-08,
|
|
denom_rel_eps=1.0e-05,
|
|
param_min_rms=1.0e-05,
|
|
param_max_rms=2.0,
|
|
scalar_max=2.0,
|
|
size_update_period=4,
|
|
lr_update_period=(200, 1000),
|
|
# grad_cov_period does not affect speed very much. If __name__ ==
|
|
# "__main__", keeping this coprime to 20 which is the number of
|
|
# sample pairs used in the test code
|
|
grad_cov_period=(3 if __name__ == "__main__" else 5),
|
|
max_block_size=1024,
|
|
):
|
|
|
|
|
|
|
|
defaults = dict(
|
|
lr=lr,
|
|
size_lr_scale=size_lr_scale,
|
|
cov_min=cov_min,
|
|
cov_max=cov_max,
|
|
cov_pow=cov_pow,
|
|
param_rms_smooth0=param_rms_smooth0,
|
|
param_rms_smooth1=param_rms_smooth1,
|
|
betas=betas,
|
|
eps=eps,
|
|
denom_rel_eps=denom_rel_eps,
|
|
param_min_rms=param_min_rms,
|
|
param_max_rms=param_max_rms,
|
|
scalar_max=scalar_max,
|
|
size_update_period=size_update_period,
|
|
lr_update_period=lr_update_period,
|
|
grad_cov_period=grad_cov_period,
|
|
max_block_size=max_block_size,
|
|
)
|
|
|
|
super(PrAdam, self).__init__(params, defaults)
|
|
|
|
def __setstate__(self, state):
|
|
super(PrAdam, self).__setstate__(state)
|
|
|
|
|
|
@torch.no_grad()
|
|
def step(self, closure=None):
|
|
"""Performs a single optimization step.
|
|
|
|
Arguments:
|
|
closure (callable, optional): A closure that reevaluates the model
|
|
and returns the loss.
|
|
"""
|
|
loss = None
|
|
if closure is not None:
|
|
with torch.enable_grad():
|
|
loss = closure()
|
|
|
|
batch = True
|
|
for group in self.param_groups:
|
|
for p, state in self.batched_params(group["params"]):
|
|
|
|
# Perform optimization step
|
|
grad = p.grad
|
|
if grad.is_sparse:
|
|
raise RuntimeError(
|
|
"PrAdam optimizer does not support sparse gradients"
|
|
)
|
|
# State initialization
|
|
if len(state) == 0:
|
|
self._init_state(group, p, state)
|
|
self._step_one_batch(group, p, state)
|
|
|
|
|
|
return loss
|
|
|
|
|
|
def _init_state(self,
|
|
group: dict,
|
|
p: Tensor,
|
|
state: dict):
|
|
"""
|
|
Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p
|
|
is actually the batch dimension, corresponding to batched-together
|
|
parameters of a given shape.
|
|
|
|
Args:
|
|
group: Dict to look up configuration values.
|
|
p: The parameter that we are initializing the state for, actually
|
|
will be stacked-together "real" parameters
|
|
state: Dict from string to whatever state we are initializing
|
|
"""
|
|
eps = group["eps"]
|
|
size_update_period = group["size_update_period"]
|
|
max_block_size = group["max_block_size"]
|
|
|
|
state["step"] = 0
|
|
|
|
kwargs = {'device':p.device, 'dtype':p.dtype}
|
|
|
|
# 'delta' implements conventional momentum. There are
|
|
# several different kinds of update going on, so rather than
|
|
# compute "exp_avg" like in Adam, we store and decay a
|
|
# parameter-change "delta", which combines all forms of
|
|
# update. this is equivalent to how it's done in Adam,
|
|
# except for the first few steps.
|
|
state["delta"] = torch.zeros_like(
|
|
p, memory_format=torch.preserve_format
|
|
)
|
|
|
|
batch_size = p.shape[0]
|
|
numel = p.numel() // batch_size
|
|
|
|
if numel > 1:
|
|
# "param_rms" just periodically records the scalar root-mean-square value of
|
|
# the parameter tensor.
|
|
# it has a shape like (batch_size, 1, 1, 1, 1)
|
|
param_rms = _mean(p**2, exclude_dims=[0], keepdim=True).sqrt() + eps
|
|
state["param_rms"] = param_rms
|
|
|
|
state["scale_exp_avg_sq"] = torch.zeros_like(param_rms)
|
|
state["scale_grads"] = torch.zeros(size_update_period, *param_rms.shape,
|
|
**kwargs)
|
|
|
|
# exp_avg_sq is in the projected space. It is periodically zeroed, and we
|
|
# set zero_step whenever this happens.
|
|
state["exp_avg_sq"] = torch.zeros_like(
|
|
p, memory_format=torch.preserve_format
|
|
)
|
|
|
|
# ignore_rank1_dims = True means we don't do our basis-changing update
|
|
# on tensors like biases that only have one non-trivial dimension.
|
|
# "rank1" refers to the rank of our estimate of the parameter
|
|
# covariance, if we were to estimate it just from the current parameter
|
|
# value.
|
|
ignore_rank1_dims = True
|
|
|
|
trivial_update = True
|
|
for dim in range(1, p.ndim):
|
|
size = p.shape[dim]
|
|
if size == 1 or (size == numel and ignore_rank1_dims):
|
|
continue
|
|
trivial_update = False
|
|
|
|
state["trivial_update"] = trivial_update # so we know whether to zero exp_avg_sq stats.
|
|
if trivial_update:
|
|
return
|
|
|
|
# "zero_step" being a member of state is the sign that this parameter has
|
|
# at least one dim that has a projection. It also records the most recent
|
|
# step on which we zeroed state["exp_avg_sq"]
|
|
state["zero_step"] = 0
|
|
# last_param_scale_update records the last time we updated the part of the learning rate
|
|
# matrices that relates to the parameter covariance; we avoid doing this too often
|
|
# as it doesn't change very rapidly and the SVD is quite slow.
|
|
state["last_param_scale_update"] = -1
|
|
|
|
for dim in range(1, p.ndim):
|
|
size = p.shape[dim]
|
|
if size == 1 or (size == numel and ignore_rank1_dims):
|
|
continue
|
|
|
|
# Most of the time, (num_blocks, block_size) will equal (1, size)
|
|
num_blocks, block_size = self._get_block_size(size, max_block_size)
|
|
|
|
# Q_{dim} is be the learning-rate matrix for this
|
|
# dimension, a matrix of indexed [batch_idx, block_idx, diagonalized_coordinate, canonical_coordinate].
|
|
# this is used twice in the update step, once transposed and once without transpose.
|
|
Q = torch.eye(block_size, block_size, **kwargs).unsqueeze(0).unsqueeze(0).expand(
|
|
batch_size, num_blocks, block_size, block_size).contiguous()
|
|
state[f"Q_{dim}"] = Q
|
|
# cov_{dim} is the averaged-over-time covariance of parameters on this dimension, treating
|
|
# all other dims as a batch axis. Also initialize as identity.
|
|
state[f"param_cov_{dim}"] = Q.clone()
|
|
|
|
# grad_cov_{dim} is the covariance of gradients on this axis (without
|
|
# any co-ordinate changes), treating all other axes as as a batch axis.
|
|
# This is needed when we re-diagonalize, and to compute the
|
|
# grad_rms_{dim}. We store it as a decaying average, decaying with beta2
|
|
|
|
# only for purposes of computing the scalar factor on the
|
|
# learning rate; and, as a result of this, also contributes something
|
|
# to the gradient w.r.t. f"Q_{dim}", which is one reason
|
|
# why we allocate a variable to keep track of its moving average
|
|
# instead of just using a temporary and smoothing the scalar factor.
|
|
state[f"grad_cov_{dim}"] = torch.zeros_like(Q)
|
|
|
|
|
|
|
|
def _get_block_size(self, size: int, max_block_size: int) -> Tuple[int,int]:
|
|
"""
|
|
Returns information about the block size for a block-diagonal structure
|
|
of covariances and co-ordinate transformations.
|
|
|
|
size: The size of the parameter tensor on one of its dimensions, e.g.
|
|
a channel dim or kernel size
|
|
max_block_size: The maximum block size allowed (of the blocks in a
|
|
block-diagonal structure)
|
|
Returns:
|
|
(num_blocks, block_size)
|
|
where nun_blocks * block_size == size and block_size <= max_block_size
|
|
"""
|
|
for num_blocks in range(1, size):
|
|
block_size = size // num_blocks
|
|
if block_size * num_blocks == size and block_size <= max_block_size:
|
|
return (num_blocks, block_size)
|
|
assert False, (size, max_block_size) # code error or e.g. negative or
|
|
# zero inputs
|
|
|
|
|
|
def _step_one_batch(self,
|
|
group: dict,
|
|
p: Tensor,
|
|
state: dict):
|
|
"""
|
|
Do the step for one parameter, which is actually going to be a batch of
|
|
`real` parameters, with dim 0 as the batch dim.
|
|
Args:
|
|
group: dict to look up configuration values
|
|
p: parameter to update (actually multiple parameters stacked together
|
|
as a batch)
|
|
state: state-dict for p, to look up the optimizer state
|
|
"""
|
|
lr = group["lr"]
|
|
size_update_period = group["size_update_period"]
|
|
grad_cov_period = group["grad_cov_period"]
|
|
eps = group["eps"]
|
|
beta1 = group["betas"][0]
|
|
|
|
grad = p.grad
|
|
step = state["step"]
|
|
delta = state["delta"]
|
|
|
|
delta.mul_(beta1)
|
|
batch_size = p.shape[0]
|
|
numel = p.numel() // batch_size
|
|
if numel > 1:
|
|
# Update the size/scale of p, and set param_rms
|
|
scale_grads = state["scale_grads"]
|
|
scale_grads[step % size_update_period] = _sum(p * grad,
|
|
exclude_dims=[0],
|
|
keepdim=True)
|
|
if step % size_update_period == size_update_period - 1:
|
|
param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..)
|
|
param_rms.copy_(_mean(p ** 2, exclude_dims=[0],
|
|
keepdim=True).sqrt().clamp_(min=eps))
|
|
if step > 0:
|
|
# self._size_update() learns the overall scale on the
|
|
# parameter, by shrinking or expanding it.
|
|
self._size_update(group, scale_grads, p, state)
|
|
|
|
|
|
if numel == 1:
|
|
# For parameters with 1 element we just use regular Adam.
|
|
# Updates delta.
|
|
self._step_scalar(group, p, state)
|
|
else:
|
|
if self._is_lr_update_step(group, state):
|
|
self._update_param_cov(group, p, state)
|
|
self._compute_bases(group, p.shape, state)
|
|
self._zero_exp_avg_sq(state)
|
|
if step % grad_cov_period == 0:
|
|
self._update_grad_cov(group, p, state)
|
|
self._step(group, p, state)
|
|
|
|
state["step"] = step + 1
|
|
|
|
|
|
def _size_update(self,
|
|
group: dict,
|
|
scale_grads: Tensor,
|
|
p: Tensor,
|
|
state: dict) -> None:
|
|
"""
|
|
Called only where p.numel() > 1, this updates the scale of the parameter.
|
|
If we imagine: p = underlying_param * scale.exp(), and we are doing
|
|
gradient descent on underlying param and on scale, this function does the update
|
|
on `scale`.
|
|
|
|
Args:
|
|
group: dict to look up configuration values
|
|
scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing
|
|
grads w.r.t. the scales.
|
|
p: The parameter to update
|
|
state: The state-dict of p
|
|
"""
|
|
|
|
param_rms = state["param_rms"]
|
|
beta1, beta2 = group["betas"]
|
|
size_lr = group["lr"] * group["size_lr_scale"]
|
|
param_min_rms = group["param_min_rms"]
|
|
param_max_rms = group["param_max_rms"]
|
|
step = state["step"]
|
|
|
|
batch_size = p.shape[0]
|
|
size_update_period = scale_grads.shape[0]
|
|
# correct beta2 for the size update period: we will have
|
|
# faster decay at this level.
|
|
beta2_corr = beta2 ** size_update_period
|
|
|
|
scale_exp_avg_sq = state["scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..)
|
|
scale_exp_avg_sq.mul_(beta2_corr).add_(
|
|
(scale_grads ** 2).mean(dim=0), # mean over dim `size_update_period`
|
|
alpha=1-beta2_corr) # shape is (batch_size, 1, 1, ...)
|
|
|
|
# The 1st time we reach here is when size_step == 1.
|
|
size_step = (step + 1) // size_update_period
|
|
bias_correction2 = 1 - beta2_corr ** size_step
|
|
# we don't bother with bias_correction1; this will help prevent divergence
|
|
# at the start of training.
|
|
|
|
eps = 1.0e-10 # this value is not critical.
|
|
denom = scale_exp_avg_sq.sqrt() + eps # shape: (batch_size, 1, 1, ..)
|
|
|
|
scale_step = -size_lr * (bias_correction2 ** 0.5) * scale_grads.sum(dim=0) / denom
|
|
|
|
is_too_small = (param_rms < param_min_rms)
|
|
is_too_large = (param_rms > param_max_rms)
|
|
|
|
scale_step.masked_fill_(is_too_small, size_lr * size_update_period)
|
|
scale_step.masked_fill_(is_too_large, -size_lr * size_update_period)
|
|
delta = state["delta"]
|
|
# the factor of (1-beta1) relates to momentum.
|
|
delta.add_(p * scale_step, alpha=(1-beta1))
|
|
|
|
def _update_param_cov(self,
|
|
group: dict,
|
|
p: Tensor,
|
|
state: dict) -> None:
|
|
"""
|
|
Updates state[f"param_cov_{dim}"] for each dim of tensor p
|
|
(except batch and trivial and rank-1 dims)
|
|
"""
|
|
eps = group["eps"]
|
|
|
|
# zero_step is always the last time we called _update_param_cov.
|
|
# Our aim is to compute the parameter covariance averaged over all time
|
|
# (in steps) from the start, so "this_weight" that we give to this
|
|
# new covariance is proportional to the interval of time it represents,
|
|
# i.e. the interval from zero_step until step, while the existing "param_cov"
|
|
# represents the interval from step until zero_step.
|
|
# The min of 0.5 is a special case to ensure that the "torch.eye" we initialized
|
|
# param_cov with gets some weight on the 1st time we call this.
|
|
zero_step = state["zero_step"]
|
|
step = state["step"]
|
|
|
|
this_weight = min(0.5, (step - zero_step) / step)
|
|
|
|
batch_size = p.shape[0]
|
|
numel = p.numel() // batch_size
|
|
for dim in range(1, p.ndim):
|
|
size = p.shape[dim]
|
|
try:
|
|
param_cov = state[f"param_cov_{dim}"]
|
|
except KeyError:
|
|
continue # e.g. size == 1 or size == numel
|
|
|
|
# param_cov shape: (batch_size, num_blocks, block_size, block_size)
|
|
(batch_size, num_blocks, block_size, block_size) = param_cov.shape
|
|
|
|
# M: (batch_size, num_blocks, x, block_size)
|
|
M = p.transpose(dim, -1).reshape(batch_size, -1,
|
|
num_blocks, block_size).transpose(1, 2)
|
|
|
|
# will be batched matrix multiply. shape of this_param_cov:
|
|
# (batch_size, num_blocks, block_size, block_size)
|
|
this_param_cov = torch.matmul(M.transpose(2, 3), M)
|
|
|
|
# normalize scale of this param_cov, in case parameter scale
|
|
# changes significantly during training, which would cause some
|
|
# parts of the training timeline to be more highly weighted.
|
|
# shape of expression on r.h.s of "/=" has shape (batch_size, 1, 1, 1)
|
|
this_param_cov /= _mean(_diag(this_param_cov),
|
|
exclude_dims=[0], keepdim=True).unsqueeze(-1) + eps
|
|
|
|
param_cov.mul_(1-this_weight).add_(this_param_cov,
|
|
alpha=this_weight)
|
|
|
|
|
|
def _zero_exp_avg_sq(self,
|
|
state: dict) -> None:
|
|
"""
|
|
Zero the exp_avg_sq stats, and set state["zero_step"] to state["step"]
|
|
"""
|
|
state["exp_avg_sq"].zero_()
|
|
state["zero_step"] = state["step"]
|
|
|
|
def _is_lr_update_step(self,
|
|
group: dict,
|
|
state: dict) -> bool:
|
|
"""
|
|
Returns True if on this step we need to update the learning-rate matrices for this tensor
|
|
and False if not (note: this may just mean we update the gradient-dependent part).
|
|
The periodicity with which we update them increases from
|
|
(by default) 200 at the start of training to 2000 later on.
|
|
"""
|
|
try:
|
|
zero_step = state["zero_step"]
|
|
except:
|
|
# This parameter tensor has no learning-rate matrices to estimate
|
|
return False
|
|
step = state["step"]
|
|
period_initial, period_final = group["lr_update_period"]
|
|
# this formula gradually increases the periodicity from period_initial at the start
|
|
# to period_final when step >> 4 * period_final
|
|
cur_update_period = (period_initial +
|
|
((period_final - period_initial) * step /
|
|
(step + 4 * period_final)))
|
|
return (step >= zero_step + cur_update_period)
|
|
|
|
|
|
|
|
def _compute_bases(self,
|
|
group: dict,
|
|
p_shape: torch.Size,
|
|
state: dict):
|
|
"""
|
|
For each tensor dimension that we are preconditioning, this function sets
|
|
state[f"Q_{dim}"] to an orthogonal matrix (per batch element and
|
|
block); it will be indexed [batch_index, block_index, diagonalized_coordinate, canonical_coordinate].
|
|
This will be the matrix that diagonalizes Z,
|
|
where Z is the symmetric matrix satisfying:
|
|
Z G Z = P (eqn:1)
|
|
with G being the gradient covariance grad_cov, P being the smoothed version of the parameter
|
|
covariance param_cov, and Z the symmetric positive definit learning-rate matrices.
|
|
We'll discuss later how we solve for Z. Firstly, because we want
|
|
a basis that is as meaningful possible for smoothing P, we will first put P and G in a basis
|
|
where G is diagonal. That is, we do an SVD G = U G' U^T, and multiplying (eqn:1) on the
|
|
left and right by U^T and U respetively, and inserting some expressions equivalent to I,
|
|
we have:
|
|
(U^T Z U) (U^T G U) (U^T Z U) = (U^T P U)
|
|
or: Z' G' Z' = P' (eqn:2)
|
|
where Z' = U^T Z U and P' = U^T P U, and of course G' is diagonal. We are actually going to
|
|
be solving (eqn:2), and then computing:
|
|
Z = U Z' U^T.
|
|
|
|
A solution to (eqn:2) is as follows. We are going to be using a Cholesky-based solution in
|
|
favor of one that requires SVD or eigenvalue decomposition, because it is much faster (we first
|
|
have to be careful that the input is not close to singular, though).
|
|
|
|
So, with the Cholesky decomposition of P' being:
|
|
P' = C C^T,
|
|
the solution is going to be:
|
|
Z' = C (C^T G' C)^{-0.5} C^T (eqn:3)
|
|
[[ we can verify that this is a solution by multiplying (eqn:1) by C^{-1} and C^{-T} on the
|
|
left and right respectively, giving:
|
|
(C^T G' C)^{-0.5} C^T G' C (C^T G' C)^{-0.5} = I
|
|
which we can immediately see is the case. ]]
|
|
|
|
It's actually going to be more convenient to compute Z'^{1} (the inverse of Z'), because
|
|
ultimately only need the basis that diagonalizes Z, and (C^T G' C) may be singular.
|
|
It's not hard to work out that:
|
|
Z'^{-1} = C^{-T} (C^T G' C)^0.5 C^{-1} (eqn:4)
|
|
|
|
|
|
Args:
|
|
group: dict to look up config values
|
|
p_shape: the shape of the batch of identical-sized tensors we are optimizing
|
|
state: dict to look up optimization state
|
|
|
|
Return:
|
|
This function returns a list of Tensors P_proj indexed by dim from 0..ndim-1, the tensors being
|
|
of shape (batch_size, size), which contain the diagonal of the smoothed parameter covariance
|
|
P after projection by the orthogonal matrix state[f"Q_{dim}"] that diagonalizes Z (this is
|
|
the space in which we do the actual update).
|
|
"""
|
|
ndim = len(p_shape)
|
|
P_proj = [None] * ndim
|
|
denom_rel_eps = group["denom_rel_eps"]
|
|
eps = group["eps"]
|
|
|
|
for dim in range(1, ndim):
|
|
size = p_shape[dim]
|
|
try:
|
|
Q = state[f"Q_{dim}"]
|
|
except KeyError:
|
|
assert size == 1 or size == numel, size
|
|
continue # e.g. size == 1 or size == numel
|
|
|
|
|
|
param_cov = state[f"param_cov_{dim}"] # (batch_size, num_blocks, block_size, block_size)
|
|
grad_cov = state[f"grad_cov_{dim}"] # (batch_size, num_blocks, block_size, block_size)
|
|
(batch_size, num_blocks, block_size, block_size) = param_cov.shape
|
|
# U_g diagonalizes grad_cov, in the sense that U_g^T grad_cov U_g is diagonal; its shape
|
|
# is the same as param_cov and grad_cov.
|
|
#
|
|
# G_prime is the diagonalized grad_cov, G' above, of shape (batch_size, num_blocks, block_size).
|
|
|
|
G = grad_cov.clone()
|
|
|
|
|
|
# Use the form of the diagonalized gradient matrix that we get after
|
|
# we add the Adam-type smoothing with epsilon.
|
|
|
|
G_diag = _diag(G) # aliased with G
|
|
G_diag += (_mean(G_diag, exclude_dims=[0], keepdim=True) *(denom_rel_eps * denom_rel_eps) +
|
|
(eps * eps))
|
|
|
|
P_unsmoothed = param_cov
|
|
P = param_cov.clone()
|
|
P = self._smooth_param_cov(group, p_shape, P, G)
|
|
|
|
if random.random() < 0.25:
|
|
_,S,_ = _svd(P)
|
|
logging.info(f"Eigs of P are {S[0][0]}")
|
|
|
|
|
|
# C will satisfy: P == torch.matmul(C, C.transpose(2, 3))
|
|
# C is of shape (batch_size, num_blocks, block_size, block_size).
|
|
#def _fake_cholesky(X):
|
|
# U_, S_, _ = _svd(X)
|
|
# return U_ * S_.sqrt().unsqueeze(-2)
|
|
#C = _fake_cholesky(P)
|
|
C = P.cholesky()
|
|
|
|
# A matrix that takes normally distributed data to P would
|
|
# be C, because C I C^T = C C^T = P. We can actually use *any* matrix
|
|
# that takes normally distributed data to P, so we can use
|
|
# C U for any orthogonal U, since C U I U^T C^T == P.
|
|
# So there is no harm in choosing a matrix U that diagonalizes the
|
|
# projected grad_cov. grad_cov gets projected by
|
|
|
|
grad_cov_proj = torch.matmul(C.transpose(2, 3),
|
|
torch.matmul(grad_cov, C))
|
|
|
|
# OK, grad_cov is diagonalized by U^T C^T. So the projection that we
|
|
# apply to the param cov is C U
|
|
U, S, _ = _svd(grad_cov_proj)
|
|
|
|
# proj is indexed [batch_idx,block_idx,canonical_coordinate,diagonalized_coordinate],
|
|
# so we need to transpose to get Q_{dim}.
|
|
proj = torch.matmul(C, U)
|
|
|
|
Q[:] = proj.transpose(2, 3)
|
|
|
|
continue
|
|
|
|
|
|
def _smooth_param_cov(self,
|
|
group: dict,
|
|
p_shape: torch.Size,
|
|
P: Tensor,
|
|
G: Tensor) -> Tensor:
|
|
"""
|
|
This function returns a modified/smoothed version of the parameter covariance
|
|
P.
|
|
|
|
Args:
|
|
group: dict to look up config values
|
|
p_shape: The shape of the parameter we are optimizing
|
|
P: a Tensor of shape (batch_size, num_blocks, block_size, block_size),
|
|
containing the parameter covariance
|
|
G: the gradient covariance, of shape (batch_size, num_blocks,
|
|
block_size, block_size)
|
|
|
|
state[f"param_cov_{dim}"], which is an estimate of the covariance of the parameter
|
|
p, averaged over time, and taken over dimension `dim` of the tensor.
|
|
|
|
The smoothing done here limits the extent to which the parameter covariance
|
|
can be strongly "off-diagonal" with respect to the gradient covariance. That is:
|
|
if the parameter covariance is just the gradient covariance to some power, this
|
|
function does no smoothing; but if it is highly off-diagonal we do more smoothing.
|
|
"""
|
|
# do smoothing based on 'rank',
|
|
# that is intended to compensate for bad estimates of P.
|
|
(batch_size, num_blocks, block_size, block_size) = P.shape
|
|
# `rank_per_block` is the rank of each block of P_prime if we were to estimate it from just one
|
|
# parameter tensor. We average it over time, but actually it won't be changing
|
|
# too much, so `rank` does tell us something.
|
|
size = num_blocks * block_size
|
|
rank = p_shape.numel() // (size * batch_size) # actually the rank of each block
|
|
smooth0 = group["param_rms_smooth0"]
|
|
smooth1 = group["param_rms_smooth1"]
|
|
# We want expr for smoothing amount to be of the form: smooth = alpha * size / (beta*rank + size)
|
|
# for "size" here, we actually want to use block_size, since we are concerned about the
|
|
# robustness of the covariance within these blocks.
|
|
# param_rms_smooth{0,1} represents the user-specified desired amount of smoothing
|
|
# when rank==0*size and rank==1*size, respectively.
|
|
# from rank==0*size, we get smooth0 = alpha * size/size, so alpha = smooth0.
|
|
# from setting rank==size, we get smooth1 = alpha * size / (beta*size * size) = alpha/(1+beta),
|
|
# so smooth1 == smooth0 / (1+beta), so (1+beta) = smooth0/smooth1, so beta=smooth0/smooth1 - 1
|
|
smooth = smooth0 * block_size / ((smooth0/smooth1 - 1) * rank + block_size)
|
|
if True:
|
|
logging.info(f"block size={block_size}, rank={rank}, smooth={smooth}")
|
|
|
|
# add rank-dependent smoothing amount to diagonal of P_prime. _diag() returns an aliased tensor.
|
|
# we don't need to multiply `smooth` by anything, because at this point, P_prime should have
|
|
# diagonal elements close to 1.
|
|
|
|
P = P.clone()
|
|
P_diag = _diag(P)
|
|
P_diag_mean = _mean(P_diag, exclude_dims=[0], keepdim=True)
|
|
P_diag += smooth * P_diag_mean
|
|
|
|
#G = G.clone()
|
|
#G_diag = _diag(G) # aliased
|
|
#G_diag *= 1.005 # ensure invertible.
|
|
G = self._smooth_cov(G,
|
|
group["cov_min"][3],
|
|
group["cov_max"][3],
|
|
group["cov_pow"][3])
|
|
|
|
|
|
P = self._apply_min_max_with_metric(P, G,
|
|
group["cov_min"][1],
|
|
group["cov_max"][1])
|
|
|
|
|
|
# Apply a 3rd round of smoothing in the canonical basis.
|
|
P = self._smooth_cov(P,
|
|
group["cov_min"][2],
|
|
group["cov_max"][2],
|
|
group["cov_pow"][2])
|
|
return P
|
|
|
|
def _smooth_cov(self,
|
|
X: Tensor,
|
|
min_eig: float,
|
|
max_eig: float,
|
|
power: float = 1.0) -> Tensor:
|
|
"""
|
|
Returns a `smoothed` version of a symmetric positive definite covariance matrix
|
|
[with block-diagonal structure, in a batch]. This is done without SVD (which
|
|
can be very slow).
|
|
The eigenvalues L will be transformed as:
|
|
|
|
L = L + min_eig * L.mean() + eps
|
|
L /= L.mean()
|
|
L = 1 / (1/L + 1/max_eig) # soft-min between L and max_eig
|
|
L = L ** power # need SVD for this, will get rid of the requirement later.
|
|
L /= L.mean()
|
|
|
|
# Note on approximation functions like x^0.75 for smallish x: on wolframalpha, type:
|
|
# plot x^0.75 and 0.05 + (1.1x - 0.18 x^2 + 0.02 x^3) for x from 0 to 10
|
|
# [this starts to diverge after 5 or so]
|
|
|
|
Args:
|
|
min_eig: minimum allowed eigenvalue of returned X
|
|
max_eig: maximum allowed eigenvalue of returned X
|
|
power: power to take eigenvalues to
|
|
X: the batch of symmetric positive definite tensors we are smoothing;
|
|
of shape (batch_size, num_blocks, block_size, block_size)
|
|
"""
|
|
eps = 1.0e-20
|
|
if power != 1.0:
|
|
U, S, _ = _svd(X)
|
|
S_mean = _mean(S, exclude_dims=[0], keepdim=True)
|
|
S = S + min_eig * S_mean + eps
|
|
S_mean = S_mean * (1 + min_eig) + eps
|
|
S = S / S_mean
|
|
S = 1. / (1./S + 1./max_eig)
|
|
S = S ** power
|
|
S = S / _mean(S, exclude_dims=[0], keepdim=True)
|
|
return torch.matmul(U * S.unsqueeze(-2), U.transpose(2, 3))
|
|
else:
|
|
X = X.clone()
|
|
diag = _diag(X) # Aliased with X
|
|
mean_eig = _mean(diag, exclude_dims=[0], keepdim=True)
|
|
diag += (mean_eig * min_eig + eps)
|
|
cur_diag_mean = mean_eig * (1 + min_eig) + eps
|
|
X /= cur_diag_mean.unsqueeze(-1)
|
|
# OK, now the mean of the diagonal of X is 1 (or less than 1, in
|
|
# which case X is extremely tiny).
|
|
|
|
# eig_ceil is the maximum possible eigenvalue that X could possibly
|
|
# have at this time, equal to num_blocks * block_size.
|
|
eig_ceil = X.shape[1] * X.shape[3]
|
|
|
|
# the next statement wslightly adjusts the target to be the same as
|
|
# what the baseline function, eig -> 1./(1./eig + 1./max_eig) would
|
|
# give. so "max_eig" is now the target of the function for arg ==
|
|
# eig_ceil.
|
|
max_eig = 1. / (1. / max_eig + 1. / eig_ceil)
|
|
|
|
while max_eig <= eig_ceil * 0.5:
|
|
# this is equivalent to the operation: l -> l - 0.5/eig_ceil*l*l
|
|
# on eigenvalues, which maps eig_ceil to 0.5*eig_ceil and is monotonically
|
|
# increasing from 0..eig_ceil.
|
|
# the transpose on the 2nd X is to try to stop small asymmetries from
|
|
# propagating.
|
|
X = X - 0.5/eig_ceil * torch.matmul(X, X.transpose(2, 3))
|
|
eig_ceil = 0.5 * eig_ceil
|
|
|
|
# max_eig > eig_ceil * 0.5
|
|
if max_eig < eig_ceil:
|
|
# map l -> l - coeff*l*l, if l==eig_ceil, this
|
|
# takes us to:
|
|
# eig_ceil - (eig_ceil-max_eig/(eig_ceil*eig_ceil))*eig_ceil*eig_ceil
|
|
# == max_eig
|
|
# .. and the fact that coeff <= 0.5/eig_ceil [since max_eig>eig_ceil*0.5]
|
|
# means that the function is monotonic on inputs from 0 to eig_ceil.
|
|
coeff = (eig_ceil - max_eig) / (eig_ceil*eig_ceil)
|
|
X = X - coeff * torch.matmul(X, X.transpose(2, 3))
|
|
|
|
# Normalize again.
|
|
X /= _mean(_diag(X), exclude_dims=[0], keepdim=True).unsqueeze(-1)
|
|
X = 0.5 * (X + X.transpose(2, 3)) # make sure exactly symmetric.
|
|
return X
|
|
|
|
|
|
def _apply_min_max_with_metric(self,
|
|
X: Tensor,
|
|
M: Tensor,
|
|
min_eig: float,
|
|
max_eig: float) -> Tensor:
|
|
"""
|
|
Smooths X with maximum eigenvalue (relative to the mean) relative to
|
|
metric M. Equivalent to applying
|
|
Y := M^{0.5} X M^{0.5}
|
|
Apply maximum to eigenvalues of Y, as done in _smooth_cov()
|
|
X := M^{-0.5} Y M^{-0.5}
|
|
|
|
Args:
|
|
X: Batch of block-diagonal nonzero positive definite matrices to have max eigenvalue applied
|
|
M: Batch of positive definite block-diagonal matrices to use as metrics w.r.t. X.
|
|
"""
|
|
X = X.clone()
|
|
# size of the block-diagonal matrix..
|
|
size = X.shape[1] * X.shape[3]
|
|
# mean eig of M^{0.5} X M^{0.5} ...
|
|
mean_eig = _sum(X*M, exclude_dims=[0], keepdim=True) / size
|
|
# make sure eigs of M^{0.5} X M^{0.5} are average 1. this imposes limit on the max.
|
|
X /= mean_eig
|
|
|
|
if min_eig != 0.0:
|
|
X = X * (1.0-min_eig) + min_eig * M.inverse()
|
|
X = 0.5 * (X + X.transpose(-2, -1)) # make sure exactly symmetric.
|
|
|
|
# eig_ceil is the maximum possible eigenvalue that X could possibly
|
|
# have at this time, equal to num_blocks * block_size.
|
|
eig_ceil = size
|
|
|
|
# the next statement wslightly adjusts the target to be the same as
|
|
# what the baseline function, eig -> 1./(1./eig + 1./max_eig) would
|
|
# give. so "max_eig" is now the target of the function for arg ==
|
|
# eig_ceil.
|
|
max_eig = 1. / (1. / max_eig + 1. / eig_ceil)
|
|
|
|
while max_eig <= eig_ceil * 0.5:
|
|
# this is equivalent to the operation: l -> l - 0.5/eig_ceil*l*l
|
|
# on eigenvalues, which maps eig_ceil to 0.5*eig_ceil and is monotonically
|
|
# increasing from 0..eig_ceil.
|
|
# The transpose is to try to stop small asymmetries from increasing.
|
|
X = X - 0.5/eig_ceil * torch.matmul(X, torch.matmul(M, X.transpose(2, 3)))
|
|
eig_ceil = 0.5 * eig_ceil
|
|
|
|
# max_eig > eig_ceil * 0.5
|
|
if max_eig < eig_ceil:
|
|
# map l -> l - coeff*l*l, if l==eig_ceil, this
|
|
# takes us to:
|
|
# eig_ceil - (eig_ceil-max_eig/(eig_ceil*eig_ceil))*eig_ceil*eig_ceil
|
|
# == max_eig
|
|
# .. and the fact that coeff <= 0.5/eig_ceil [since max_eig>eig_ceil*0.5]
|
|
# means that the function is monotonic on inputs from 0 to eig_ceil.
|
|
coeff = (eig_ceil - max_eig) / (eig_ceil*eig_ceil)
|
|
X = X - coeff * torch.matmul(X, torch.matmul(M, X.transpose(2, 3)))
|
|
|
|
# Normalize again.
|
|
X /= _mean(_diag(X), exclude_dims=[0], keepdim=True).unsqueeze(-1)
|
|
X = 0.5 * (X + X.transpose(-2, -1)) # make sure exactly symmetric.
|
|
return X
|
|
|
|
|
|
def _update_grad_cov(self,
|
|
group: dict,
|
|
p: Tensor,
|
|
state: dict):
|
|
"""
|
|
Updates the gradient covariance state[f"grad_cov_{dim}"] for appropriate
|
|
dimensions of the tensor.
|
|
"""
|
|
beta2 = group["betas"][1]
|
|
grad = p.grad
|
|
batch_size = p.shape[0]
|
|
for dim in range(1, grad.ndim): # dim 0 is batch dim.
|
|
size = grad.shape[dim]
|
|
name = f"grad_cov_{dim}"
|
|
if name not in state:
|
|
continue # e.g. size==1, size == numel, size > max_block_size
|
|
grad_cov = state[name] # shaped: (batch_size, num_blocks, block_size, block_size)
|
|
(batch_size, num_blocks, block_size, block_size) = grad_cov.shape
|
|
|
|
g = grad.transpose(-1, dim)
|
|
# if grad were of shape (batch_size, x, size, y, z),
|
|
# g would be of shape (batch_size, x, z, y, size); and
|
|
# after the next line, g will be of shape
|
|
# (batch_size, x, z, y, num_blocks, block_size)
|
|
g = g.reshape(*g.shape[:-1], num_blocks, block_size)
|
|
g = _move_dim(g, -2, 1) # (batch_size, num_blocks, x, z, y, block_size)
|
|
g = g.reshape(batch_size, num_blocks, -1, block_size)
|
|
# now g is of shape (batch_size, num_blocks, x*z*y, block_size)
|
|
|
|
# this_grad_cov: (batch_size, num_blocks, block_size, block_size)
|
|
this_grad_cov = torch.matmul(g.transpose(-2, -1), g)
|
|
|
|
grad_cov.mul_(beta2).add_(this_grad_cov, alpha=(1-beta2))
|
|
|
|
def _step(self,
|
|
group: dict,
|
|
p: Tensor,
|
|
state: dict):
|
|
"""
|
|
This function does the core update of self.step(), in the case where the tensor
|
|
has more than 1 elements. It multiplies the moving-average gradient tensor by a
|
|
state["lr_{dim}"] on each non-trivial axis/dimension, does a length normalization,
|
|
and takes a step in that direction.
|
|
|
|
Args:
|
|
group: A dict which will be used to look up configuration values
|
|
p: The parameter to be updated
|
|
grad: The grad of p
|
|
state: The state-dict corresponding to parameter p
|
|
batch: if true, the first dimension of p is a batch dim, i.e. a batch of
|
|
parameter matrices.
|
|
|
|
This function modifies p.
|
|
"""
|
|
grad = p.grad
|
|
lr = group["lr"]
|
|
beta1, beta2 = group["betas"]
|
|
eps = group["eps"]
|
|
denom_rel_eps = group["denom_rel_eps"]
|
|
step = state["step"]
|
|
|
|
grad = self._project(grad, state, forward=True)
|
|
|
|
exp_avg_sq = state["exp_avg_sq"]
|
|
exp_avg_sq.mul_(beta2).addcmul_(grad, grad,
|
|
value=(1-beta2))
|
|
|
|
this_step = state["step"] - (state["zero_step"] if "zero_step" in state else 0)
|
|
bias_correction2 = 1 - beta2 ** (this_step + 1)
|
|
if bias_correction2 < 0.99:
|
|
# note: not in-place.
|
|
exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2)
|
|
|
|
denom = exp_avg_sq.sqrt()
|
|
denom += eps + denom_rel_eps * _mean(denom, exclude_dims=[0], keepdim=True)
|
|
grad = grad / denom
|
|
|
|
# and project back..
|
|
grad = self._project(grad, state, forward=False)
|
|
|
|
|
|
alpha = -lr * (1-beta1) * state["param_rms"]
|
|
|
|
delta = state["delta"]
|
|
delta.add_(grad * alpha)
|
|
p.add_(delta)
|
|
|
|
|
|
def _step_scalar(self,
|
|
group: dict,
|
|
p: Tensor,
|
|
state: dict):
|
|
"""
|
|
A simplified form of the core update for scalar tensors, where we cannot get a good
|
|
estimate of the parameter rms. p will not be a scalar, because it has a batch dim.
|
|
"""
|
|
beta1, beta2 = group["betas"]
|
|
scalar_max = group["scalar_max"]
|
|
eps = group["eps"]
|
|
lr = group["lr"]
|
|
grad = p.grad
|
|
|
|
exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,)
|
|
exp_avg_sq.mul_(beta2).addcmul_(grad, grad,
|
|
value=1-beta2)
|
|
|
|
# bias_correction2 is like in Adam. Don't bother with bias_correction1;
|
|
# slower update at the start will help stability anyway.
|
|
bias_correction2 = 1 - beta2 ** (state["step"] + 1)
|
|
denom = (exp_avg_sq / bias_correction2).sqrt() + eps
|
|
|
|
delta = state["delta"]
|
|
delta.add_(grad / denom, alpha=-lr*(1-beta1))
|
|
p.clamp_(min=-scalar_max, max=scalar_max)
|
|
p.add_(delta)
|
|
|
|
def _project(self,
|
|
M: Tensor,
|
|
state: dict,
|
|
forward: bool) -> Tensor:
|
|
"""
|
|
Multiply a tensor M by a projection Q on each of its dimensions (for which we have
|
|
such a projection)
|
|
|
|
Args:
|
|
M: The tensor to project, e.g. a gradient or a parameter change.
|
|
state: dict to look up the projections
|
|
forward: if True, go in the forward direction (from canonical to diagnonalized
|
|
co-ordinates); if False, the reverse. They differ by a transpose,
|
|
not an inverse, and do not make a round trip.
|
|
"""
|
|
numel = M.numel()
|
|
|
|
for dim in range(1, M.ndim): # dim 0 is batch dim (batch of parameter
|
|
# tensors of same shape)
|
|
try:
|
|
Q = state[f"Q_{dim}"] # shape: (batch_size, num_blocks, block_size, block_size)
|
|
except KeyError:
|
|
continue # no projection for this dim
|
|
|
|
(batch_size, num_blocks, block_size, block_size) = Q.shape
|
|
size = M.shape[dim] # == num_blocks * block_size
|
|
|
|
if forward:
|
|
# Q is indexed
|
|
# [batch_idx, block_idx, diagonalized_idx, canonical_idx]; in
|
|
# the forward direction we want to change canonical to
|
|
# diagonalized index, so we transpose.
|
|
Q = Q.transpose(2, 3)
|
|
|
|
# assume M currently has shape (batch_size, x, size, y, z), and dim==2.
|
|
M = M.transpose(-1, dim) # (batch_size, x, y, z, size)
|
|
M = M.reshape(batch_size, *M.shape[1:-1],
|
|
num_blocks, block_size) # (batch_size, x, y, z, num_blocks, block_size)
|
|
M = _move_dim(M, -2, 1) # (batch_size, num_blocks, x, y, z, block_size)
|
|
|
|
while Q.ndim < M.ndim:
|
|
Q = Q.unsqueeze(2)
|
|
# now Q has shape (batch_size, num_blocks, 1, 1, block_size, block_size)
|
|
if M.ndim < Q.ndim: # special case where M shape is e.g. (batch_size, num_blocks, x)
|
|
M = torch.matmul(M.unsqueeze(-2), Q).squeeze(-2)
|
|
else:
|
|
M = torch.matmul(M, Q) # (batch_size, num_blocks, x, y, z, block_size)
|
|
M = _move_dim(M, 1, -2) # (batch_size, x, y, z, num_blocks, block_size)
|
|
M = M.reshape(*M.shape[:-2], size) # (batch_size, x, y, z, size)
|
|
M = M.transpose(-1, dim) # (batch_size, x, size, y, z)
|
|
return M
|
|
|
|
|
|
def _move_dim(x: Tensor, orig_dim: int, new_dim: int) -> Tensor:
|
|
"""
|
|
Moves the position of one dimension in a Tensor, while keeping
|
|
the remaining dims in the same order. E.g. if x has
|
|
shape [10, 1, 2, 3], _move_dim(x, 3, 1) will have shape
|
|
[10, 3, 1, 2]. Returns the reshaped tensor which will be
|
|
a view of the original Tensor.
|
|
Args:
|
|
x: Tensor to reshape
|
|
orig_dim: Original dimension to move, with -x.ndim <= orig_dim < x.ndim
|
|
new_dim: New position of the original dimension, with
|
|
-x.ndim <= new_dim < x.ndim
|
|
Returns: a permuted view of x
|
|
"""
|
|
if orig_dim < 0:
|
|
orig_dim += x.ndim
|
|
if new_dim < 0:
|
|
new_dim += x.ndim
|
|
dims = list(range(x.ndim))
|
|
dims[orig_dim] = -1
|
|
dims.remove(-1)
|
|
dims.insert(new_dim, orig_dim)
|
|
return x.permute(dims)
|
|
|
|
|
|
def _diag(x: Tensor):
|
|
"""
|
|
like torch diag(), this returns the diagonal of a matrix; but it returns an
|
|
aliased tensor, and it supports batch dims,
|
|
i.e. input of shape (B, M, M) returns
|
|
output of shape (B, M), or input of shape (A, B, M, M) returns output of shape
|
|
(A, B, M).
|
|
"""
|
|
stride = x.stride()
|
|
if x.ndim == 3:
|
|
(B, M, M2) = x.shape
|
|
assert M == M2
|
|
ans = x.as_strided(size=(B, M), stride=(stride[0], stride[1] + stride[2]))
|
|
elif x.ndim == 4:
|
|
(B, C, M, M2) = x.shape
|
|
assert M == M2
|
|
ans = x.as_strided(size=(B, C, M), stride=(stride[0], stride[1], stride[2] + stride[3]))
|
|
elif x.ndim == 2:
|
|
(M, M2) = x.shape
|
|
assert M == M2
|
|
ans = x.as_strided(size=(M,), stride=(stride[0] + stride[1],))
|
|
return ans
|
|
|
|
|
|
def _sum(x: Tensor,
|
|
exclude_dims: List[int] = [],
|
|
keepdim: bool = False):
|
|
"""
|
|
Version of torch.sum where you specify dims to exclude, not dims to sum.
|
|
"""
|
|
if len(exclude_dims) == 0:
|
|
# dim=[] works the same as tuple(range(x.ndim)), which makes
|
|
# no sense.. supplying the full tuple for clarity and in case
|
|
# the interface changes in future.
|
|
return x.sum(dim=tuple(range(x.ndim)), keepdim=keepdim)
|
|
elif x.ndim == 1:
|
|
assert exclude_dim == [0] or exclude_dim == [-1]
|
|
return x # if one dim is excluded, there are no dims to sum, and
|
|
# x.sum(dim=[]) sums all dims so we should not call sum().
|
|
exclude_dims_norm = [i if i >= 0 else i + x.ndim for i in exclude_dims]
|
|
dims = [i for i in range(x.ndim) if i not in exclude_dims_norm]
|
|
assert (len(dims) + len(exclude_dims) == x.ndim), exclude_dims
|
|
return x.sum(dim=dims, keepdim=keepdim)
|
|
|
|
|
|
def _mean(x: Tensor,
|
|
exclude_dims: List[int] = [],
|
|
keepdim: bool = False):
|
|
"""
|
|
Version of torch.mean where you specify dims to exclude, not dims to mean.
|
|
"""
|
|
if len(exclude_dims) == 0:
|
|
# dim=[] works the same as tuple(range(x.ndim)), which makes
|
|
# no sense.. supplying the full tuple for clarity and in case
|
|
# the interface changes in future.
|
|
return x.mean(dim=tuple(range(x.ndim)), keepdim=keepdim)
|
|
elif x.ndim == 1:
|
|
assert exclude_dims == [0] or exclude_dims == [-1]
|
|
return x # if one dim is excluded, there are no dims to mean, and
|
|
# x.mean(dim=[]) means all dims so we should not call mean().
|
|
exclude_dims_norm = [i if i >= 0 else i + x.ndim for i in exclude_dims]
|
|
dims = [i for i in range(x.ndim) if i not in exclude_dims_norm]
|
|
assert (len(dims) + len(exclude_dims) == x.ndim), exclude_dims
|
|
return x.mean(dim=dims, keepdim=keepdim)
|
|
|
|
|
|
|
|
def _svd(x: Tensor, recursion_depth: int = 0):
|
|
# Wrapper for torch svd that catches errors and re-tries (to address bugs in
|
|
# some versions of torch SVD for CUDA)
|
|
xsum = x.sum()
|
|
if not (xsum - xsum == 0):
|
|
raise RuntimeError("svd on inf or nan input")
|
|
try:
|
|
U, S, V = x.svd()
|
|
s = U.sum()
|
|
if s - s != 0 or (__name__ == '__main__' and random.random() < 0.01):
|
|
raise RuntimeError(f"svd failed (or random test), sum={s}")
|
|
return U, S, V # success
|
|
except:
|
|
if recursion_depth < 2:
|
|
logging.warning(f"svd failed: recursion_depth={recursion_depth}, "
|
|
f"x.shape={tuple(x.shape)}, x.sum()={x.sum()}: likely "
|
|
f"error in SVD. Will try again after random permutation of rows")
|
|
# randomly permute the row indexes of the matrices, and retry, hoping this
|
|
# fixes the issue
|
|
n = x.shape[-2]
|
|
randperm = torch.randperm(n, device=x.device)
|
|
inv_randperm = torch.zeros(n, device=x.device, dtype=torch.int64)
|
|
inv_randperm[randperm] = torch.arange(n, device=x.device)
|
|
x = torch.index_select(x, dim=-2, index=randperm)
|
|
U,S,V = _svd(x, recursion_depth + 1)
|
|
return torch.index_select(U, dim=-2, index=inv_randperm), S, V
|
|
elif recursion_depth < 4:
|
|
logging.warning(f"svd failed after {recursion_depth} tries: x.shape={tuple(x.shape)}, x.sum()={x.sum()}. Will try orthogonal transformation")
|
|
Urand, _, _ = torch.randn(x.shape[-2], x.shape[-1], device=x.device,
|
|
dtype=x.dtype).svd()
|
|
U, S, V = _svd(torch.matmul(Urand, x),
|
|
recursion_depth + 1)
|
|
return torch.matmul(Urand.t(), U), S, V
|
|
else:
|
|
raise RuntimeError(f"svd failed after {recursion_depth} tries")
|
|
|
|
|
|
|
|
|
|
class Cain(Optimizer):
|
|
"""Implements Cain algorithm, which is a substantially modified form of the
|
|
Adam optimizer.
|
|
The modifications can be summarized as follows:
|
|
|
|
(i) change Adam to what we called "Eve", which adds an extra configuration
|
|
value, "target_rms" (default value: 0.1); and for non-scalar parameters,
|
|
if their root-mean-square value exceeds "target_rms", we apply weight
|
|
decay (aggressive enough weight decay that the rms stays at target_rms).
|
|
This weight decay is applied in the same way as in AdamW, i.e. by
|
|
parameter shrinkage rather than by modifying the gradients.
|
|
(ii) By limiting the rms to 0.1, we reduce our modeling power; we restore
|
|
it by replacing all weights and biases-- in fact, all model parameters--
|
|
with expressions like
|
|
weight * weight_scale.exp() or bias * bias_scale.exp().
|
|
This has the benefit of making the learnable offsets and scales of
|
|
BatchNorm/LayerNorm unnecessary.
|
|
Below we describe the changes from Eve to Cain:
|
|
(iii) We removed all the weight_scale and bias_scale parameters from
|
|
the model and "moved them into the optimizer". Now we allow
|
|
parameters to have any rms value (that's <= 10), but we:
|
|
(a) make the update speed for a parameter proportional to
|
|
its rms value.
|
|
(b) Multiply the learning rate by 10 because we are getting rid of
|
|
target_rms=0.1
|
|
(c) simulate the effect of learning the weight_scale or bias_scale
|
|
in log space using standard Adam, with a scale to prevent it
|
|
from learning too fast. (see scale_exp_avg_sq in the
|
|
code below).
|
|
|
|
The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_.
|
|
The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_.
|
|
Eve is unpublished so far.
|
|
|
|
Arguments:
|
|
params (iterable): iterable of parameters to optimize or dicts defining
|
|
parameter groups
|
|
lr (float, optional): learning rate (default: 1e-3)
|
|
betas (Tuple[float, float], optional): coefficients used for computing
|
|
running averages of gradient and its square (default: (0.9, 0.999))
|
|
eps (float, optional): term added to the denominator to improve
|
|
numerical stability (default: 1e-8)
|
|
scale_speed (float, optional): we multiply the learning rate by this value
|
|
when learning (in log space) the scales of the parameter tensors
|
|
rms_eps (float, optional): epsilon value such that when parameter
|
|
rms value goes below this, we stop learning slower for that parameter,
|
|
i.e. we treat the rms as if it were rms_eps. In practice, rms values
|
|
will not go much below this.
|
|
param_max (float, optional): maximum absolute value for any parameter;
|
|
we will clamp parameters to satisfy -param_max <= param <= param_max
|
|
|
|
|
|
|
|
.. _Adam\: A Method for Stochastic Optimization:
|
|
https://arxiv.org/abs/1412.6980
|
|
.. _Decoupled Weight Decay Regularization:
|
|
https://arxiv.org/abs/1711.05101
|
|
.. _On the Convergence of Adam and Beyond:
|
|
https://openreview.net/forum?id=ryQu7f-RZ
|
|
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
params,
|
|
lr=1e-2,
|
|
betas=(0.9, 0.98),
|
|
eps=1e-8,
|
|
scale_speed=0.05,
|
|
rms_eps=1.0e-05,
|
|
param_max=20.0,
|
|
):
|
|
if not 0.0 <= lr:
|
|
raise ValueError("Invalid learning rate: {}".format(lr))
|
|
if not 0.0 <= eps:
|
|
raise ValueError("Invalid epsilon value: {}".format(eps))
|
|
if not 0.0 <= betas[0] < 1.0:
|
|
raise ValueError(
|
|
"Invalid beta parameter at index 0: {}".format(betas[0])
|
|
)
|
|
if not 0.0 <= betas[1] < 1.0:
|
|
raise ValueError(
|
|
"Invalid beta parameter at index 1: {}".format(betas[1])
|
|
)
|
|
if not 0 < scale_speed < 1.0:
|
|
raise ValueError("Invalid scale_speed value: {}".format(scale_speed))
|
|
if not 0.1 < param_max <= 10000.0:
|
|
raise ValueError("Invalid param_max value: {}".format(param_max))
|
|
if not 0.0 < rms_eps < 0.1:
|
|
raise ValueError("Invalid rms_eps value: {}".format(rms_eps))
|
|
|
|
defaults = dict(
|
|
lr=lr,
|
|
betas=betas,
|
|
eps=eps,
|
|
scale_speed=0.05,
|
|
rms_eps=rms_eps,
|
|
param_max=param_max,
|
|
)
|
|
super(Cain, self).__init__(params, defaults)
|
|
|
|
def __setstate__(self, state):
|
|
super(Cain, self).__setstate__(state)
|
|
|
|
@torch.no_grad()
|
|
def step(self, closure=None):
|
|
"""Performs a single optimization step.
|
|
|
|
Arguments:
|
|
closure (callable, optional): A closure that reevaluates the model
|
|
and returns the loss.
|
|
"""
|
|
loss = None
|
|
if closure is not None:
|
|
with torch.enable_grad():
|
|
loss = closure()
|
|
|
|
for group in self.param_groups:
|
|
beta1, beta2 = group["betas"]
|
|
lr = group["lr"]
|
|
scale_speed = group["scale_speed"]
|
|
rms_eps = group["rms_eps"]
|
|
eps = group["eps"]
|
|
param_max = group["param_max"]
|
|
|
|
for p in group["params"]:
|
|
if p.grad is None:
|
|
continue
|
|
|
|
# Perform optimization step
|
|
grad = p.grad
|
|
if grad.is_sparse:
|
|
raise RuntimeError(
|
|
"Cain does not support sparse gradients"
|
|
)
|
|
|
|
state = self.state[p]
|
|
|
|
# State initialization
|
|
if len(state) == 0:
|
|
state["step"] = 0
|
|
# Exponential moving average of gradient values
|
|
state["delta"] = torch.zeros_like(
|
|
p, memory_format=torch.preserve_format
|
|
)
|
|
# Exponential moving average of squared gradient values
|
|
state["exp_avg_sq"] = torch.zeros_like(
|
|
p, memory_format=torch.preserve_format
|
|
)
|
|
if p.numel() > 1:
|
|
# we keep track of the RMS value of the parameters to
|
|
# help set the learning rate... this emulates Eve.
|
|
state["param_rms"] = torch.ones_like(p)
|
|
# we learn the scale on the parameters in a way that
|
|
# also emulates Eve. scale_exp_avg_sq is the
|
|
# moving-average square of the gradient w.r.t. this
|
|
# scalar scale.
|
|
state["scale_exp_avg_sq"] = torch.zeros((), device=p.device,
|
|
dtype=p.dtype)
|
|
|
|
|
|
delta, exp_avg_sq = state["delta"], state["exp_avg_sq"]
|
|
step = state["step"]
|
|
|
|
delta.mul_(beta1)
|
|
|
|
if p.numel() > 1:
|
|
|
|
if True:
|
|
# This block learns the scale on the parameters.
|
|
# scale_deriv is the derivative w.r.t. a log-space scale
|
|
# on the parameters, i.e. if we imagine: p =
|
|
# underlying_param * scale.exp(), scale_deriv is the
|
|
# derivative w.r.t. scale (it does not matter
|
|
# what value `scale` currently has).
|
|
scale_deriv = (p * grad).sum()
|
|
scale_exp_avg_sq = state["scale_exp_avg_sq"]
|
|
scale_exp_avg_sq.mul_(beta2).addcmul_(scale_deriv, scale_deriv,
|
|
value=1 - beta2)
|
|
|
|
# should actually be step + 1, so on 1st minibatch we are not learning
|
|
# anything here. May fix this at some point.
|
|
scale_bias_correction2 = 1 - beta2 ** (step + 1)
|
|
|
|
scale_denom = (scale_exp_avg_sq.sqrt()).add_(eps)
|
|
|
|
scale_alpha = -lr * (scale_bias_correction2 ** 0.5) * scale_speed * (1-beta1)
|
|
|
|
# scale_delta is going to be the change in log-scale (before momentum), i.e. if
|
|
# p = underlying_param * scale.exp(),
|
|
# delta is the change in `scale`.
|
|
scale_delta = scale_alpha * (scale_deriv / scale_denom)
|
|
|
|
delta.add_(p, alpha=scale_delta)
|
|
|
|
# Decay the second moment running average coefficient
|
|
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
|
|
|
bias_correction2 = 1 - beta2 ** step
|
|
|
|
grad_rms = (exp_avg_sq.sqrt()).add_(eps)
|
|
|
|
this_delta = grad / grad_rms
|
|
|
|
param_rms = state["param_rms"]
|
|
self._update_param_rms(p, param_rms, step, rms_eps)
|
|
|
|
alpha = (-(1-beta1) * lr * bias_correction2)
|
|
|
|
# geometrically interpolate the per-element param_rms with its mean
|
|
#param_rms_half = (param_rms * param_rms.mean()) ** 0.5
|
|
delta.add_(this_delta * param_rms, alpha=alpha)
|
|
|
|
# below, will do: delta.add_(this_delta, alpha=alpha)
|
|
else:
|
|
# p.numel() == 1
|
|
|
|
# Decay the second moment running average coefficient
|
|
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
|
bias_correction2 = 1 - beta2 ** step
|
|
denom = (exp_avg_sq.sqrt()).add_(eps)
|
|
this_delta = grad / denom
|
|
alpha = -lr * (1-beta1) * (bias_correction2 ** 0.5)
|
|
delta.add_(this_delta, alpha=alpha)
|
|
|
|
p.add_(delta)
|
|
if step % 10 == 0:
|
|
p.clamp_(min=-param_max, max=param_max)
|
|
state["step"] += 1
|
|
|
|
return loss
|
|
|
|
def _update_param_rms(self, p: Tensor, param_rms: Tensor,
|
|
step: int, rms_eps: float):
|
|
"""
|
|
Updates `param_rms`. Require p.numel() > 1
|
|
Args:
|
|
p: parameter whose magnitude/rms we are measuring
|
|
param_rms: a Tensor containing non-negative values, with the same
|
|
shape as p. Contains an estimate of the root-mean-square
|
|
value of the parameter. This has a special structure where
|
|
it must be a product of one-dimensional quantities,
|
|
one for each axis of p.
|
|
We exclude certain axes to avoid estimating the rms value
|
|
from a too-small number of parameters.
|
|
step: the step of the algorithm
|
|
rms_eps: epsilon such that we'll aim to prevent param_rms from
|
|
going much lower than this.
|
|
"""
|
|
if step % 1000 == 0:
|
|
# Periodically refresh param_rms to keep the invariance that it is a
|
|
# product over its axes (numerical roundoff might destroy this)
|
|
param_rms.fill_(1.0)
|
|
for dim in range(p.ndim):
|
|
self._update_param_rms_for_dim(p, param_rms, dim, rms_eps)
|
|
else:
|
|
dim = step % p.ndim
|
|
self._update_param_rms_for_dim(p, param_rms, dim, rms_eps)
|
|
|
|
def _update_param_rms_for_dim(self, p: Tensor,
|
|
param_rms: Tensor,
|
|
dim: int,
|
|
rms_eps: float):
|
|
"""
|
|
Updates `param_rms` on one of its axes/dimensions.
|
|
Args:
|
|
p: parameter whose magnitude/rms we are measuring
|
|
param_rms: a Tensor containing non-negative values, with the same
|
|
shape as p. Contains an estimate of the root-mean-square
|
|
value of the parameter. This has a special structure where
|
|
it must be a product of one-dimensional quantities,
|
|
one for each axis of p.
|
|
We exclude certain axes to avoid estimating the rms value
|
|
from a too-small number of parameters.
|
|
dim: with 0 <= dim < p.ndim, the
|
|
dimension/axis on which to update the scale factor
|
|
rms_eps: epsilon such that we'll aim to prevent param_rms from
|
|
going much lower than this.
|
|
"""
|
|
# var_factor is the square of the factor to change param_rms by. p will
|
|
# not be super large, we limit it to -param_max <= p <= param_max.
|
|
# Clamping p**2 to min=1.0e-10 should prevent param_rms from getting much
|
|
# smaller than 1.0e-05, which should prevent division by zero here.
|
|
var_factor = (p ** 2).clamp(min=rms_eps*rms_eps) / (param_rms ** 2)
|
|
|
|
if p.numel() <= 2 * p.shape[dim]:
|
|
var_factor = var_factor.mean()
|
|
else:
|
|
dims = tuple(d for d in range(p.ndim) if d != dim)
|
|
var_factor = var_factor.mean(dim=dims, keepdim=True)
|
|
|
|
#if random.random() < 0.01:
|
|
# logging.info(f"shape={p.shape}, var_factor={var_factor}")
|
|
param_rms.mul_(var_factor.sqrt())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LRScheduler(object):
|
|
"""
|
|
Base-class for learning rate schedulers where the learning-rate depends on both the
|
|
batch and the epoch.
|
|
"""
|
|
|
|
def __init__(self, optimizer: Optimizer, verbose: bool = False):
|
|
# Attach optimizer
|
|
if not isinstance(optimizer, Optimizer):
|
|
raise TypeError(
|
|
"{} is not an Optimizer".format(type(optimizer).__name__)
|
|
)
|
|
self.optimizer = optimizer
|
|
self.verbose = verbose
|
|
|
|
for group in optimizer.param_groups:
|
|
group.setdefault("initial_lr", group["lr"])
|
|
|
|
self.base_lrs = [
|
|
group["initial_lr"] for group in optimizer.param_groups
|
|
]
|
|
|
|
self.epoch = 0
|
|
self.batch = 0
|
|
|
|
def state_dict(self):
|
|
"""Returns the state of the scheduler as a :class:`dict`.
|
|
|
|
It contains an entry for every variable in self.__dict__ which
|
|
is not the optimizer.
|
|
"""
|
|
return {
|
|
"base_lrs": self.base_lrs,
|
|
"epoch": self.epoch,
|
|
"batch": self.batch,
|
|
}
|
|
|
|
def load_state_dict(self, state_dict):
|
|
"""Loads the schedulers state.
|
|
|
|
Args:
|
|
state_dict (dict): scheduler state. Should be an object returned
|
|
from a call to :meth:`state_dict`.
|
|
"""
|
|
self.__dict__.update(state_dict)
|
|
|
|
def get_last_lr(self) -> List[float]:
|
|
"""Return last computed learning rate by current scheduler. Will be a list of float."""
|
|
return self._last_lr
|
|
|
|
def get_lr(self):
|
|
# Compute list of learning rates from self.epoch and self.batch and
|
|
# self.base_lrs; this must be overloaded by the user.
|
|
# e.g. return [some_formula(self.batch, self.epoch, base_lr) for base_lr in self.base_lrs ]
|
|
raise NotImplementedError
|
|
|
|
def step_batch(self, batch: Optional[int] = None) -> None:
|
|
# Step the batch index, or just set it. If `batch` is specified, it
|
|
# must be the batch index from the start of training, i.e. summed over
|
|
# all epochs.
|
|
# You can call this in any order; if you don't provide 'batch', it should
|
|
# of course be called once per batch.
|
|
if batch is not None:
|
|
self.batch = batch
|
|
else:
|
|
self.batch = self.batch + 1
|
|
self._set_lrs()
|
|
|
|
def step_epoch(self, epoch: Optional[int] = None):
|
|
# Step the epoch index, or just set it. If you provide the 'epoch' arg,
|
|
# you should call this at the start of the epoch; if you don't provide the 'epoch'
|
|
# arg, you should call it at the end of the epoch.
|
|
if epoch is not None:
|
|
self.epoch = epoch
|
|
else:
|
|
self.epoch = self.epoch + 1
|
|
self._set_lrs()
|
|
|
|
def _set_lrs(self):
|
|
values = self.get_lr()
|
|
assert len(values) == len(self.optimizer.param_groups)
|
|
|
|
for i, data in enumerate(zip(self.optimizer.param_groups, values)):
|
|
param_group, lr = data
|
|
param_group["lr"] = lr
|
|
self.print_lr(self.verbose, i, lr)
|
|
self._last_lr = [group["lr"] for group in self.optimizer.param_groups]
|
|
|
|
def print_lr(self, is_verbose, group, lr):
|
|
"""Display the current learning rate."""
|
|
if is_verbose:
|
|
logging.info(
|
|
f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate"
|
|
f" of group {group} to {lr:.4e}."
|
|
)
|
|
|
|
class Eve(Optimizer):
|
|
"""
|
|
Implements Eve algorithm. This is a modified version of AdamW with a special
|
|
way of setting the weight-decay / shrinkage-factor, which is designed to make the
|
|
rms of the parameters approach a particular target_rms (default: 0.1). This is
|
|
for use with networks with 'scaled' versions of modules (see scaling.py), which
|
|
will be close to invariant to the absolute scale on the parameter matrix.
|
|
|
|
The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_.
|
|
The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_.
|
|
Eve is unpublished so far.
|
|
|
|
Arguments:
|
|
params (iterable): iterable of parameters to optimize or dicts defining
|
|
parameter groups
|
|
lr (float, optional): learning rate (default: 1e-3)
|
|
betas (Tuple[float, float], optional): coefficients used for computing
|
|
running averages of gradient and its square (default: (0.9, 0.999))
|
|
eps (float, optional): term added to the denominator to improve
|
|
numerical stability (default: 1e-8)
|
|
weight_decay (float, optional): weight decay coefficient (default: 3e-4;
|
|
this value means that the weight would decay significantly after
|
|
about 3k minibatches. Is not multiplied by learning rate, but
|
|
is conditional on RMS-value of parameter being > target_rms.
|
|
target_rms (float, optional): target root-mean-square value of
|
|
parameters, if they fall below this we will stop applying weight decay.
|
|
|
|
|
|
.. _Adam\: A Method for Stochastic Optimization:
|
|
https://arxiv.org/abs/1412.6980
|
|
.. _Decoupled Weight Decay Regularization:
|
|
https://arxiv.org/abs/1711.05101
|
|
.. _On the Convergence of Adam and Beyond:
|
|
https://openreview.net/forum?id=ryQu7f-RZ
|
|
"""
|
|
def __init__(
|
|
self,
|
|
params,
|
|
lr=1e-3,
|
|
betas=(0.9, 0.98),
|
|
eps=1e-8,
|
|
weight_decay=1e-3,
|
|
target_rms=0.1,
|
|
):
|
|
if not 0.0 <= lr:
|
|
raise ValueError("Invalid learning rate: {}".format(lr))
|
|
if not 0.0 <= eps:
|
|
raise ValueError("Invalid epsilon value: {}".format(eps))
|
|
if not 0.0 <= betas[0] < 1.0:
|
|
raise ValueError(
|
|
"Invalid beta parameter at index 0: {}".format(betas[0])
|
|
)
|
|
if not 0.0 <= betas[1] < 1.0:
|
|
raise ValueError(
|
|
"Invalid beta parameter at index 1: {}".format(betas[1])
|
|
)
|
|
if not 0 <= weight_decay <= 0.1:
|
|
raise ValueError(
|
|
"Invalid weight_decay value: {}".format(weight_decay)
|
|
)
|
|
if not 0 < target_rms <= 10.0:
|
|
raise ValueError("Invalid target_rms value: {}".format(target_rms))
|
|
defaults = dict(
|
|
lr=lr,
|
|
betas=betas,
|
|
eps=eps,
|
|
weight_decay=weight_decay,
|
|
target_rms=target_rms,
|
|
)
|
|
super(Eve, self).__init__(params, defaults)
|
|
|
|
def __setstate__(self, state):
|
|
super(Eve, self).__setstate__(state)
|
|
|
|
@torch.no_grad()
|
|
def step(self, closure=None):
|
|
"""Performs a single optimization step.
|
|
|
|
Arguments:
|
|
closure (callable, optional): A closure that reevaluates the model
|
|
and returns the loss.
|
|
"""
|
|
loss = None
|
|
if closure is not None:
|
|
with torch.enable_grad():
|
|
loss = closure()
|
|
|
|
for group in self.param_groups:
|
|
for p in group["params"]:
|
|
if p.grad is None:
|
|
continue
|
|
|
|
# Perform optimization step
|
|
grad = p.grad
|
|
if grad.is_sparse:
|
|
raise RuntimeError(
|
|
"AdamW does not support sparse gradients"
|
|
)
|
|
|
|
state = self.state[p]
|
|
|
|
# State initialization
|
|
if len(state) == 0:
|
|
state["step"] = 0
|
|
# Exponential moving average of gradient values
|
|
state["exp_avg"] = torch.zeros_like(
|
|
p, memory_format=torch.preserve_format
|
|
)
|
|
# Exponential moving average of squared gradient values
|
|
state["exp_avg_sq"] = torch.zeros_like(
|
|
p, memory_format=torch.preserve_format
|
|
)
|
|
|
|
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
|
|
|
beta1, beta2 = group["betas"]
|
|
|
|
state["step"] += 1
|
|
bias_correction1 = 1 - beta1 ** state["step"]
|
|
bias_correction2 = 1 - beta2 ** state["step"]
|
|
|
|
# Decay the first and second moment running average coefficient
|
|
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
|
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
|
denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_(
|
|
group["eps"]
|
|
)
|
|
|
|
step_size = group["lr"] / bias_correction1
|
|
target_rms = group["target_rms"]
|
|
weight_decay = group["weight_decay"]
|
|
|
|
if p.numel() > 1:
|
|
# avoid applying this weight-decay on "scaling factors"
|
|
# (which are scalar).
|
|
is_above_target_rms = p.norm() > (
|
|
target_rms * (p.numel() ** 0.5)
|
|
)
|
|
p.mul_(1 - (weight_decay * is_above_target_rms))
|
|
|
|
p.addcdiv_(exp_avg, denom, value=-step_size)
|
|
|
|
if random.random() < 0.0005:
|
|
step = (exp_avg/denom) * step_size
|
|
logging.info(f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}")
|
|
|
|
|
|
return loss
|
|
|
|
|
|
class Eden(LRScheduler):
|
|
"""
|
|
Eden scheduler.
|
|
lr = initial_lr * (((batch**2 + lr_batches**2) / lr_batches**2) ** -0.25 *
|
|
(((epoch**2 + lr_epochs**2) / lr_epochs**2) ** -0.25))
|
|
|
|
E.g. suggest initial-lr = 0.003 (passed to optimizer).
|
|
|
|
Args:
|
|
optimizer: the optimizer to change the learning rates on
|
|
lr_batches: the number of batches after which we start significantly
|
|
decreasing the learning rate, suggest 5000.
|
|
lr_epochs: the number of epochs after which we start significantly
|
|
decreasing the learning rate, suggest 6 if you plan to do e.g.
|
|
20 to 40 epochs, but may need smaller number if dataset is huge
|
|
and you will do few epochs.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
optimizer: Optimizer,
|
|
lr_batches: Union[int, float],
|
|
lr_epochs: Union[int, float],
|
|
verbose: bool = False,
|
|
):
|
|
super(Eden, self).__init__(optimizer, verbose)
|
|
self.lr_batches = lr_batches
|
|
self.lr_epochs = lr_epochs
|
|
|
|
def get_lr(self):
|
|
factor = (
|
|
(self.batch ** 2 + self.lr_batches ** 2) / self.lr_batches ** 2
|
|
) ** -0.25 * (
|
|
((self.epoch ** 2 + self.lr_epochs ** 2) / self.lr_epochs ** 2)
|
|
** -0.25
|
|
)
|
|
return [x * factor for x in self.base_lrs]
|
|
|
|
|
|
def _test_eden():
|
|
m = torch.nn.Linear(100, 100)
|
|
optim = Cain(m.parameters(), lr=0.003)
|
|
|
|
scheduler = Eden(optim, lr_batches=100, lr_epochs=2, verbose=True)
|
|
|
|
for epoch in range(10):
|
|
scheduler.step_epoch(epoch) # sets epoch to `epoch`
|
|
|
|
for step in range(20):
|
|
x = torch.randn(200, 100).detach()
|
|
x.requires_grad = True
|
|
y = m(x)
|
|
dy = torch.randn(200, 100).detach()
|
|
f = (y * dy).sum()
|
|
f.backward()
|
|
|
|
optim.step()
|
|
scheduler.step_batch()
|
|
optim.zero_grad()
|
|
|
|
logging.info(f"last lr = {scheduler.get_last_lr()}")
|
|
logging.info(f"state dict = {scheduler.state_dict()}")
|
|
|
|
|
|
def _test_eve_cain():
|
|
import timeit
|
|
from scaling import ScaledLinear
|
|
E = 100
|
|
B = 4
|
|
T = 2
|
|
logging.info("in test_eve_cain")
|
|
#device = torch.device('cuda')
|
|
device = torch.device('cpu')
|
|
dtype = torch.float32
|
|
|
|
fix_random_seed(42)
|
|
# these input_magnitudes and output_magnitudes are to test that
|
|
# Abel is working as we expect and is able to adjust scales of
|
|
# different dims differently.
|
|
input_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp()
|
|
output_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp()
|
|
|
|
#for iter in [3, 2, 1, 0]: # will restore 1,0 later
|
|
for iter in [3, 2]:
|
|
fix_random_seed(42)
|
|
Linear = torch.nn.Linear if iter == 0 else ScaledLinear
|
|
|
|
hidden_dim = 400
|
|
m = torch.nn.Sequential(Linear(E, hidden_dim),
|
|
torch.nn.PReLU(),
|
|
Linear(hidden_dim, hidden_dim),
|
|
torch.nn.PReLU(),
|
|
Linear(hidden_dim, E),
|
|
).to(device)
|
|
|
|
train_pairs = [ (100.0 * torch.randn(B, T, E, device=device, dtype=dtype) * input_magnitudes,
|
|
torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes) for _ in range(20) ]
|
|
|
|
if iter == 0: optim = Eve(m.parameters(), lr=0.003)
|
|
elif iter == 1: optim = Cain(m.parameters(), lr=0.03)
|
|
elif iter == 2: optim = PrAdam(m.parameters(), lr=0.03)
|
|
elif iter == 3: optim = PrAdam(m.parameters(), lr=0.03, max_block_size=256)
|
|
scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False)
|
|
|
|
|
|
start = timeit.default_timer()
|
|
avg_loss = 0.0
|
|
for epoch in range(180):
|
|
scheduler.step_epoch()
|
|
#if epoch == 100 and iter in [2,3]:
|
|
# optim.reset_speedup() # check it doesn't crash.
|
|
|
|
#if epoch == 130:
|
|
# opts = diagnostics.TensorDiagnosticOptions(
|
|
# 2 ** 22
|
|
# ) # allow 4 megabytes per sub-module
|
|
# diagnostic = diagnostics.attach_diagnostics(m, opts)
|
|
|
|
|
|
for n, (x,y) in enumerate(train_pairs):
|
|
y_out = m(x)
|
|
loss = ((y_out - y)**2).mean() * 100.0
|
|
if epoch == 0 and n == 0:
|
|
avg_loss = loss.item()
|
|
else:
|
|
avg_loss = 0.98 * avg_loss + 0.02 * loss.item()
|
|
if n == 0 and epoch % 5 == 0:
|
|
#norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item()
|
|
#norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item()
|
|
#norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item()
|
|
#norm2b = '%.2e' % (m[2].bias**2).mean().sqrt().item()
|
|
#scale1 = '%.2e' % (m[0].weight_scale.exp().item())
|
|
#scale1b = '%.2e' % (m[0].bias_scale.exp().item())
|
|
#scale2 = '%.2e' % (m[2].weight_scale.exp().item())
|
|
#scale2b = '%.2e' % (m[2].bias_scale.exp().item())
|
|
lr = scheduler.get_last_lr()[0]
|
|
logging.info(f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}") #, norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b}
|
|
loss.log().backward()
|
|
optim.step()
|
|
optim.zero_grad()
|
|
scheduler.step_batch()
|
|
|
|
#diagnostic.print_diagnostics()
|
|
|
|
stop = timeit.default_timer()
|
|
logging.info(f"Iter={iter}, Time taken: {stop - start}")
|
|
|
|
logging.info(f"last lr = {scheduler.get_last_lr()}")
|
|
#logging.info("state dict = ", scheduler.state_dict())
|
|
#logging.info("optim state_dict = ", optim.state_dict())
|
|
logging.info(f"input_magnitudes = {input_magnitudes}")
|
|
logging.info(f"output_magnitudes = {output_magnitudes}")
|
|
|
|
|
|
def _test_svd():
|
|
device = 'cuda'
|
|
for i in range(1000):
|
|
X = torch.eye(100, device=device) + 1.0e-08 * torch.randn(100, 100, device=device)
|
|
U, S, V = _svd(X)
|
|
X2 = torch.matmul(U*S, V.t())
|
|
assert torch.allclose(X2, X, atol=0.001)
|
|
|
|
|
|
def _test_smooth_cov():
|
|
b = (torch.arange(10)*0.5).exp().diag()
|
|
b = b.unsqueeze(0).unsqueeze(0)
|
|
c = PrAdam._smooth_cov(None, b, 0.0, 10.0)
|
|
logging.info(f"c[noinv] = {_diag(c)}")
|
|
c = PrAdam._smooth_cov(None, b, 0.0, 10.0, 1.00001)
|
|
logging.info(f"c[svd] = {_diag(c)}")
|
|
|
|
if __name__ == "__main__":
|
|
torch.set_num_threads(1)
|
|
torch.set_num_interop_threads(1)
|
|
logging.getLogger().setLevel(logging.INFO)
|
|
import subprocess
|
|
s = subprocess.check_output("git status -uno .; git log -1", shell=True)
|
|
_test_smooth_cov()
|
|
logging.info(s)
|
|
#_test_svd()
|
|
_test_eve_cain()
|
|
#_test_eden()
|