Reintroduce batching to the optimizer

This commit is contained in:
Daniel Povey 2022-10-09 20:29:23 +08:00
parent 00841f0f49
commit bd7dce460b

View File

@ -23,11 +23,93 @@ import random
from torch import Tensor from torch import Tensor
from torch.optim import Optimizer from torch.optim import Optimizer
import logging import logging
import contextlib
class BatchedOptimizer(Optimizer):
"""
This class adds to class Optimizer the capability to optimize parameters in batches:
it will stack the parameters and their grads for you so the optimizer can work
on tensors with an extra leading dimension. This is intended for speed with GPUs,
as it reduces the number of kernels launched in the optimizer.
Args:
params:
"""
def __init__(self, params, defaults):
super(BatchedOptimizer, self).__init__(params, defaults)
@contextlib.contextmanager
def batched_params(self, param_group):
"""
This function returns (technically, yields) a list of
of tuples (p, state), where
p is a `fake` parameter that is stacked (over axis 0) from real parameters
that share the same shape, and its gradient is also stacked;
`state` is the state corresponding to this batch of parameters
(it will be physically located in the "state" for one of the real
parameters, the last one that has any particular shape and dtype).
This function is decorated as a context manager so that it can
write parameters back to their "real" locations.
The idea is, instead of doing:
<code>
for p in group["params"]:
state = self.state[p]
...
</code>
you can do:
<code>
with self.batched_params(group["params"]) as batches:
for p, state in batches:
...
</code>
Args:
group: a parameter group, which is a list of parameters; should be
one of self.groups.
"""
batches = defaultdict(list) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter
for p in param_group:
assert not p.grad.is_sparse and "Sparse gradients not supported."
key = (str(p.dtype), *p.shape)
batches[key].append(p)
stacked_params_dict = dict()
pairs = []
for p in param_group:
key = (str(p.dtype), *p.shape)
batch = batches[key]
if p is batch[0]:
# if this is the 1st param in the batch...
# we arbitrarily store the state in the
# state corresponding to the 1st parameter in the
# group. class Optimizer will take care of saving/loading state.
state = self.state[p]
p_stacked = torch.stack(batch)
grad = torch.stack([torch.zeros_like(p) if p.grad is None else p.grad for p in batch ])
p_stacked.grad = grad
stacked_params_dict[key] = p_stacked
pairs.append((p_stacked, state))
yield pairs # <-- calling code will do the actual optimization here!
for key, batch in batches.items():
stacked_params = stacked_params_dict[key]
for i, p in enumerate(batch): # batch is list of Parameter
p.copy_(stacked_params[i])
class ScaledAdam(Optimizer): class ScaledAdam(BatchedOptimizer):
""" """
Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update
proportional to the norm of that parameter; and also learn the scale of the parameter, proportional to the norm of that parameter; and also learn the scale of the parameter,
@ -85,7 +167,6 @@ class ScaledAdam(Optimizer):
): ):
defaults = dict( defaults = dict(
lr=lr, lr=lr,
clipping_scale=clipping_scale, clipping_scale=clipping_scale,
@ -121,26 +202,30 @@ class ScaledAdam(Optimizer):
batch = True batch = True
for group in self.param_groups: for group in self.param_groups:
with self.batched_params(group["params"]) as batches:
for i,p in enumerate(group["params"]): # batches is list of pairs (stacked_param, state). stacked_param is like
state = self.state[p] # a regular parameter, and will have a .grad, but the 1st dim corresponds to
# a stacking dim, it is not a real dim.
# Perform optimization step if len(batches[0][1]) == 0: # if len(first state) == 0: not yet initialized
grad = p.grad clipping_scale = 1
if grad is not None and grad.is_sparse: else:
raise RuntimeError( clipping_scale = self._get_clipping_scale(group, batches)
"ScaledAdam optimizer does not support sparse gradients"
)
# State initialization
if len(state) == 0:
self._init_state(group, p, state)
if i == 0: for p, state in batches:
clipping_scale = self._get_clipping_scale(group, p, state) # Perform optimization step.
# grad is not going to be None, we handled that when creating the batches.
grad = p.grad
if grad.is_sparse:
raise RuntimeError(
"ScaledAdam optimizer does not support sparse gradients"
)
# State initialization
if len(state) == 0:
self._init_state(group, p, state)
if grad is None: self._step_one_batch(group, p, state, clipping_scale)
continue
self._step_one_batch(group, p, state, clipping_scale)
return loss return loss
@ -151,14 +236,16 @@ class ScaledAdam(Optimizer):
p: Tensor, p: Tensor,
state: dict): state: dict):
""" """
Initializes state dict for parameter 'p'. Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p
is actually the batch dimension, corresponding to batched-together
parameters of a given shape.
Args: Args:
group: Dict to look up configuration values. group: Dict to look up configuration values.
p: The parameter that we are initializing the state for p: The parameter that we are initializing the state for
state: Dict from string to whatever state we are initializing state: Dict from string to whatever state we are initializing
""" """
eps = group["eps"]
size_update_period = group["size_update_period"] size_update_period = group["size_update_period"]
state["step"] = 0 state["step"] = 0
@ -175,16 +262,22 @@ class ScaledAdam(Optimizer):
p, memory_format=torch.preserve_format p, memory_format=torch.preserve_format
) )
batch_size = p.shape[0]
numel = p.numel() // batch_size
numel = p.numel() numel = p.numel()
if numel > 1: if numel > 1:
# "param_rms" just periodically records the scalar root-mean-square value of # "param_rms" just periodically records the scalar root-mean-square value of
# the parameter tensor. # the parameter tensor.
param_rms = (p**2).mean().sqrt() + eps # it has a shape like (batch_size, 1, 1, 1, 1)
param_rms = (p**2).mean(dim=list(range(1, p.ndim)),
keepdim=True).sqrt()
state["param_rms"] = param_rms state["param_rms"] = param_rms
state["scale_exp_avg_sq"] = torch.zeros_like(param_rms) state["scale_exp_avg_sq"] = torch.zeros_like(param_rms)
state["scale_grads"] = torch.zeros(size_update_period, **kwargs) state["scale_grads"] = torch.zeros(size_update_period, *param_rms.shape,
**kwargs)
# exp_avg_sq is the weighted sum of scaled gradients. as in Adam. # exp_avg_sq is the weighted sum of scaled gradients. as in Adam.
@ -194,45 +287,49 @@ class ScaledAdam(Optimizer):
def _get_clipping_scale(self, def _get_clipping_scale(self,
group: dict, group: dict,
p: Tensor, pairs: List[Tuple[Tensor, dict]]) -> float:
state: dict) -> float:
""" """
Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients
by this amount before applying the rest of the update. by this amount before applying the rest of the update.
This function is only to be called for the first parameter in the group.
Args:
group: the parameter group, an item in self.param_groups
pairs: a list of pairs of (param, state) where param is a batched set of parameters, with a .grad
(1st dim is batch dim) and state is the state-dict where optimization parameters are kept.
""" """
assert len(pairs) >= 1
clipping_scale = group["clipping_scale"] clipping_scale = group["clipping_scale"]
step = state["step"] (first_p, first_state) = pairs[0]
step = first_state["step"]
if clipping_scale is None or step == 0: if clipping_scale is None or step == 0:
# no clipping. return early on step == 0 because the other # no clipping. return early on step == 0 because the other
# parameters' state won't have been initialize yet. # parameters' state won't have been initialized yet.
return 1.0 return 1.0
clipping_update_period = group["clipping_update_period"] clipping_update_period = group["clipping_update_period"]
tot_sumsq = torch.tensor(0.0, device=p.device) tot_sumsq = torch.tensor(0.0, device=first_p.device)
for p in group["params"]: for (p, state) in pairs:
state = self.state[p]
grad = p.grad grad = p.grad
if grad is None:
continue
if grad.is_sparse: if grad.is_sparse:
raise RuntimeError( raise RuntimeError(
"ScaledAdam optimizer does not support sparse gradients" "ScaledAdam optimizer does not support sparse gradients"
) )
if p.numel() == 1: if p.numel() == p.shape[0]: # a batch of scalars
tot_sumsq += (grad**2).sum() # sum() to change shape [1] to [] tot_sumsq += (grad**2).sum() # sum() to change shape [1] to []
else: else:
tot_sumsq += (grad**2).sum() * (state["param_rms"] ** 2) tot_sumsq += ((grad * state["param_rms"])**2).sum()
tot_norm = tot_sumsq.sqrt() tot_norm = tot_sumsq.sqrt()
if not "model_norms" in state: if not "model_norms" in first_state:
state["model_norms"] = torch.zeros(clipping_update_period, first_state["model_norms"] = torch.zeros(clipping_update_period,
device=p.device) device=p.device)
state["model_norms"][step % clipping_update_period] = tot_norm first_state["model_norms"][step % clipping_update_period] = tot_norm
if step % clipping_update_period == 0: if step % clipping_update_period == 0:
# We don't reach here if step == 0 because we # Print some stats.
# would have returned above. # We don't reach here if step == 0 because we would have returned
sorted_norms = state["model_norms"].sort()[0].to('cpu') # above.
sorted_norms = first_state["model_norms"].sort()[0].to('cpu')
quartiles = [] quartiles = []
for n in range(0, 5): for n in range(0, 5):
index = min(clipping_update_period - 1, index = min(clipping_update_period - 1,
@ -241,10 +338,10 @@ class ScaledAdam(Optimizer):
median = quartiles[2] median = quartiles[2]
threshold = clipping_scale * median threshold = clipping_scale * median
state["model_norm_threshold"] = threshold first_state["model_norm_threshold"] = threshold
percent_clipped = (state["num_clipped"] * 100.0 / clipping_update_period percent_clipped = (first_state["num_clipped"] * 100.0 / clipping_update_period
if "num_clipped" in state else 0.0) if "num_clipped" in first_state else 0.0)
state["num_clipped"] = 0 first_state["num_clipped"] = 0
quartiles = ' '.join([ '%.3e' % x for x in quartiles ]) quartiles = ' '.join([ '%.3e' % x for x in quartiles ])
logging.info(f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, " logging.info(f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, "
f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}") f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}")
@ -252,11 +349,11 @@ class ScaledAdam(Optimizer):
if step < clipping_update_period: if step < clipping_update_period:
return 1.0 # We have not yet estimated a norm to clip to. return 1.0 # We have not yet estimated a norm to clip to.
else: else:
model_norm_threshold = state["model_norm_threshold"] model_norm_threshold = first_state["model_norm_threshold"]
ans = min(1.0, ans = min(1.0,
(model_norm_threshold / (tot_norm + 1.0e-20)).item()) (model_norm_threshold / (tot_norm + 1.0e-20)).item())
if ans < 1.0: if ans < 1.0:
state["num_clipped"] += 1 first_state["num_clipped"] += 1
if ans < 0.1: if ans < 0.1:
logging.warn(f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}") logging.warn(f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}")
return ans return ans
@ -268,10 +365,12 @@ class ScaledAdam(Optimizer):
state: dict, state: dict,
clipping_scale: float): clipping_scale: float):
""" """
Do the step for parameter p. Do the step for one parameter, which is actually going to be a batch of
`real` parameters, with dim 0 as the batch dim.
Args: Args:
group: dict to look up configuration values group: dict to look up configuration values
p: parameter to update p: parameter to update (actually multiple parameters stacked together
as a batch)
state: state-dict for p, to look up the optimizer state state: state-dict for p, to look up the optimizer state
""" """
lr = group["lr"] lr = group["lr"]
@ -285,14 +384,17 @@ class ScaledAdam(Optimizer):
delta = state["delta"] delta = state["delta"]
delta.mul_(beta1) delta.mul_(beta1)
numel = p.numel() batch_size = p.shape[0]
numel = p.numel() // batch_size
if numel > 1: if numel > 1:
# Update the size/scale of p, and set param_rms # Update the size/scale of p, and set param_rms
scale_grads = state["scale_grads"] scale_grads = state["scale_grads"]
scale_grads[step % size_update_period] = (p * grad).sum() scale_grads[step % size_update_period] = (p * grad).sum(
dim=list(range(1, p.ndim)), keepdim=True)
if step % size_update_period == size_update_period - 1: if step % size_update_period == size_update_period - 1:
param_rms = state["param_rms"] param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..)
param_rms.copy_((p**2).mean().sqrt()) param_rms.copy_((p**2).mean(dim=list(range(1, p.ndim)),
keepdim=True).sqrt())
if step > 0: if step > 0:
# self._size_update() learns the overall scale on the # self._size_update() learns the overall scale on the
# parameter, by shrinking or expanding it. # parameter, by shrinking or expanding it.
@ -322,7 +424,7 @@ class ScaledAdam(Optimizer):
Args: Args:
group: dict to look up configuration values group: dict to look up configuration values
scale_grads: a tensor of shape (size_update_period) containing scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing
grads w.r.t. the scales. grads w.r.t. the scales.
p: The parameter to update p: The parameter to update
state: The state-dict of p state: The state-dict of p
@ -335,16 +437,17 @@ class ScaledAdam(Optimizer):
param_max_rms = group["param_max_rms"] param_max_rms = group["param_max_rms"]
eps = group["eps"] eps = group["eps"]
step = state["step"] step = state["step"]
batch_size = p.shape[0]
size_update_period = scale_grads.shape[0] size_update_period = scale_grads.shape[0]
# correct beta2 for the size update period: we will have # correct beta2 for the size update period: we will have
# faster decay at this level. # faster decay at this level.
beta2_corr = beta2 ** size_update_period beta2_corr = beta2 ** size_update_period
scale_exp_avg_sq = state["scale_exp_avg_sq"] # scalar scale_exp_avg_sq = state["scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..)
scale_exp_avg_sq.mul_(beta2_corr).add_( scale_exp_avg_sq.mul_(beta2_corr).add_(
(scale_grads ** 2).mean(), # mean over `size_update_period` elements (scale_grads ** 2).mean(dim=0), # mean over dim `size_update_period`
alpha=1-beta2_corr) alpha=1-beta2_corr) # shape is (batch_size, 1, 1, ...)
# The 1st time we reach here is when size_step == 1. # The 1st time we reach here is when size_step == 1.
size_step = (step + 1) // size_update_period size_step = (step + 1) // size_update_period
@ -354,7 +457,7 @@ class ScaledAdam(Optimizer):
denom = scale_exp_avg_sq.sqrt() + eps denom = scale_exp_avg_sq.sqrt() + eps
scale_step = -size_lr * (bias_correction2 ** 0.5) * scale_grads.sum() / denom scale_step = -size_lr * (bias_correction2 ** 0.5) * scale_grads.sum(dim=0) / denom
is_too_small = (param_rms < param_min_rms) is_too_small = (param_rms < param_min_rms)
is_too_large = (param_rms > param_max_rms) is_too_large = (param_rms > param_max_rms)
@ -373,8 +476,8 @@ class ScaledAdam(Optimizer):
p: Tensor, p: Tensor,
state: dict): state: dict):
""" """
This function does the core update of self.step(), in the case where the tensor This function does the core update of self.step(), in the case where the members of
has more than 1 elements. the batch have more than 1 element.
Args: Args:
group: A dict which will be used to look up configuration values group: A dict which will be used to look up configuration values
@ -391,7 +494,6 @@ class ScaledAdam(Optimizer):
param_min_rms = group["param_min_rms"] param_min_rms = group["param_min_rms"]
step = state["step"] step = state["step"]
exp_avg_sq = state["exp_avg_sq"] exp_avg_sq = state["exp_avg_sq"]
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, exp_avg_sq.mul_(beta2).addcmul_(grad, grad,
value=(1-beta2)) value=(1-beta2))
@ -427,7 +529,7 @@ class ScaledAdam(Optimizer):
lr = group["lr"] * group["scalar_lr_scale"] lr = group["lr"] * group["scalar_lr_scale"]
grad = p.grad grad = p.grad
exp_avg_sq = state["exp_avg_sq"] # scalar exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, exp_avg_sq.mul_(beta2).addcmul_(grad, grad,
value=1-beta2) value=1-beta2)