Remove batching from ScaledAdam, in preparation to add gradient norm clipping

This commit is contained in:
Daniel Povey 2022-09-16 15:42:56 +08:00
parent 3b450c2682
commit 8f876b3f54

View File

@ -25,73 +25,9 @@ from torch.optim import Optimizer
import logging import logging
class BatchedOptimizer(Optimizer):
"""
This class adds to class Optimizer the capability to optimize parameters in batches:
it will stack the parameters and their grads for you so the optimizer can work
on tensors with an extra leading dimension. This is intended for speed with GPUs,
as it reduces the number of kernels launched in the optimizer.
Args:
params:
"""
def __init__(self, params, defaults):
super(BatchedOptimizer, self).__init__(params, defaults)
def batched_params(self, param_group): class ScaledAdam(Optimizer):
"""
This function returns an iterator over tuples (p, state), where
p is a `fake` parameter that is stacked (over axis 0) from real parameters
that share the same shape, and its gradient is also stacked;
`state` is the state corresponding to this fake parameter
(it will be physically located in the "state" for one of the real
parameters, the last one that has any particular shape and dtype).
The idea is, instead of doing:
<code>
for p in group["params"]:
state = self.state[p]
...
</code>
you can do:
<code>
for p, state in self.batched_params(group["params"]):
...
</code>
Args:
group: a parameter group, which is a list of parameters; should be
one of self.groups.
"""
batches = defaultdict(list) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter
for p in param_group:
assert p.grad is not None and "All parameters must have a grad when you call optimizer.step()"
assert not p.grad.is_sparse and "Sparse gradients not supported."
key = (str(p.dtype), *p.shape)
batches[key].append(p)
for p in param_group:
key = (str(p.dtype), *p.shape)
batch = batches[key]
if p is batch[0]: # if this is the 1st param in the batch...
# we arbitrarily store the state in the
# state corresponding to the 1st parameter in the
# group. class Optimizer will take care of saving/loading state.
state = self.state[p]
p_stacked = torch.stack(batch)
grad = torch.stack([p.grad for p in batch])
p_stacked.grad = grad
yield p_stacked, state # <-- calling code will do the actual optimization here!
# Now un-stack the parameter changes
for i,p in enumerate(batch):
p.copy_(p_stacked[i])
class ScaledAdam(BatchedOptimizer):
""" """
Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update
proportional to the norm of that parameter; and also learn the scale of the parameter, proportional to the norm of that parameter; and also learn the scale of the parameter,
@ -170,7 +106,9 @@ class ScaledAdam(BatchedOptimizer):
batch = True batch = True
for group in self.param_groups: for group in self.param_groups:
for p, state in self.batched_params(group["params"]): for p in group["params"]:
state = self.state[p]
# Perform optimization step # Perform optimization step
grad = p.grad grad = p.grad
if grad.is_sparse: if grad.is_sparse:
@ -191,14 +129,11 @@ class ScaledAdam(BatchedOptimizer):
p: Tensor, p: Tensor,
state: dict): state: dict):
""" """
Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p Initializes state dict for parameter 'p'.
is actually the batch dimension, corresponding to batched-together
parameters of a given shape.
Args: Args:
group: Dict to look up configuration values. group: Dict to look up configuration values.
p: The parameter that we are initializing the state for, actually p: The parameter that we are initializing the state for
will be stacked-together "real" parameters
state: Dict from string to whatever state we are initializing state: Dict from string to whatever state we are initializing
""" """
eps = group["eps"] eps = group["eps"]
@ -218,20 +153,17 @@ class ScaledAdam(BatchedOptimizer):
p, memory_format=torch.preserve_format p, memory_format=torch.preserve_format
) )
batch_size = p.shape[0] numel = p.numel()
numel = p.numel() // batch_size
if numel > 1: if numel > 1:
# "param_rms" just periodically records the scalar root-mean-square value of # "param_rms" just periodically records the scalar root-mean-square value of
# the parameter tensor. # the parameter tensor.
# it has a shape like (batch_size, 1, 1, 1, 1) param_rms = (p**2).mean().sqrt() + eps
param_rms = (p**2).mean(dim=list(range(1, p.ndim)),
keepdim=True).sqrt() + eps
state["param_rms"] = param_rms state["param_rms"] = param_rms
state["scale_exp_avg_sq"] = torch.zeros_like(param_rms) state["scale_exp_avg_sq"] = torch.zeros_like(param_rms)
state["scale_grads"] = torch.zeros(size_update_period, *param_rms.shape, state["scale_grads"] = torch.zeros(size_update_period, **kwargs)
**kwargs)
# exp_avg_sq is the weighted sum of scaled gradients. as in Adam. # exp_avg_sq is the weighted sum of scaled gradients. as in Adam.
state["exp_avg_sq"] = torch.zeros_like( state["exp_avg_sq"] = torch.zeros_like(
@ -244,12 +176,10 @@ class ScaledAdam(BatchedOptimizer):
p: Tensor, p: Tensor,
state: dict): state: dict):
""" """
Do the step for one parameter, which is actually going to be a batch of Do the step for parameter p.
`real` parameters, with dim 0 as the batch dim.
Args: Args:
group: dict to look up configuration values group: dict to look up configuration values
p: parameter to update (actually multiple parameters stacked together p: parameter to update
as a batch)
state: state-dict for p, to look up the optimizer state state: state-dict for p, to look up the optimizer state
""" """
lr = group["lr"] lr = group["lr"]
@ -262,17 +192,14 @@ class ScaledAdam(BatchedOptimizer):
delta = state["delta"] delta = state["delta"]
delta.mul_(beta1) delta.mul_(beta1)
batch_size = p.shape[0] numel = p.numel()
numel = p.numel() // batch_size
if numel > 1: if numel > 1:
# Update the size/scale of p, and set param_rms # Update the size/scale of p, and set param_rms
scale_grads = state["scale_grads"] scale_grads = state["scale_grads"]
scale_grads[step % size_update_period] = (p * grad).sum( scale_grads[step % size_update_period] = (p * grad).sum()
dim=list(range(1, p.ndim)), keepdim=True)
if step % size_update_period == size_update_period - 1: if step % size_update_period == size_update_period - 1:
param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..) param_rms = state["param_rms"]
param_rms.copy_((p**2).mean(dim=list(range(1, p.ndim)), param_rms.copy_((p**2).mean().sqrt().clamp_(min=eps))
keepdim=True).sqrt().clamp_(min=eps))
if step > 0: if step > 0:
# self._size_update() learns the overall scale on the # self._size_update() learns the overall scale on the
# parameter, by shrinking or expanding it. # parameter, by shrinking or expanding it.
@ -302,7 +229,7 @@ class ScaledAdam(BatchedOptimizer):
Args: Args:
group: dict to look up configuration values group: dict to look up configuration values
scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing scale_grads: a tensor of shape (size_update_period) containing
grads w.r.t. the scales. grads w.r.t. the scales.
p: The parameter to update p: The parameter to update
state: The state-dict of p state: The state-dict of p
@ -315,16 +242,15 @@ class ScaledAdam(BatchedOptimizer):
param_max_rms = group["param_max_rms"] param_max_rms = group["param_max_rms"]
step = state["step"] step = state["step"]
batch_size = p.shape[0]
size_update_period = scale_grads.shape[0] size_update_period = scale_grads.shape[0]
# correct beta2 for the size update period: we will have # correct beta2 for the size update period: we will have
# faster decay at this level. # faster decay at this level.
beta2_corr = beta2 ** size_update_period beta2_corr = beta2 ** size_update_period
scale_exp_avg_sq = state["scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..) scale_exp_avg_sq = state["scale_exp_avg_sq"] # scalar
scale_exp_avg_sq.mul_(beta2_corr).add_( scale_exp_avg_sq.mul_(beta2_corr).add_(
(scale_grads ** 2).mean(dim=0), # mean over dim `size_update_period` (scale_grads ** 2).mean(), # mean over `size_update_period` elements
alpha=1-beta2_corr) # shape is (batch_size, 1, 1, ...) alpha=1-beta2_corr)
# The 1st time we reach here is when size_step == 1. # The 1st time we reach here is when size_step == 1.
size_step = (step + 1) // size_update_period size_step = (step + 1) // size_update_period
@ -333,9 +259,9 @@ class ScaledAdam(BatchedOptimizer):
# at the start of training. # at the start of training.
eps = 1.0e-10 # this value is not critical. eps = 1.0e-10 # this value is not critical.
denom = scale_exp_avg_sq.sqrt() + eps # shape: (batch_size, 1, 1, ..) denom = scale_exp_avg_sq.sqrt() + eps
scale_step = -size_lr * (bias_correction2 ** 0.5) * scale_grads.sum(dim=0) / denom scale_step = -size_lr * (bias_correction2 ** 0.5) * scale_grads.sum() / denom
is_too_small = (param_rms < param_min_rms) is_too_small = (param_rms < param_min_rms)
is_too_large = (param_rms > param_max_rms) is_too_large = (param_rms > param_max_rms)
@ -353,17 +279,13 @@ class ScaledAdam(BatchedOptimizer):
state: dict): state: dict):
""" """
This function does the core update of self.step(), in the case where the tensor This function does the core update of self.step(), in the case where the tensor
has more than 1 elements. It multiplies the moving-average gradient tensor by a has more than 1 elements.
state["lr_{dim}"] on each non-trivial axis/dimension, does a length normalization,
and takes a step in that direction.
Args: Args:
group: A dict which will be used to look up configuration values group: A dict which will be used to look up configuration values
p: The parameter to be updated p: The parameter to be updated
grad: The grad of p grad: The grad of p
state: The state-dict corresponding to parameter p state: The state-dict corresponding to parameter p
batch: if true, the first dimension of p is a batch dim, i.e. a batch of
parameter matrices.
This function modifies p. This function modifies p.
""" """
@ -401,7 +323,7 @@ class ScaledAdam(BatchedOptimizer):
state: dict): state: dict):
""" """
A simplified form of the core update for scalar tensors, where we cannot get a good A simplified form of the core update for scalar tensors, where we cannot get a good
estimate of the parameter rms. p will not be a scalar, because it has a batch dim. estimate of the parameter rms.
""" """
beta1, beta2 = group["betas"] beta1, beta2 = group["betas"]
scalar_max = group["scalar_max"] scalar_max = group["scalar_max"]
@ -409,7 +331,7 @@ class ScaledAdam(BatchedOptimizer):
lr = group["lr"] lr = group["lr"]
grad = p.grad grad = p.grad
exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,) exp_avg_sq = state["exp_avg_sq"] # scalar
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, exp_avg_sq.mul_(beta2).addcmul_(grad, grad,
value=1-beta2) value=1-beta2)