diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py
index 6294592a8..c1589b907 100644
--- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py
@@ -23,11 +23,93 @@ import random
from torch import Tensor
from torch.optim import Optimizer
import logging
+import contextlib
+
+
+
+class BatchedOptimizer(Optimizer):
+ """
+ This class adds to class Optimizer the capability to optimize parameters in batches:
+ it will stack the parameters and their grads for you so the optimizer can work
+ on tensors with an extra leading dimension. This is intended for speed with GPUs,
+ as it reduces the number of kernels launched in the optimizer.
+
+ Args:
+ params:
+ """
+ def __init__(self, params, defaults):
+ super(BatchedOptimizer, self).__init__(params, defaults)
+
+
+
+ @contextlib.contextmanager
+ def batched_params(self, param_group):
+ """
+ This function returns (technically, yields) a list of
+ of tuples (p, state), where
+ p is a `fake` parameter that is stacked (over axis 0) from real parameters
+ that share the same shape, and its gradient is also stacked;
+ `state` is the state corresponding to this batch of parameters
+ (it will be physically located in the "state" for one of the real
+ parameters, the last one that has any particular shape and dtype).
+
+ This function is decorated as a context manager so that it can
+ write parameters back to their "real" locations.
+
+ The idea is, instead of doing:
+
+ for p in group["params"]:
+ state = self.state[p]
+ ...
+
+ you can do:
+
+ with self.batched_params(group["params"]) as batches:
+ for p, state in batches:
+ ...
+
+
+ Args:
+ group: a parameter group, which is a list of parameters; should be
+ one of self.groups.
+ """
+ batches = defaultdict(list) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter
+
+ for p in param_group:
+ assert not p.grad.is_sparse and "Sparse gradients not supported."
+ key = (str(p.dtype), *p.shape)
+ batches[key].append(p)
+
+ stacked_params_dict = dict()
+
+ pairs = []
+ for p in param_group:
+ key = (str(p.dtype), *p.shape)
+ batch = batches[key]
+ if p is batch[0]:
+ # if this is the 1st param in the batch...
+ # we arbitrarily store the state in the
+ # state corresponding to the 1st parameter in the
+ # group. class Optimizer will take care of saving/loading state.
+ state = self.state[p]
+ p_stacked = torch.stack(batch)
+ grad = torch.stack([torch.zeros_like(p) if p.grad is None else p.grad for p in batch ])
+ p_stacked.grad = grad
+
+ stacked_params_dict[key] = p_stacked
+ pairs.append((p_stacked, state))
+
+ yield pairs # <-- calling code will do the actual optimization here!
+
+ for key, batch in batches.items():
+ stacked_params = stacked_params_dict[key]
+ for i, p in enumerate(batch): # batch is list of Parameter
+ p.copy_(stacked_params[i])
-class ScaledAdam(Optimizer):
+class ScaledAdam(BatchedOptimizer):
"""
Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update
proportional to the norm of that parameter; and also learn the scale of the parameter,
@@ -85,7 +167,6 @@ class ScaledAdam(Optimizer):
):
-
defaults = dict(
lr=lr,
clipping_scale=clipping_scale,
@@ -121,26 +202,30 @@ class ScaledAdam(Optimizer):
batch = True
for group in self.param_groups:
+ with self.batched_params(group["params"]) as batches:
- for i,p in enumerate(group["params"]):
- state = self.state[p]
+ # batches is list of pairs (stacked_param, state). stacked_param is like
+ # a regular parameter, and will have a .grad, but the 1st dim corresponds to
+ # a stacking dim, it is not a real dim.
- # Perform optimization step
- grad = p.grad
- if grad is not None and grad.is_sparse:
- raise RuntimeError(
- "ScaledAdam optimizer does not support sparse gradients"
- )
- # State initialization
- if len(state) == 0:
- self._init_state(group, p, state)
+ if len(batches[0][1]) == 0: # if len(first state) == 0: not yet initialized
+ clipping_scale = 1
+ else:
+ clipping_scale = self._get_clipping_scale(group, batches)
- if i == 0:
- clipping_scale = self._get_clipping_scale(group, p, state)
+ for p, state in batches:
+ # Perform optimization step.
+ # grad is not going to be None, we handled that when creating the batches.
+ grad = p.grad
+ if grad.is_sparse:
+ raise RuntimeError(
+ "ScaledAdam optimizer does not support sparse gradients"
+ )
+ # State initialization
+ if len(state) == 0:
+ self._init_state(group, p, state)
- if grad is None:
- continue
- self._step_one_batch(group, p, state, clipping_scale)
+ self._step_one_batch(group, p, state, clipping_scale)
return loss
@@ -151,14 +236,16 @@ class ScaledAdam(Optimizer):
p: Tensor,
state: dict):
"""
- Initializes state dict for parameter 'p'.
+ Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p
+ is actually the batch dimension, corresponding to batched-together
+ parameters of a given shape.
+
Args:
group: Dict to look up configuration values.
p: The parameter that we are initializing the state for
state: Dict from string to whatever state we are initializing
"""
- eps = group["eps"]
size_update_period = group["size_update_period"]
state["step"] = 0
@@ -175,16 +262,22 @@ class ScaledAdam(Optimizer):
p, memory_format=torch.preserve_format
)
+ batch_size = p.shape[0]
+ numel = p.numel() // batch_size
numel = p.numel()
+
if numel > 1:
# "param_rms" just periodically records the scalar root-mean-square value of
# the parameter tensor.
- param_rms = (p**2).mean().sqrt() + eps
+ # it has a shape like (batch_size, 1, 1, 1, 1)
+ param_rms = (p**2).mean(dim=list(range(1, p.ndim)),
+ keepdim=True).sqrt()
state["param_rms"] = param_rms
state["scale_exp_avg_sq"] = torch.zeros_like(param_rms)
- state["scale_grads"] = torch.zeros(size_update_period, **kwargs)
+ state["scale_grads"] = torch.zeros(size_update_period, *param_rms.shape,
+ **kwargs)
# exp_avg_sq is the weighted sum of scaled gradients. as in Adam.
@@ -194,45 +287,49 @@ class ScaledAdam(Optimizer):
def _get_clipping_scale(self,
group: dict,
- p: Tensor,
- state: dict) -> float:
+ pairs: List[Tuple[Tensor, dict]]) -> float:
"""
Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients
by this amount before applying the rest of the update.
- This function is only to be called for the first parameter in the group.
+
+ Args:
+ group: the parameter group, an item in self.param_groups
+ pairs: a list of pairs of (param, state) where param is a batched set of parameters, with a .grad
+ (1st dim is batch dim) and state is the state-dict where optimization parameters are kept.
"""
+ assert len(pairs) >= 1
clipping_scale = group["clipping_scale"]
- step = state["step"]
+ (first_p, first_state) = pairs[0]
+ step = first_state["step"]
if clipping_scale is None or step == 0:
# no clipping. return early on step == 0 because the other
- # parameters' state won't have been initialize yet.
+ # parameters' state won't have been initialized yet.
return 1.0
clipping_update_period = group["clipping_update_period"]
- tot_sumsq = torch.tensor(0.0, device=p.device)
- for p in group["params"]:
- state = self.state[p]
+ tot_sumsq = torch.tensor(0.0, device=first_p.device)
+ for (p, state) in pairs:
grad = p.grad
- if grad is None:
- continue
if grad.is_sparse:
raise RuntimeError(
"ScaledAdam optimizer does not support sparse gradients"
)
- if p.numel() == 1:
+ if p.numel() == p.shape[0]: # a batch of scalars
tot_sumsq += (grad**2).sum() # sum() to change shape [1] to []
else:
- tot_sumsq += (grad**2).sum() * (state["param_rms"] ** 2)
+ tot_sumsq += ((grad * state["param_rms"])**2).sum()
+
tot_norm = tot_sumsq.sqrt()
- if not "model_norms" in state:
- state["model_norms"] = torch.zeros(clipping_update_period,
+ if not "model_norms" in first_state:
+ first_state["model_norms"] = torch.zeros(clipping_update_period,
device=p.device)
- state["model_norms"][step % clipping_update_period] = tot_norm
+ first_state["model_norms"][step % clipping_update_period] = tot_norm
if step % clipping_update_period == 0:
- # We don't reach here if step == 0 because we
- # would have returned above.
- sorted_norms = state["model_norms"].sort()[0].to('cpu')
+ # Print some stats.
+ # We don't reach here if step == 0 because we would have returned
+ # above.
+ sorted_norms = first_state["model_norms"].sort()[0].to('cpu')
quartiles = []
for n in range(0, 5):
index = min(clipping_update_period - 1,
@@ -241,10 +338,10 @@ class ScaledAdam(Optimizer):
median = quartiles[2]
threshold = clipping_scale * median
- state["model_norm_threshold"] = threshold
- percent_clipped = (state["num_clipped"] * 100.0 / clipping_update_period
- if "num_clipped" in state else 0.0)
- state["num_clipped"] = 0
+ first_state["model_norm_threshold"] = threshold
+ percent_clipped = (first_state["num_clipped"] * 100.0 / clipping_update_period
+ if "num_clipped" in first_state else 0.0)
+ first_state["num_clipped"] = 0
quartiles = ' '.join([ '%.3e' % x for x in quartiles ])
logging.info(f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, "
f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}")
@@ -252,11 +349,11 @@ class ScaledAdam(Optimizer):
if step < clipping_update_period:
return 1.0 # We have not yet estimated a norm to clip to.
else:
- model_norm_threshold = state["model_norm_threshold"]
+ model_norm_threshold = first_state["model_norm_threshold"]
ans = min(1.0,
(model_norm_threshold / (tot_norm + 1.0e-20)).item())
if ans < 1.0:
- state["num_clipped"] += 1
+ first_state["num_clipped"] += 1
if ans < 0.1:
logging.warn(f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}")
return ans
@@ -268,10 +365,12 @@ class ScaledAdam(Optimizer):
state: dict,
clipping_scale: float):
"""
- Do the step for parameter p.
+ Do the step for one parameter, which is actually going to be a batch of
+ `real` parameters, with dim 0 as the batch dim.
Args:
group: dict to look up configuration values
- p: parameter to update
+ p: parameter to update (actually multiple parameters stacked together
+ as a batch)
state: state-dict for p, to look up the optimizer state
"""
lr = group["lr"]
@@ -285,14 +384,17 @@ class ScaledAdam(Optimizer):
delta = state["delta"]
delta.mul_(beta1)
- numel = p.numel()
+ batch_size = p.shape[0]
+ numel = p.numel() // batch_size
if numel > 1:
# Update the size/scale of p, and set param_rms
scale_grads = state["scale_grads"]
- scale_grads[step % size_update_period] = (p * grad).sum()
+ scale_grads[step % size_update_period] = (p * grad).sum(
+ dim=list(range(1, p.ndim)), keepdim=True)
if step % size_update_period == size_update_period - 1:
- param_rms = state["param_rms"]
- param_rms.copy_((p**2).mean().sqrt())
+ param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..)
+ param_rms.copy_((p**2).mean(dim=list(range(1, p.ndim)),
+ keepdim=True).sqrt())
if step > 0:
# self._size_update() learns the overall scale on the
# parameter, by shrinking or expanding it.
@@ -322,7 +424,7 @@ class ScaledAdam(Optimizer):
Args:
group: dict to look up configuration values
- scale_grads: a tensor of shape (size_update_period) containing
+ scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing
grads w.r.t. the scales.
p: The parameter to update
state: The state-dict of p
@@ -335,16 +437,17 @@ class ScaledAdam(Optimizer):
param_max_rms = group["param_max_rms"]
eps = group["eps"]
step = state["step"]
+ batch_size = p.shape[0]
size_update_period = scale_grads.shape[0]
# correct beta2 for the size update period: we will have
# faster decay at this level.
beta2_corr = beta2 ** size_update_period
- scale_exp_avg_sq = state["scale_exp_avg_sq"] # scalar
+ scale_exp_avg_sq = state["scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..)
scale_exp_avg_sq.mul_(beta2_corr).add_(
- (scale_grads ** 2).mean(), # mean over `size_update_period` elements
- alpha=1-beta2_corr)
+ (scale_grads ** 2).mean(dim=0), # mean over dim `size_update_period`
+ alpha=1-beta2_corr) # shape is (batch_size, 1, 1, ...)
# The 1st time we reach here is when size_step == 1.
size_step = (step + 1) // size_update_period
@@ -354,7 +457,7 @@ class ScaledAdam(Optimizer):
denom = scale_exp_avg_sq.sqrt() + eps
- scale_step = -size_lr * (bias_correction2 ** 0.5) * scale_grads.sum() / denom
+ scale_step = -size_lr * (bias_correction2 ** 0.5) * scale_grads.sum(dim=0) / denom
is_too_small = (param_rms < param_min_rms)
is_too_large = (param_rms > param_max_rms)
@@ -373,8 +476,8 @@ class ScaledAdam(Optimizer):
p: Tensor,
state: dict):
"""
- This function does the core update of self.step(), in the case where the tensor
- has more than 1 elements.
+ This function does the core update of self.step(), in the case where the members of
+ the batch have more than 1 element.
Args:
group: A dict which will be used to look up configuration values
@@ -391,7 +494,6 @@ class ScaledAdam(Optimizer):
param_min_rms = group["param_min_rms"]
step = state["step"]
-
exp_avg_sq = state["exp_avg_sq"]
exp_avg_sq.mul_(beta2).addcmul_(grad, grad,
value=(1-beta2))
@@ -427,7 +529,7 @@ class ScaledAdam(Optimizer):
lr = group["lr"] * group["scalar_lr_scale"]
grad = p.grad
- exp_avg_sq = state["exp_avg_sq"] # scalar
+ exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad,
value=1-beta2)