Reintroduce batching to the optimizer

This commit is contained in:
Daniel Povey 2022-10-09 20:29:23 +08:00
parent 00841f0f49
commit bd7dce460b

View File

@ -23,11 +23,93 @@ import random
from torch import Tensor
from torch.optim import Optimizer
import logging
import contextlib
class BatchedOptimizer(Optimizer):
"""
This class adds to class Optimizer the capability to optimize parameters in batches:
it will stack the parameters and their grads for you so the optimizer can work
on tensors with an extra leading dimension. This is intended for speed with GPUs,
as it reduces the number of kernels launched in the optimizer.
Args:
params:
"""
def __init__(self, params, defaults):
super(BatchedOptimizer, self).__init__(params, defaults)
@contextlib.contextmanager
def batched_params(self, param_group):
"""
This function returns (technically, yields) a list of
of tuples (p, state), where
p is a `fake` parameter that is stacked (over axis 0) from real parameters
that share the same shape, and its gradient is also stacked;
`state` is the state corresponding to this batch of parameters
(it will be physically located in the "state" for one of the real
parameters, the last one that has any particular shape and dtype).
This function is decorated as a context manager so that it can
write parameters back to their "real" locations.
The idea is, instead of doing:
<code>
for p in group["params"]:
state = self.state[p]
...
</code>
you can do:
<code>
with self.batched_params(group["params"]) as batches:
for p, state in batches:
...
</code>
Args:
group: a parameter group, which is a list of parameters; should be
one of self.groups.
"""
batches = defaultdict(list) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter
for p in param_group:
assert not p.grad.is_sparse and "Sparse gradients not supported."
key = (str(p.dtype), *p.shape)
batches[key].append(p)
stacked_params_dict = dict()
pairs = []
for p in param_group:
key = (str(p.dtype), *p.shape)
batch = batches[key]
if p is batch[0]:
# if this is the 1st param in the batch...
# we arbitrarily store the state in the
# state corresponding to the 1st parameter in the
# group. class Optimizer will take care of saving/loading state.
state = self.state[p]
p_stacked = torch.stack(batch)
grad = torch.stack([torch.zeros_like(p) if p.grad is None else p.grad for p in batch ])
p_stacked.grad = grad
stacked_params_dict[key] = p_stacked
pairs.append((p_stacked, state))
yield pairs # <-- calling code will do the actual optimization here!
for key, batch in batches.items():
stacked_params = stacked_params_dict[key]
for i, p in enumerate(batch): # batch is list of Parameter
p.copy_(stacked_params[i])
class ScaledAdam(Optimizer):
class ScaledAdam(BatchedOptimizer):
"""
Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update
proportional to the norm of that parameter; and also learn the scale of the parameter,
@ -85,7 +167,6 @@ class ScaledAdam(Optimizer):
):
defaults = dict(
lr=lr,
clipping_scale=clipping_scale,
@ -121,26 +202,30 @@ class ScaledAdam(Optimizer):
batch = True
for group in self.param_groups:
with self.batched_params(group["params"]) as batches:
for i,p in enumerate(group["params"]):
state = self.state[p]
# batches is list of pairs (stacked_param, state). stacked_param is like
# a regular parameter, and will have a .grad, but the 1st dim corresponds to
# a stacking dim, it is not a real dim.
# Perform optimization step
grad = p.grad
if grad is not None and grad.is_sparse:
raise RuntimeError(
"ScaledAdam optimizer does not support sparse gradients"
)
# State initialization
if len(state) == 0:
self._init_state(group, p, state)
if len(batches[0][1]) == 0: # if len(first state) == 0: not yet initialized
clipping_scale = 1
else:
clipping_scale = self._get_clipping_scale(group, batches)
if i == 0:
clipping_scale = self._get_clipping_scale(group, p, state)
for p, state in batches:
# Perform optimization step.
# grad is not going to be None, we handled that when creating the batches.
grad = p.grad
if grad.is_sparse:
raise RuntimeError(
"ScaledAdam optimizer does not support sparse gradients"
)
# State initialization
if len(state) == 0:
self._init_state(group, p, state)
if grad is None:
continue
self._step_one_batch(group, p, state, clipping_scale)
self._step_one_batch(group, p, state, clipping_scale)
return loss
@ -151,14 +236,16 @@ class ScaledAdam(Optimizer):
p: Tensor,
state: dict):
"""
Initializes state dict for parameter 'p'.
Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p
is actually the batch dimension, corresponding to batched-together
parameters of a given shape.
Args:
group: Dict to look up configuration values.
p: The parameter that we are initializing the state for
state: Dict from string to whatever state we are initializing
"""
eps = group["eps"]
size_update_period = group["size_update_period"]
state["step"] = 0
@ -175,16 +262,22 @@ class ScaledAdam(Optimizer):
p, memory_format=torch.preserve_format
)
batch_size = p.shape[0]
numel = p.numel() // batch_size
numel = p.numel()
if numel > 1:
# "param_rms" just periodically records the scalar root-mean-square value of
# the parameter tensor.
param_rms = (p**2).mean().sqrt() + eps
# it has a shape like (batch_size, 1, 1, 1, 1)
param_rms = (p**2).mean(dim=list(range(1, p.ndim)),
keepdim=True).sqrt()
state["param_rms"] = param_rms
state["scale_exp_avg_sq"] = torch.zeros_like(param_rms)
state["scale_grads"] = torch.zeros(size_update_period, **kwargs)
state["scale_grads"] = torch.zeros(size_update_period, *param_rms.shape,
**kwargs)
# exp_avg_sq is the weighted sum of scaled gradients. as in Adam.
@ -194,45 +287,49 @@ class ScaledAdam(Optimizer):
def _get_clipping_scale(self,
group: dict,
p: Tensor,
state: dict) -> float:
pairs: List[Tuple[Tensor, dict]]) -> float:
"""
Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients
by this amount before applying the rest of the update.
This function is only to be called for the first parameter in the group.
Args:
group: the parameter group, an item in self.param_groups
pairs: a list of pairs of (param, state) where param is a batched set of parameters, with a .grad
(1st dim is batch dim) and state is the state-dict where optimization parameters are kept.
"""
assert len(pairs) >= 1
clipping_scale = group["clipping_scale"]
step = state["step"]
(first_p, first_state) = pairs[0]
step = first_state["step"]
if clipping_scale is None or step == 0:
# no clipping. return early on step == 0 because the other
# parameters' state won't have been initialize yet.
# parameters' state won't have been initialized yet.
return 1.0
clipping_update_period = group["clipping_update_period"]
tot_sumsq = torch.tensor(0.0, device=p.device)
for p in group["params"]:
state = self.state[p]
tot_sumsq = torch.tensor(0.0, device=first_p.device)
for (p, state) in pairs:
grad = p.grad
if grad is None:
continue
if grad.is_sparse:
raise RuntimeError(
"ScaledAdam optimizer does not support sparse gradients"
)
if p.numel() == 1:
if p.numel() == p.shape[0]: # a batch of scalars
tot_sumsq += (grad**2).sum() # sum() to change shape [1] to []
else:
tot_sumsq += (grad**2).sum() * (state["param_rms"] ** 2)
tot_sumsq += ((grad * state["param_rms"])**2).sum()
tot_norm = tot_sumsq.sqrt()
if not "model_norms" in state:
state["model_norms"] = torch.zeros(clipping_update_period,
if not "model_norms" in first_state:
first_state["model_norms"] = torch.zeros(clipping_update_period,
device=p.device)
state["model_norms"][step % clipping_update_period] = tot_norm
first_state["model_norms"][step % clipping_update_period] = tot_norm
if step % clipping_update_period == 0:
# We don't reach here if step == 0 because we
# would have returned above.
sorted_norms = state["model_norms"].sort()[0].to('cpu')
# Print some stats.
# We don't reach here if step == 0 because we would have returned
# above.
sorted_norms = first_state["model_norms"].sort()[0].to('cpu')
quartiles = []
for n in range(0, 5):
index = min(clipping_update_period - 1,
@ -241,10 +338,10 @@ class ScaledAdam(Optimizer):
median = quartiles[2]
threshold = clipping_scale * median
state["model_norm_threshold"] = threshold
percent_clipped = (state["num_clipped"] * 100.0 / clipping_update_period
if "num_clipped" in state else 0.0)
state["num_clipped"] = 0
first_state["model_norm_threshold"] = threshold
percent_clipped = (first_state["num_clipped"] * 100.0 / clipping_update_period
if "num_clipped" in first_state else 0.0)
first_state["num_clipped"] = 0
quartiles = ' '.join([ '%.3e' % x for x in quartiles ])
logging.info(f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, "
f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}")
@ -252,11 +349,11 @@ class ScaledAdam(Optimizer):
if step < clipping_update_period:
return 1.0 # We have not yet estimated a norm to clip to.
else:
model_norm_threshold = state["model_norm_threshold"]
model_norm_threshold = first_state["model_norm_threshold"]
ans = min(1.0,
(model_norm_threshold / (tot_norm + 1.0e-20)).item())
if ans < 1.0:
state["num_clipped"] += 1
first_state["num_clipped"] += 1
if ans < 0.1:
logging.warn(f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}")
return ans
@ -268,10 +365,12 @@ class ScaledAdam(Optimizer):
state: dict,
clipping_scale: float):
"""
Do the step for parameter p.
Do the step for one parameter, which is actually going to be a batch of
`real` parameters, with dim 0 as the batch dim.
Args:
group: dict to look up configuration values
p: parameter to update
p: parameter to update (actually multiple parameters stacked together
as a batch)
state: state-dict for p, to look up the optimizer state
"""
lr = group["lr"]
@ -285,14 +384,17 @@ class ScaledAdam(Optimizer):
delta = state["delta"]
delta.mul_(beta1)
numel = p.numel()
batch_size = p.shape[0]
numel = p.numel() // batch_size
if numel > 1:
# Update the size/scale of p, and set param_rms
scale_grads = state["scale_grads"]
scale_grads[step % size_update_period] = (p * grad).sum()
scale_grads[step % size_update_period] = (p * grad).sum(
dim=list(range(1, p.ndim)), keepdim=True)
if step % size_update_period == size_update_period - 1:
param_rms = state["param_rms"]
param_rms.copy_((p**2).mean().sqrt())
param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..)
param_rms.copy_((p**2).mean(dim=list(range(1, p.ndim)),
keepdim=True).sqrt())
if step > 0:
# self._size_update() learns the overall scale on the
# parameter, by shrinking or expanding it.
@ -322,7 +424,7 @@ class ScaledAdam(Optimizer):
Args:
group: dict to look up configuration values
scale_grads: a tensor of shape (size_update_period) containing
scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing
grads w.r.t. the scales.
p: The parameter to update
state: The state-dict of p
@ -335,16 +437,17 @@ class ScaledAdam(Optimizer):
param_max_rms = group["param_max_rms"]
eps = group["eps"]
step = state["step"]
batch_size = p.shape[0]
size_update_period = scale_grads.shape[0]
# correct beta2 for the size update period: we will have
# faster decay at this level.
beta2_corr = beta2 ** size_update_period
scale_exp_avg_sq = state["scale_exp_avg_sq"] # scalar
scale_exp_avg_sq = state["scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..)
scale_exp_avg_sq.mul_(beta2_corr).add_(
(scale_grads ** 2).mean(), # mean over `size_update_period` elements
alpha=1-beta2_corr)
(scale_grads ** 2).mean(dim=0), # mean over dim `size_update_period`
alpha=1-beta2_corr) # shape is (batch_size, 1, 1, ...)
# The 1st time we reach here is when size_step == 1.
size_step = (step + 1) // size_update_period
@ -354,7 +457,7 @@ class ScaledAdam(Optimizer):
denom = scale_exp_avg_sq.sqrt() + eps
scale_step = -size_lr * (bias_correction2 ** 0.5) * scale_grads.sum() / denom
scale_step = -size_lr * (bias_correction2 ** 0.5) * scale_grads.sum(dim=0) / denom
is_too_small = (param_rms < param_min_rms)
is_too_large = (param_rms > param_max_rms)
@ -373,8 +476,8 @@ class ScaledAdam(Optimizer):
p: Tensor,
state: dict):
"""
This function does the core update of self.step(), in the case where the tensor
has more than 1 elements.
This function does the core update of self.step(), in the case where the members of
the batch have more than 1 element.
Args:
group: A dict which will be used to look up configuration values
@ -391,7 +494,6 @@ class ScaledAdam(Optimizer):
param_min_rms = group["param_min_rms"]
step = state["step"]
exp_avg_sq = state["exp_avg_sq"]
exp_avg_sq.mul_(beta2).addcmul_(grad, grad,
value=(1-beta2))
@ -427,7 +529,7 @@ class ScaledAdam(Optimizer):
lr = group["lr"] * group["scalar_lr_scale"]
grad = p.grad
exp_avg_sq = state["exp_avg_sq"] # scalar
exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad,
value=1-beta2)