Reintroduce batching to the optimizer
This commit is contained in:
parent
00841f0f49
commit
bd7dce460b
@ -23,11 +23,93 @@ import random
|
|||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
import logging
|
import logging
|
||||||
|
import contextlib
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class BatchedOptimizer(Optimizer):
|
||||||
|
"""
|
||||||
|
This class adds to class Optimizer the capability to optimize parameters in batches:
|
||||||
|
it will stack the parameters and their grads for you so the optimizer can work
|
||||||
|
on tensors with an extra leading dimension. This is intended for speed with GPUs,
|
||||||
|
as it reduces the number of kernels launched in the optimizer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params:
|
||||||
|
"""
|
||||||
|
def __init__(self, params, defaults):
|
||||||
|
super(BatchedOptimizer, self).__init__(params, defaults)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def batched_params(self, param_group):
|
||||||
|
"""
|
||||||
|
This function returns (technically, yields) a list of
|
||||||
|
of tuples (p, state), where
|
||||||
|
p is a `fake` parameter that is stacked (over axis 0) from real parameters
|
||||||
|
that share the same shape, and its gradient is also stacked;
|
||||||
|
`state` is the state corresponding to this batch of parameters
|
||||||
|
(it will be physically located in the "state" for one of the real
|
||||||
|
parameters, the last one that has any particular shape and dtype).
|
||||||
|
|
||||||
|
This function is decorated as a context manager so that it can
|
||||||
|
write parameters back to their "real" locations.
|
||||||
|
|
||||||
|
The idea is, instead of doing:
|
||||||
|
<code>
|
||||||
|
for p in group["params"]:
|
||||||
|
state = self.state[p]
|
||||||
|
...
|
||||||
|
</code>
|
||||||
|
you can do:
|
||||||
|
<code>
|
||||||
|
with self.batched_params(group["params"]) as batches:
|
||||||
|
for p, state in batches:
|
||||||
|
...
|
||||||
|
</code>
|
||||||
|
|
||||||
|
Args:
|
||||||
|
group: a parameter group, which is a list of parameters; should be
|
||||||
|
one of self.groups.
|
||||||
|
"""
|
||||||
|
batches = defaultdict(list) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter
|
||||||
|
|
||||||
|
for p in param_group:
|
||||||
|
assert not p.grad.is_sparse and "Sparse gradients not supported."
|
||||||
|
key = (str(p.dtype), *p.shape)
|
||||||
|
batches[key].append(p)
|
||||||
|
|
||||||
|
stacked_params_dict = dict()
|
||||||
|
|
||||||
|
pairs = []
|
||||||
|
for p in param_group:
|
||||||
|
key = (str(p.dtype), *p.shape)
|
||||||
|
batch = batches[key]
|
||||||
|
if p is batch[0]:
|
||||||
|
# if this is the 1st param in the batch...
|
||||||
|
# we arbitrarily store the state in the
|
||||||
|
# state corresponding to the 1st parameter in the
|
||||||
|
# group. class Optimizer will take care of saving/loading state.
|
||||||
|
state = self.state[p]
|
||||||
|
p_stacked = torch.stack(batch)
|
||||||
|
grad = torch.stack([torch.zeros_like(p) if p.grad is None else p.grad for p in batch ])
|
||||||
|
p_stacked.grad = grad
|
||||||
|
|
||||||
|
stacked_params_dict[key] = p_stacked
|
||||||
|
pairs.append((p_stacked, state))
|
||||||
|
|
||||||
|
yield pairs # <-- calling code will do the actual optimization here!
|
||||||
|
|
||||||
|
for key, batch in batches.items():
|
||||||
|
stacked_params = stacked_params_dict[key]
|
||||||
|
for i, p in enumerate(batch): # batch is list of Parameter
|
||||||
|
p.copy_(stacked_params[i])
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ScaledAdam(Optimizer):
|
class ScaledAdam(BatchedOptimizer):
|
||||||
"""
|
"""
|
||||||
Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update
|
Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update
|
||||||
proportional to the norm of that parameter; and also learn the scale of the parameter,
|
proportional to the norm of that parameter; and also learn the scale of the parameter,
|
||||||
@ -85,7 +167,6 @@ class ScaledAdam(Optimizer):
|
|||||||
):
|
):
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
defaults = dict(
|
defaults = dict(
|
||||||
lr=lr,
|
lr=lr,
|
||||||
clipping_scale=clipping_scale,
|
clipping_scale=clipping_scale,
|
||||||
@ -121,26 +202,30 @@ class ScaledAdam(Optimizer):
|
|||||||
batch = True
|
batch = True
|
||||||
for group in self.param_groups:
|
for group in self.param_groups:
|
||||||
|
|
||||||
|
with self.batched_params(group["params"]) as batches:
|
||||||
|
|
||||||
for i,p in enumerate(group["params"]):
|
# batches is list of pairs (stacked_param, state). stacked_param is like
|
||||||
state = self.state[p]
|
# a regular parameter, and will have a .grad, but the 1st dim corresponds to
|
||||||
|
# a stacking dim, it is not a real dim.
|
||||||
|
|
||||||
# Perform optimization step
|
if len(batches[0][1]) == 0: # if len(first state) == 0: not yet initialized
|
||||||
grad = p.grad
|
clipping_scale = 1
|
||||||
if grad is not None and grad.is_sparse:
|
else:
|
||||||
raise RuntimeError(
|
clipping_scale = self._get_clipping_scale(group, batches)
|
||||||
"ScaledAdam optimizer does not support sparse gradients"
|
|
||||||
)
|
|
||||||
# State initialization
|
|
||||||
if len(state) == 0:
|
|
||||||
self._init_state(group, p, state)
|
|
||||||
|
|
||||||
if i == 0:
|
for p, state in batches:
|
||||||
clipping_scale = self._get_clipping_scale(group, p, state)
|
# Perform optimization step.
|
||||||
|
# grad is not going to be None, we handled that when creating the batches.
|
||||||
|
grad = p.grad
|
||||||
|
if grad.is_sparse:
|
||||||
|
raise RuntimeError(
|
||||||
|
"ScaledAdam optimizer does not support sparse gradients"
|
||||||
|
)
|
||||||
|
# State initialization
|
||||||
|
if len(state) == 0:
|
||||||
|
self._init_state(group, p, state)
|
||||||
|
|
||||||
if grad is None:
|
self._step_one_batch(group, p, state, clipping_scale)
|
||||||
continue
|
|
||||||
self._step_one_batch(group, p, state, clipping_scale)
|
|
||||||
|
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
@ -151,14 +236,16 @@ class ScaledAdam(Optimizer):
|
|||||||
p: Tensor,
|
p: Tensor,
|
||||||
state: dict):
|
state: dict):
|
||||||
"""
|
"""
|
||||||
Initializes state dict for parameter 'p'.
|
Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p
|
||||||
|
is actually the batch dimension, corresponding to batched-together
|
||||||
|
parameters of a given shape.
|
||||||
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
group: Dict to look up configuration values.
|
group: Dict to look up configuration values.
|
||||||
p: The parameter that we are initializing the state for
|
p: The parameter that we are initializing the state for
|
||||||
state: Dict from string to whatever state we are initializing
|
state: Dict from string to whatever state we are initializing
|
||||||
"""
|
"""
|
||||||
eps = group["eps"]
|
|
||||||
size_update_period = group["size_update_period"]
|
size_update_period = group["size_update_period"]
|
||||||
|
|
||||||
state["step"] = 0
|
state["step"] = 0
|
||||||
@ -175,16 +262,22 @@ class ScaledAdam(Optimizer):
|
|||||||
p, memory_format=torch.preserve_format
|
p, memory_format=torch.preserve_format
|
||||||
)
|
)
|
||||||
|
|
||||||
|
batch_size = p.shape[0]
|
||||||
|
numel = p.numel() // batch_size
|
||||||
numel = p.numel()
|
numel = p.numel()
|
||||||
|
|
||||||
|
|
||||||
if numel > 1:
|
if numel > 1:
|
||||||
# "param_rms" just periodically records the scalar root-mean-square value of
|
# "param_rms" just periodically records the scalar root-mean-square value of
|
||||||
# the parameter tensor.
|
# the parameter tensor.
|
||||||
param_rms = (p**2).mean().sqrt() + eps
|
# it has a shape like (batch_size, 1, 1, 1, 1)
|
||||||
|
param_rms = (p**2).mean(dim=list(range(1, p.ndim)),
|
||||||
|
keepdim=True).sqrt()
|
||||||
state["param_rms"] = param_rms
|
state["param_rms"] = param_rms
|
||||||
|
|
||||||
state["scale_exp_avg_sq"] = torch.zeros_like(param_rms)
|
state["scale_exp_avg_sq"] = torch.zeros_like(param_rms)
|
||||||
state["scale_grads"] = torch.zeros(size_update_period, **kwargs)
|
state["scale_grads"] = torch.zeros(size_update_period, *param_rms.shape,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
|
||||||
# exp_avg_sq is the weighted sum of scaled gradients. as in Adam.
|
# exp_avg_sq is the weighted sum of scaled gradients. as in Adam.
|
||||||
@ -194,45 +287,49 @@ class ScaledAdam(Optimizer):
|
|||||||
|
|
||||||
def _get_clipping_scale(self,
|
def _get_clipping_scale(self,
|
||||||
group: dict,
|
group: dict,
|
||||||
p: Tensor,
|
pairs: List[Tuple[Tensor, dict]]) -> float:
|
||||||
state: dict) -> float:
|
|
||||||
"""
|
"""
|
||||||
Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients
|
Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients
|
||||||
by this amount before applying the rest of the update.
|
by this amount before applying the rest of the update.
|
||||||
This function is only to be called for the first parameter in the group.
|
|
||||||
|
Args:
|
||||||
|
group: the parameter group, an item in self.param_groups
|
||||||
|
pairs: a list of pairs of (param, state) where param is a batched set of parameters, with a .grad
|
||||||
|
(1st dim is batch dim) and state is the state-dict where optimization parameters are kept.
|
||||||
"""
|
"""
|
||||||
|
assert len(pairs) >= 1
|
||||||
clipping_scale = group["clipping_scale"]
|
clipping_scale = group["clipping_scale"]
|
||||||
step = state["step"]
|
(first_p, first_state) = pairs[0]
|
||||||
|
step = first_state["step"]
|
||||||
if clipping_scale is None or step == 0:
|
if clipping_scale is None or step == 0:
|
||||||
# no clipping. return early on step == 0 because the other
|
# no clipping. return early on step == 0 because the other
|
||||||
# parameters' state won't have been initialize yet.
|
# parameters' state won't have been initialized yet.
|
||||||
return 1.0
|
return 1.0
|
||||||
clipping_update_period = group["clipping_update_period"]
|
clipping_update_period = group["clipping_update_period"]
|
||||||
|
|
||||||
tot_sumsq = torch.tensor(0.0, device=p.device)
|
tot_sumsq = torch.tensor(0.0, device=first_p.device)
|
||||||
for p in group["params"]:
|
for (p, state) in pairs:
|
||||||
state = self.state[p]
|
|
||||||
grad = p.grad
|
grad = p.grad
|
||||||
if grad is None:
|
|
||||||
continue
|
|
||||||
if grad.is_sparse:
|
if grad.is_sparse:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"ScaledAdam optimizer does not support sparse gradients"
|
"ScaledAdam optimizer does not support sparse gradients"
|
||||||
)
|
)
|
||||||
if p.numel() == 1:
|
if p.numel() == p.shape[0]: # a batch of scalars
|
||||||
tot_sumsq += (grad**2).sum() # sum() to change shape [1] to []
|
tot_sumsq += (grad**2).sum() # sum() to change shape [1] to []
|
||||||
else:
|
else:
|
||||||
tot_sumsq += (grad**2).sum() * (state["param_rms"] ** 2)
|
tot_sumsq += ((grad * state["param_rms"])**2).sum()
|
||||||
|
|
||||||
tot_norm = tot_sumsq.sqrt()
|
tot_norm = tot_sumsq.sqrt()
|
||||||
if not "model_norms" in state:
|
if not "model_norms" in first_state:
|
||||||
state["model_norms"] = torch.zeros(clipping_update_period,
|
first_state["model_norms"] = torch.zeros(clipping_update_period,
|
||||||
device=p.device)
|
device=p.device)
|
||||||
state["model_norms"][step % clipping_update_period] = tot_norm
|
first_state["model_norms"][step % clipping_update_period] = tot_norm
|
||||||
|
|
||||||
if step % clipping_update_period == 0:
|
if step % clipping_update_period == 0:
|
||||||
# We don't reach here if step == 0 because we
|
# Print some stats.
|
||||||
# would have returned above.
|
# We don't reach here if step == 0 because we would have returned
|
||||||
sorted_norms = state["model_norms"].sort()[0].to('cpu')
|
# above.
|
||||||
|
sorted_norms = first_state["model_norms"].sort()[0].to('cpu')
|
||||||
quartiles = []
|
quartiles = []
|
||||||
for n in range(0, 5):
|
for n in range(0, 5):
|
||||||
index = min(clipping_update_period - 1,
|
index = min(clipping_update_period - 1,
|
||||||
@ -241,10 +338,10 @@ class ScaledAdam(Optimizer):
|
|||||||
|
|
||||||
median = quartiles[2]
|
median = quartiles[2]
|
||||||
threshold = clipping_scale * median
|
threshold = clipping_scale * median
|
||||||
state["model_norm_threshold"] = threshold
|
first_state["model_norm_threshold"] = threshold
|
||||||
percent_clipped = (state["num_clipped"] * 100.0 / clipping_update_period
|
percent_clipped = (first_state["num_clipped"] * 100.0 / clipping_update_period
|
||||||
if "num_clipped" in state else 0.0)
|
if "num_clipped" in first_state else 0.0)
|
||||||
state["num_clipped"] = 0
|
first_state["num_clipped"] = 0
|
||||||
quartiles = ' '.join([ '%.3e' % x for x in quartiles ])
|
quartiles = ' '.join([ '%.3e' % x for x in quartiles ])
|
||||||
logging.info(f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, "
|
logging.info(f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, "
|
||||||
f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}")
|
f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}")
|
||||||
@ -252,11 +349,11 @@ class ScaledAdam(Optimizer):
|
|||||||
if step < clipping_update_period:
|
if step < clipping_update_period:
|
||||||
return 1.0 # We have not yet estimated a norm to clip to.
|
return 1.0 # We have not yet estimated a norm to clip to.
|
||||||
else:
|
else:
|
||||||
model_norm_threshold = state["model_norm_threshold"]
|
model_norm_threshold = first_state["model_norm_threshold"]
|
||||||
ans = min(1.0,
|
ans = min(1.0,
|
||||||
(model_norm_threshold / (tot_norm + 1.0e-20)).item())
|
(model_norm_threshold / (tot_norm + 1.0e-20)).item())
|
||||||
if ans < 1.0:
|
if ans < 1.0:
|
||||||
state["num_clipped"] += 1
|
first_state["num_clipped"] += 1
|
||||||
if ans < 0.1:
|
if ans < 0.1:
|
||||||
logging.warn(f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}")
|
logging.warn(f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}")
|
||||||
return ans
|
return ans
|
||||||
@ -268,10 +365,12 @@ class ScaledAdam(Optimizer):
|
|||||||
state: dict,
|
state: dict,
|
||||||
clipping_scale: float):
|
clipping_scale: float):
|
||||||
"""
|
"""
|
||||||
Do the step for parameter p.
|
Do the step for one parameter, which is actually going to be a batch of
|
||||||
|
`real` parameters, with dim 0 as the batch dim.
|
||||||
Args:
|
Args:
|
||||||
group: dict to look up configuration values
|
group: dict to look up configuration values
|
||||||
p: parameter to update
|
p: parameter to update (actually multiple parameters stacked together
|
||||||
|
as a batch)
|
||||||
state: state-dict for p, to look up the optimizer state
|
state: state-dict for p, to look up the optimizer state
|
||||||
"""
|
"""
|
||||||
lr = group["lr"]
|
lr = group["lr"]
|
||||||
@ -285,14 +384,17 @@ class ScaledAdam(Optimizer):
|
|||||||
delta = state["delta"]
|
delta = state["delta"]
|
||||||
|
|
||||||
delta.mul_(beta1)
|
delta.mul_(beta1)
|
||||||
numel = p.numel()
|
batch_size = p.shape[0]
|
||||||
|
numel = p.numel() // batch_size
|
||||||
if numel > 1:
|
if numel > 1:
|
||||||
# Update the size/scale of p, and set param_rms
|
# Update the size/scale of p, and set param_rms
|
||||||
scale_grads = state["scale_grads"]
|
scale_grads = state["scale_grads"]
|
||||||
scale_grads[step % size_update_period] = (p * grad).sum()
|
scale_grads[step % size_update_period] = (p * grad).sum(
|
||||||
|
dim=list(range(1, p.ndim)), keepdim=True)
|
||||||
if step % size_update_period == size_update_period - 1:
|
if step % size_update_period == size_update_period - 1:
|
||||||
param_rms = state["param_rms"]
|
param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..)
|
||||||
param_rms.copy_((p**2).mean().sqrt())
|
param_rms.copy_((p**2).mean(dim=list(range(1, p.ndim)),
|
||||||
|
keepdim=True).sqrt())
|
||||||
if step > 0:
|
if step > 0:
|
||||||
# self._size_update() learns the overall scale on the
|
# self._size_update() learns the overall scale on the
|
||||||
# parameter, by shrinking or expanding it.
|
# parameter, by shrinking or expanding it.
|
||||||
@ -322,7 +424,7 @@ class ScaledAdam(Optimizer):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
group: dict to look up configuration values
|
group: dict to look up configuration values
|
||||||
scale_grads: a tensor of shape (size_update_period) containing
|
scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing
|
||||||
grads w.r.t. the scales.
|
grads w.r.t. the scales.
|
||||||
p: The parameter to update
|
p: The parameter to update
|
||||||
state: The state-dict of p
|
state: The state-dict of p
|
||||||
@ -335,16 +437,17 @@ class ScaledAdam(Optimizer):
|
|||||||
param_max_rms = group["param_max_rms"]
|
param_max_rms = group["param_max_rms"]
|
||||||
eps = group["eps"]
|
eps = group["eps"]
|
||||||
step = state["step"]
|
step = state["step"]
|
||||||
|
batch_size = p.shape[0]
|
||||||
|
|
||||||
size_update_period = scale_grads.shape[0]
|
size_update_period = scale_grads.shape[0]
|
||||||
# correct beta2 for the size update period: we will have
|
# correct beta2 for the size update period: we will have
|
||||||
# faster decay at this level.
|
# faster decay at this level.
|
||||||
beta2_corr = beta2 ** size_update_period
|
beta2_corr = beta2 ** size_update_period
|
||||||
|
|
||||||
scale_exp_avg_sq = state["scale_exp_avg_sq"] # scalar
|
scale_exp_avg_sq = state["scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..)
|
||||||
scale_exp_avg_sq.mul_(beta2_corr).add_(
|
scale_exp_avg_sq.mul_(beta2_corr).add_(
|
||||||
(scale_grads ** 2).mean(), # mean over `size_update_period` elements
|
(scale_grads ** 2).mean(dim=0), # mean over dim `size_update_period`
|
||||||
alpha=1-beta2_corr)
|
alpha=1-beta2_corr) # shape is (batch_size, 1, 1, ...)
|
||||||
|
|
||||||
# The 1st time we reach here is when size_step == 1.
|
# The 1st time we reach here is when size_step == 1.
|
||||||
size_step = (step + 1) // size_update_period
|
size_step = (step + 1) // size_update_period
|
||||||
@ -354,7 +457,7 @@ class ScaledAdam(Optimizer):
|
|||||||
|
|
||||||
denom = scale_exp_avg_sq.sqrt() + eps
|
denom = scale_exp_avg_sq.sqrt() + eps
|
||||||
|
|
||||||
scale_step = -size_lr * (bias_correction2 ** 0.5) * scale_grads.sum() / denom
|
scale_step = -size_lr * (bias_correction2 ** 0.5) * scale_grads.sum(dim=0) / denom
|
||||||
|
|
||||||
is_too_small = (param_rms < param_min_rms)
|
is_too_small = (param_rms < param_min_rms)
|
||||||
is_too_large = (param_rms > param_max_rms)
|
is_too_large = (param_rms > param_max_rms)
|
||||||
@ -373,8 +476,8 @@ class ScaledAdam(Optimizer):
|
|||||||
p: Tensor,
|
p: Tensor,
|
||||||
state: dict):
|
state: dict):
|
||||||
"""
|
"""
|
||||||
This function does the core update of self.step(), in the case where the tensor
|
This function does the core update of self.step(), in the case where the members of
|
||||||
has more than 1 elements.
|
the batch have more than 1 element.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
group: A dict which will be used to look up configuration values
|
group: A dict which will be used to look up configuration values
|
||||||
@ -391,7 +494,6 @@ class ScaledAdam(Optimizer):
|
|||||||
param_min_rms = group["param_min_rms"]
|
param_min_rms = group["param_min_rms"]
|
||||||
step = state["step"]
|
step = state["step"]
|
||||||
|
|
||||||
|
|
||||||
exp_avg_sq = state["exp_avg_sq"]
|
exp_avg_sq = state["exp_avg_sq"]
|
||||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad,
|
exp_avg_sq.mul_(beta2).addcmul_(grad, grad,
|
||||||
value=(1-beta2))
|
value=(1-beta2))
|
||||||
@ -427,7 +529,7 @@ class ScaledAdam(Optimizer):
|
|||||||
lr = group["lr"] * group["scalar_lr_scale"]
|
lr = group["lr"] * group["scalar_lr_scale"]
|
||||||
grad = p.grad
|
grad = p.grad
|
||||||
|
|
||||||
exp_avg_sq = state["exp_avg_sq"] # scalar
|
exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,)
|
||||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad,
|
exp_avg_sq.mul_(beta2).addcmul_(grad, grad,
|
||||||
value=1-beta2)
|
value=1-beta2)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user