mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Reintroduce batching to the optimizer
This commit is contained in:
parent
00841f0f49
commit
bd7dce460b
@ -23,11 +23,93 @@ import random
|
||||
from torch import Tensor
|
||||
from torch.optim import Optimizer
|
||||
import logging
|
||||
import contextlib
|
||||
|
||||
|
||||
|
||||
class BatchedOptimizer(Optimizer):
|
||||
"""
|
||||
This class adds to class Optimizer the capability to optimize parameters in batches:
|
||||
it will stack the parameters and their grads for you so the optimizer can work
|
||||
on tensors with an extra leading dimension. This is intended for speed with GPUs,
|
||||
as it reduces the number of kernels launched in the optimizer.
|
||||
|
||||
Args:
|
||||
params:
|
||||
"""
|
||||
def __init__(self, params, defaults):
|
||||
super(BatchedOptimizer, self).__init__(params, defaults)
|
||||
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def batched_params(self, param_group):
|
||||
"""
|
||||
This function returns (technically, yields) a list of
|
||||
of tuples (p, state), where
|
||||
p is a `fake` parameter that is stacked (over axis 0) from real parameters
|
||||
that share the same shape, and its gradient is also stacked;
|
||||
`state` is the state corresponding to this batch of parameters
|
||||
(it will be physically located in the "state" for one of the real
|
||||
parameters, the last one that has any particular shape and dtype).
|
||||
|
||||
This function is decorated as a context manager so that it can
|
||||
write parameters back to their "real" locations.
|
||||
|
||||
The idea is, instead of doing:
|
||||
<code>
|
||||
for p in group["params"]:
|
||||
state = self.state[p]
|
||||
...
|
||||
</code>
|
||||
you can do:
|
||||
<code>
|
||||
with self.batched_params(group["params"]) as batches:
|
||||
for p, state in batches:
|
||||
...
|
||||
</code>
|
||||
|
||||
Args:
|
||||
group: a parameter group, which is a list of parameters; should be
|
||||
one of self.groups.
|
||||
"""
|
||||
batches = defaultdict(list) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter
|
||||
|
||||
for p in param_group:
|
||||
assert not p.grad.is_sparse and "Sparse gradients not supported."
|
||||
key = (str(p.dtype), *p.shape)
|
||||
batches[key].append(p)
|
||||
|
||||
stacked_params_dict = dict()
|
||||
|
||||
pairs = []
|
||||
for p in param_group:
|
||||
key = (str(p.dtype), *p.shape)
|
||||
batch = batches[key]
|
||||
if p is batch[0]:
|
||||
# if this is the 1st param in the batch...
|
||||
# we arbitrarily store the state in the
|
||||
# state corresponding to the 1st parameter in the
|
||||
# group. class Optimizer will take care of saving/loading state.
|
||||
state = self.state[p]
|
||||
p_stacked = torch.stack(batch)
|
||||
grad = torch.stack([torch.zeros_like(p) if p.grad is None else p.grad for p in batch ])
|
||||
p_stacked.grad = grad
|
||||
|
||||
stacked_params_dict[key] = p_stacked
|
||||
pairs.append((p_stacked, state))
|
||||
|
||||
yield pairs # <-- calling code will do the actual optimization here!
|
||||
|
||||
for key, batch in batches.items():
|
||||
stacked_params = stacked_params_dict[key]
|
||||
for i, p in enumerate(batch): # batch is list of Parameter
|
||||
p.copy_(stacked_params[i])
|
||||
|
||||
|
||||
|
||||
|
||||
class ScaledAdam(Optimizer):
|
||||
class ScaledAdam(BatchedOptimizer):
|
||||
"""
|
||||
Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update
|
||||
proportional to the norm of that parameter; and also learn the scale of the parameter,
|
||||
@ -85,7 +167,6 @@ class ScaledAdam(Optimizer):
|
||||
):
|
||||
|
||||
|
||||
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
clipping_scale=clipping_scale,
|
||||
@ -121,26 +202,30 @@ class ScaledAdam(Optimizer):
|
||||
batch = True
|
||||
for group in self.param_groups:
|
||||
|
||||
with self.batched_params(group["params"]) as batches:
|
||||
|
||||
for i,p in enumerate(group["params"]):
|
||||
state = self.state[p]
|
||||
# batches is list of pairs (stacked_param, state). stacked_param is like
|
||||
# a regular parameter, and will have a .grad, but the 1st dim corresponds to
|
||||
# a stacking dim, it is not a real dim.
|
||||
|
||||
# Perform optimization step
|
||||
grad = p.grad
|
||||
if grad is not None and grad.is_sparse:
|
||||
raise RuntimeError(
|
||||
"ScaledAdam optimizer does not support sparse gradients"
|
||||
)
|
||||
# State initialization
|
||||
if len(state) == 0:
|
||||
self._init_state(group, p, state)
|
||||
if len(batches[0][1]) == 0: # if len(first state) == 0: not yet initialized
|
||||
clipping_scale = 1
|
||||
else:
|
||||
clipping_scale = self._get_clipping_scale(group, batches)
|
||||
|
||||
if i == 0:
|
||||
clipping_scale = self._get_clipping_scale(group, p, state)
|
||||
for p, state in batches:
|
||||
# Perform optimization step.
|
||||
# grad is not going to be None, we handled that when creating the batches.
|
||||
grad = p.grad
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError(
|
||||
"ScaledAdam optimizer does not support sparse gradients"
|
||||
)
|
||||
# State initialization
|
||||
if len(state) == 0:
|
||||
self._init_state(group, p, state)
|
||||
|
||||
if grad is None:
|
||||
continue
|
||||
self._step_one_batch(group, p, state, clipping_scale)
|
||||
self._step_one_batch(group, p, state, clipping_scale)
|
||||
|
||||
|
||||
return loss
|
||||
@ -151,14 +236,16 @@ class ScaledAdam(Optimizer):
|
||||
p: Tensor,
|
||||
state: dict):
|
||||
"""
|
||||
Initializes state dict for parameter 'p'.
|
||||
Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p
|
||||
is actually the batch dimension, corresponding to batched-together
|
||||
parameters of a given shape.
|
||||
|
||||
|
||||
Args:
|
||||
group: Dict to look up configuration values.
|
||||
p: The parameter that we are initializing the state for
|
||||
state: Dict from string to whatever state we are initializing
|
||||
"""
|
||||
eps = group["eps"]
|
||||
size_update_period = group["size_update_period"]
|
||||
|
||||
state["step"] = 0
|
||||
@ -175,16 +262,22 @@ class ScaledAdam(Optimizer):
|
||||
p, memory_format=torch.preserve_format
|
||||
)
|
||||
|
||||
batch_size = p.shape[0]
|
||||
numel = p.numel() // batch_size
|
||||
numel = p.numel()
|
||||
|
||||
|
||||
if numel > 1:
|
||||
# "param_rms" just periodically records the scalar root-mean-square value of
|
||||
# the parameter tensor.
|
||||
param_rms = (p**2).mean().sqrt() + eps
|
||||
# it has a shape like (batch_size, 1, 1, 1, 1)
|
||||
param_rms = (p**2).mean(dim=list(range(1, p.ndim)),
|
||||
keepdim=True).sqrt()
|
||||
state["param_rms"] = param_rms
|
||||
|
||||
state["scale_exp_avg_sq"] = torch.zeros_like(param_rms)
|
||||
state["scale_grads"] = torch.zeros(size_update_period, **kwargs)
|
||||
state["scale_grads"] = torch.zeros(size_update_period, *param_rms.shape,
|
||||
**kwargs)
|
||||
|
||||
|
||||
# exp_avg_sq is the weighted sum of scaled gradients. as in Adam.
|
||||
@ -194,45 +287,49 @@ class ScaledAdam(Optimizer):
|
||||
|
||||
def _get_clipping_scale(self,
|
||||
group: dict,
|
||||
p: Tensor,
|
||||
state: dict) -> float:
|
||||
pairs: List[Tuple[Tensor, dict]]) -> float:
|
||||
"""
|
||||
Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients
|
||||
by this amount before applying the rest of the update.
|
||||
This function is only to be called for the first parameter in the group.
|
||||
|
||||
Args:
|
||||
group: the parameter group, an item in self.param_groups
|
||||
pairs: a list of pairs of (param, state) where param is a batched set of parameters, with a .grad
|
||||
(1st dim is batch dim) and state is the state-dict where optimization parameters are kept.
|
||||
"""
|
||||
assert len(pairs) >= 1
|
||||
clipping_scale = group["clipping_scale"]
|
||||
step = state["step"]
|
||||
(first_p, first_state) = pairs[0]
|
||||
step = first_state["step"]
|
||||
if clipping_scale is None or step == 0:
|
||||
# no clipping. return early on step == 0 because the other
|
||||
# parameters' state won't have been initialize yet.
|
||||
# parameters' state won't have been initialized yet.
|
||||
return 1.0
|
||||
clipping_update_period = group["clipping_update_period"]
|
||||
|
||||
tot_sumsq = torch.tensor(0.0, device=p.device)
|
||||
for p in group["params"]:
|
||||
state = self.state[p]
|
||||
tot_sumsq = torch.tensor(0.0, device=first_p.device)
|
||||
for (p, state) in pairs:
|
||||
grad = p.grad
|
||||
if grad is None:
|
||||
continue
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError(
|
||||
"ScaledAdam optimizer does not support sparse gradients"
|
||||
)
|
||||
if p.numel() == 1:
|
||||
if p.numel() == p.shape[0]: # a batch of scalars
|
||||
tot_sumsq += (grad**2).sum() # sum() to change shape [1] to []
|
||||
else:
|
||||
tot_sumsq += (grad**2).sum() * (state["param_rms"] ** 2)
|
||||
tot_sumsq += ((grad * state["param_rms"])**2).sum()
|
||||
|
||||
tot_norm = tot_sumsq.sqrt()
|
||||
if not "model_norms" in state:
|
||||
state["model_norms"] = torch.zeros(clipping_update_period,
|
||||
if not "model_norms" in first_state:
|
||||
first_state["model_norms"] = torch.zeros(clipping_update_period,
|
||||
device=p.device)
|
||||
state["model_norms"][step % clipping_update_period] = tot_norm
|
||||
first_state["model_norms"][step % clipping_update_period] = tot_norm
|
||||
|
||||
if step % clipping_update_period == 0:
|
||||
# We don't reach here if step == 0 because we
|
||||
# would have returned above.
|
||||
sorted_norms = state["model_norms"].sort()[0].to('cpu')
|
||||
# Print some stats.
|
||||
# We don't reach here if step == 0 because we would have returned
|
||||
# above.
|
||||
sorted_norms = first_state["model_norms"].sort()[0].to('cpu')
|
||||
quartiles = []
|
||||
for n in range(0, 5):
|
||||
index = min(clipping_update_period - 1,
|
||||
@ -241,10 +338,10 @@ class ScaledAdam(Optimizer):
|
||||
|
||||
median = quartiles[2]
|
||||
threshold = clipping_scale * median
|
||||
state["model_norm_threshold"] = threshold
|
||||
percent_clipped = (state["num_clipped"] * 100.0 / clipping_update_period
|
||||
if "num_clipped" in state else 0.0)
|
||||
state["num_clipped"] = 0
|
||||
first_state["model_norm_threshold"] = threshold
|
||||
percent_clipped = (first_state["num_clipped"] * 100.0 / clipping_update_period
|
||||
if "num_clipped" in first_state else 0.0)
|
||||
first_state["num_clipped"] = 0
|
||||
quartiles = ' '.join([ '%.3e' % x for x in quartiles ])
|
||||
logging.info(f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, "
|
||||
f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}")
|
||||
@ -252,11 +349,11 @@ class ScaledAdam(Optimizer):
|
||||
if step < clipping_update_period:
|
||||
return 1.0 # We have not yet estimated a norm to clip to.
|
||||
else:
|
||||
model_norm_threshold = state["model_norm_threshold"]
|
||||
model_norm_threshold = first_state["model_norm_threshold"]
|
||||
ans = min(1.0,
|
||||
(model_norm_threshold / (tot_norm + 1.0e-20)).item())
|
||||
if ans < 1.0:
|
||||
state["num_clipped"] += 1
|
||||
first_state["num_clipped"] += 1
|
||||
if ans < 0.1:
|
||||
logging.warn(f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}")
|
||||
return ans
|
||||
@ -268,10 +365,12 @@ class ScaledAdam(Optimizer):
|
||||
state: dict,
|
||||
clipping_scale: float):
|
||||
"""
|
||||
Do the step for parameter p.
|
||||
Do the step for one parameter, which is actually going to be a batch of
|
||||
`real` parameters, with dim 0 as the batch dim.
|
||||
Args:
|
||||
group: dict to look up configuration values
|
||||
p: parameter to update
|
||||
p: parameter to update (actually multiple parameters stacked together
|
||||
as a batch)
|
||||
state: state-dict for p, to look up the optimizer state
|
||||
"""
|
||||
lr = group["lr"]
|
||||
@ -285,14 +384,17 @@ class ScaledAdam(Optimizer):
|
||||
delta = state["delta"]
|
||||
|
||||
delta.mul_(beta1)
|
||||
numel = p.numel()
|
||||
batch_size = p.shape[0]
|
||||
numel = p.numel() // batch_size
|
||||
if numel > 1:
|
||||
# Update the size/scale of p, and set param_rms
|
||||
scale_grads = state["scale_grads"]
|
||||
scale_grads[step % size_update_period] = (p * grad).sum()
|
||||
scale_grads[step % size_update_period] = (p * grad).sum(
|
||||
dim=list(range(1, p.ndim)), keepdim=True)
|
||||
if step % size_update_period == size_update_period - 1:
|
||||
param_rms = state["param_rms"]
|
||||
param_rms.copy_((p**2).mean().sqrt())
|
||||
param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..)
|
||||
param_rms.copy_((p**2).mean(dim=list(range(1, p.ndim)),
|
||||
keepdim=True).sqrt())
|
||||
if step > 0:
|
||||
# self._size_update() learns the overall scale on the
|
||||
# parameter, by shrinking or expanding it.
|
||||
@ -322,7 +424,7 @@ class ScaledAdam(Optimizer):
|
||||
|
||||
Args:
|
||||
group: dict to look up configuration values
|
||||
scale_grads: a tensor of shape (size_update_period) containing
|
||||
scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing
|
||||
grads w.r.t. the scales.
|
||||
p: The parameter to update
|
||||
state: The state-dict of p
|
||||
@ -335,16 +437,17 @@ class ScaledAdam(Optimizer):
|
||||
param_max_rms = group["param_max_rms"]
|
||||
eps = group["eps"]
|
||||
step = state["step"]
|
||||
batch_size = p.shape[0]
|
||||
|
||||
size_update_period = scale_grads.shape[0]
|
||||
# correct beta2 for the size update period: we will have
|
||||
# faster decay at this level.
|
||||
beta2_corr = beta2 ** size_update_period
|
||||
|
||||
scale_exp_avg_sq = state["scale_exp_avg_sq"] # scalar
|
||||
scale_exp_avg_sq = state["scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..)
|
||||
scale_exp_avg_sq.mul_(beta2_corr).add_(
|
||||
(scale_grads ** 2).mean(), # mean over `size_update_period` elements
|
||||
alpha=1-beta2_corr)
|
||||
(scale_grads ** 2).mean(dim=0), # mean over dim `size_update_period`
|
||||
alpha=1-beta2_corr) # shape is (batch_size, 1, 1, ...)
|
||||
|
||||
# The 1st time we reach here is when size_step == 1.
|
||||
size_step = (step + 1) // size_update_period
|
||||
@ -354,7 +457,7 @@ class ScaledAdam(Optimizer):
|
||||
|
||||
denom = scale_exp_avg_sq.sqrt() + eps
|
||||
|
||||
scale_step = -size_lr * (bias_correction2 ** 0.5) * scale_grads.sum() / denom
|
||||
scale_step = -size_lr * (bias_correction2 ** 0.5) * scale_grads.sum(dim=0) / denom
|
||||
|
||||
is_too_small = (param_rms < param_min_rms)
|
||||
is_too_large = (param_rms > param_max_rms)
|
||||
@ -373,8 +476,8 @@ class ScaledAdam(Optimizer):
|
||||
p: Tensor,
|
||||
state: dict):
|
||||
"""
|
||||
This function does the core update of self.step(), in the case where the tensor
|
||||
has more than 1 elements.
|
||||
This function does the core update of self.step(), in the case where the members of
|
||||
the batch have more than 1 element.
|
||||
|
||||
Args:
|
||||
group: A dict which will be used to look up configuration values
|
||||
@ -391,7 +494,6 @@ class ScaledAdam(Optimizer):
|
||||
param_min_rms = group["param_min_rms"]
|
||||
step = state["step"]
|
||||
|
||||
|
||||
exp_avg_sq = state["exp_avg_sq"]
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad,
|
||||
value=(1-beta2))
|
||||
@ -427,7 +529,7 @@ class ScaledAdam(Optimizer):
|
||||
lr = group["lr"] * group["scalar_lr_scale"]
|
||||
grad = p.grad
|
||||
|
||||
exp_avg_sq = state["exp_avg_sq"] # scalar
|
||||
exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,)
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad,
|
||||
value=1-beta2)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user