Remove batching from ScaledAdam, in preparation to add gradient norm clipping

This commit is contained in:
Daniel Povey 2022-09-16 15:42:56 +08:00
parent 3b450c2682
commit 8f876b3f54

View File

@ -25,73 +25,9 @@ from torch.optim import Optimizer
import logging
class BatchedOptimizer(Optimizer):
"""
This class adds to class Optimizer the capability to optimize parameters in batches:
it will stack the parameters and their grads for you so the optimizer can work
on tensors with an extra leading dimension. This is intended for speed with GPUs,
as it reduces the number of kernels launched in the optimizer.
Args:
params:
"""
def __init__(self, params, defaults):
super(BatchedOptimizer, self).__init__(params, defaults)
def batched_params(self, param_group):
"""
This function returns an iterator over tuples (p, state), where
p is a `fake` parameter that is stacked (over axis 0) from real parameters
that share the same shape, and its gradient is also stacked;
`state` is the state corresponding to this fake parameter
(it will be physically located in the "state" for one of the real
parameters, the last one that has any particular shape and dtype).
The idea is, instead of doing:
<code>
for p in group["params"]:
state = self.state[p]
...
</code>
you can do:
<code>
for p, state in self.batched_params(group["params"]):
...
</code>
Args:
group: a parameter group, which is a list of parameters; should be
one of self.groups.
"""
batches = defaultdict(list) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter
for p in param_group:
assert p.grad is not None and "All parameters must have a grad when you call optimizer.step()"
assert not p.grad.is_sparse and "Sparse gradients not supported."
key = (str(p.dtype), *p.shape)
batches[key].append(p)
for p in param_group:
key = (str(p.dtype), *p.shape)
batch = batches[key]
if p is batch[0]: # if this is the 1st param in the batch...
# we arbitrarily store the state in the
# state corresponding to the 1st parameter in the
# group. class Optimizer will take care of saving/loading state.
state = self.state[p]
p_stacked = torch.stack(batch)
grad = torch.stack([p.grad for p in batch])
p_stacked.grad = grad
yield p_stacked, state # <-- calling code will do the actual optimization here!
# Now un-stack the parameter changes
for i,p in enumerate(batch):
p.copy_(p_stacked[i])
class ScaledAdam(BatchedOptimizer):
class ScaledAdam(Optimizer):
"""
Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update
proportional to the norm of that parameter; and also learn the scale of the parameter,
@ -170,7 +106,9 @@ class ScaledAdam(BatchedOptimizer):
batch = True
for group in self.param_groups:
for p, state in self.batched_params(group["params"]):
for p in group["params"]:
state = self.state[p]
# Perform optimization step
grad = p.grad
if grad.is_sparse:
@ -191,14 +129,11 @@ class ScaledAdam(BatchedOptimizer):
p: Tensor,
state: dict):
"""
Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p
is actually the batch dimension, corresponding to batched-together
parameters of a given shape.
Initializes state dict for parameter 'p'.
Args:
group: Dict to look up configuration values.
p: The parameter that we are initializing the state for, actually
will be stacked-together "real" parameters
p: The parameter that we are initializing the state for
state: Dict from string to whatever state we are initializing
"""
eps = group["eps"]
@ -218,20 +153,17 @@ class ScaledAdam(BatchedOptimizer):
p, memory_format=torch.preserve_format
)
batch_size = p.shape[0]
numel = p.numel() // batch_size
numel = p.numel()
if numel > 1:
# "param_rms" just periodically records the scalar root-mean-square value of
# the parameter tensor.
# it has a shape like (batch_size, 1, 1, 1, 1)
param_rms = (p**2).mean(dim=list(range(1, p.ndim)),
keepdim=True).sqrt() + eps
param_rms = (p**2).mean().sqrt() + eps
state["param_rms"] = param_rms
state["scale_exp_avg_sq"] = torch.zeros_like(param_rms)
state["scale_grads"] = torch.zeros(size_update_period, *param_rms.shape,
**kwargs)
state["scale_grads"] = torch.zeros(size_update_period, **kwargs)
# exp_avg_sq is the weighted sum of scaled gradients. as in Adam.
state["exp_avg_sq"] = torch.zeros_like(
@ -244,12 +176,10 @@ class ScaledAdam(BatchedOptimizer):
p: Tensor,
state: dict):
"""
Do the step for one parameter, which is actually going to be a batch of
`real` parameters, with dim 0 as the batch dim.
Do the step for parameter p.
Args:
group: dict to look up configuration values
p: parameter to update (actually multiple parameters stacked together
as a batch)
p: parameter to update
state: state-dict for p, to look up the optimizer state
"""
lr = group["lr"]
@ -262,17 +192,14 @@ class ScaledAdam(BatchedOptimizer):
delta = state["delta"]
delta.mul_(beta1)
batch_size = p.shape[0]
numel = p.numel() // batch_size
numel = p.numel()
if numel > 1:
# Update the size/scale of p, and set param_rms
scale_grads = state["scale_grads"]
scale_grads[step % size_update_period] = (p * grad).sum(
dim=list(range(1, p.ndim)), keepdim=True)
scale_grads[step % size_update_period] = (p * grad).sum()
if step % size_update_period == size_update_period - 1:
param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..)
param_rms.copy_((p**2).mean(dim=list(range(1, p.ndim)),
keepdim=True).sqrt().clamp_(min=eps))
param_rms = state["param_rms"]
param_rms.copy_((p**2).mean().sqrt().clamp_(min=eps))
if step > 0:
# self._size_update() learns the overall scale on the
# parameter, by shrinking or expanding it.
@ -302,7 +229,7 @@ class ScaledAdam(BatchedOptimizer):
Args:
group: dict to look up configuration values
scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing
scale_grads: a tensor of shape (size_update_period) containing
grads w.r.t. the scales.
p: The parameter to update
state: The state-dict of p
@ -315,16 +242,15 @@ class ScaledAdam(BatchedOptimizer):
param_max_rms = group["param_max_rms"]
step = state["step"]
batch_size = p.shape[0]
size_update_period = scale_grads.shape[0]
# correct beta2 for the size update period: we will have
# faster decay at this level.
beta2_corr = beta2 ** size_update_period
scale_exp_avg_sq = state["scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..)
scale_exp_avg_sq = state["scale_exp_avg_sq"] # scalar
scale_exp_avg_sq.mul_(beta2_corr).add_(
(scale_grads ** 2).mean(dim=0), # mean over dim `size_update_period`
alpha=1-beta2_corr) # shape is (batch_size, 1, 1, ...)
(scale_grads ** 2).mean(), # mean over `size_update_period` elements
alpha=1-beta2_corr)
# The 1st time we reach here is when size_step == 1.
size_step = (step + 1) // size_update_period
@ -333,9 +259,9 @@ class ScaledAdam(BatchedOptimizer):
# at the start of training.
eps = 1.0e-10 # this value is not critical.
denom = scale_exp_avg_sq.sqrt() + eps # shape: (batch_size, 1, 1, ..)
denom = scale_exp_avg_sq.sqrt() + eps
scale_step = -size_lr * (bias_correction2 ** 0.5) * scale_grads.sum(dim=0) / denom
scale_step = -size_lr * (bias_correction2 ** 0.5) * scale_grads.sum() / denom
is_too_small = (param_rms < param_min_rms)
is_too_large = (param_rms > param_max_rms)
@ -353,17 +279,13 @@ class ScaledAdam(BatchedOptimizer):
state: dict):
"""
This function does the core update of self.step(), in the case where the tensor
has more than 1 elements. It multiplies the moving-average gradient tensor by a
state["lr_{dim}"] on each non-trivial axis/dimension, does a length normalization,
and takes a step in that direction.
has more than 1 elements.
Args:
group: A dict which will be used to look up configuration values
p: The parameter to be updated
grad: The grad of p
state: The state-dict corresponding to parameter p
batch: if true, the first dimension of p is a batch dim, i.e. a batch of
parameter matrices.
This function modifies p.
"""
@ -401,7 +323,7 @@ class ScaledAdam(BatchedOptimizer):
state: dict):
"""
A simplified form of the core update for scalar tensors, where we cannot get a good
estimate of the parameter rms. p will not be a scalar, because it has a batch dim.
estimate of the parameter rms.
"""
beta1, beta2 = group["betas"]
scalar_max = group["scalar_max"]
@ -409,7 +331,7 @@ class ScaledAdam(BatchedOptimizer):
lr = group["lr"]
grad = p.grad
exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,)
exp_avg_sq = state["exp_avg_sq"] # scalar
exp_avg_sq.mul_(beta2).addcmul_(grad, grad,
value=1-beta2)