diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index 6294592a8..c1589b907 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -23,11 +23,93 @@ import random from torch import Tensor from torch.optim import Optimizer import logging +import contextlib + + + +class BatchedOptimizer(Optimizer): + """ + This class adds to class Optimizer the capability to optimize parameters in batches: + it will stack the parameters and their grads for you so the optimizer can work + on tensors with an extra leading dimension. This is intended for speed with GPUs, + as it reduces the number of kernels launched in the optimizer. + + Args: + params: + """ + def __init__(self, params, defaults): + super(BatchedOptimizer, self).__init__(params, defaults) + + + + @contextlib.contextmanager + def batched_params(self, param_group): + """ + This function returns (technically, yields) a list of + of tuples (p, state), where + p is a `fake` parameter that is stacked (over axis 0) from real parameters + that share the same shape, and its gradient is also stacked; + `state` is the state corresponding to this batch of parameters + (it will be physically located in the "state" for one of the real + parameters, the last one that has any particular shape and dtype). + + This function is decorated as a context manager so that it can + write parameters back to their "real" locations. + + The idea is, instead of doing: + + for p in group["params"]: + state = self.state[p] + ... + + you can do: + + with self.batched_params(group["params"]) as batches: + for p, state in batches: + ... + + + Args: + group: a parameter group, which is a list of parameters; should be + one of self.groups. + """ + batches = defaultdict(list) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter + + for p in param_group: + assert not p.grad.is_sparse and "Sparse gradients not supported." + key = (str(p.dtype), *p.shape) + batches[key].append(p) + + stacked_params_dict = dict() + + pairs = [] + for p in param_group: + key = (str(p.dtype), *p.shape) + batch = batches[key] + if p is batch[0]: + # if this is the 1st param in the batch... + # we arbitrarily store the state in the + # state corresponding to the 1st parameter in the + # group. class Optimizer will take care of saving/loading state. + state = self.state[p] + p_stacked = torch.stack(batch) + grad = torch.stack([torch.zeros_like(p) if p.grad is None else p.grad for p in batch ]) + p_stacked.grad = grad + + stacked_params_dict[key] = p_stacked + pairs.append((p_stacked, state)) + + yield pairs # <-- calling code will do the actual optimization here! + + for key, batch in batches.items(): + stacked_params = stacked_params_dict[key] + for i, p in enumerate(batch): # batch is list of Parameter + p.copy_(stacked_params[i]) -class ScaledAdam(Optimizer): +class ScaledAdam(BatchedOptimizer): """ Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update proportional to the norm of that parameter; and also learn the scale of the parameter, @@ -85,7 +167,6 @@ class ScaledAdam(Optimizer): ): - defaults = dict( lr=lr, clipping_scale=clipping_scale, @@ -121,26 +202,30 @@ class ScaledAdam(Optimizer): batch = True for group in self.param_groups: + with self.batched_params(group["params"]) as batches: - for i,p in enumerate(group["params"]): - state = self.state[p] + # batches is list of pairs (stacked_param, state). stacked_param is like + # a regular parameter, and will have a .grad, but the 1st dim corresponds to + # a stacking dim, it is not a real dim. - # Perform optimization step - grad = p.grad - if grad is not None and grad.is_sparse: - raise RuntimeError( - "ScaledAdam optimizer does not support sparse gradients" - ) - # State initialization - if len(state) == 0: - self._init_state(group, p, state) + if len(batches[0][1]) == 0: # if len(first state) == 0: not yet initialized + clipping_scale = 1 + else: + clipping_scale = self._get_clipping_scale(group, batches) - if i == 0: - clipping_scale = self._get_clipping_scale(group, p, state) + for p, state in batches: + # Perform optimization step. + # grad is not going to be None, we handled that when creating the batches. + grad = p.grad + if grad.is_sparse: + raise RuntimeError( + "ScaledAdam optimizer does not support sparse gradients" + ) + # State initialization + if len(state) == 0: + self._init_state(group, p, state) - if grad is None: - continue - self._step_one_batch(group, p, state, clipping_scale) + self._step_one_batch(group, p, state, clipping_scale) return loss @@ -151,14 +236,16 @@ class ScaledAdam(Optimizer): p: Tensor, state: dict): """ - Initializes state dict for parameter 'p'. + Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p + is actually the batch dimension, corresponding to batched-together + parameters of a given shape. + Args: group: Dict to look up configuration values. p: The parameter that we are initializing the state for state: Dict from string to whatever state we are initializing """ - eps = group["eps"] size_update_period = group["size_update_period"] state["step"] = 0 @@ -175,16 +262,22 @@ class ScaledAdam(Optimizer): p, memory_format=torch.preserve_format ) + batch_size = p.shape[0] + numel = p.numel() // batch_size numel = p.numel() + if numel > 1: # "param_rms" just periodically records the scalar root-mean-square value of # the parameter tensor. - param_rms = (p**2).mean().sqrt() + eps + # it has a shape like (batch_size, 1, 1, 1, 1) + param_rms = (p**2).mean(dim=list(range(1, p.ndim)), + keepdim=True).sqrt() state["param_rms"] = param_rms state["scale_exp_avg_sq"] = torch.zeros_like(param_rms) - state["scale_grads"] = torch.zeros(size_update_period, **kwargs) + state["scale_grads"] = torch.zeros(size_update_period, *param_rms.shape, + **kwargs) # exp_avg_sq is the weighted sum of scaled gradients. as in Adam. @@ -194,45 +287,49 @@ class ScaledAdam(Optimizer): def _get_clipping_scale(self, group: dict, - p: Tensor, - state: dict) -> float: + pairs: List[Tuple[Tensor, dict]]) -> float: """ Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients by this amount before applying the rest of the update. - This function is only to be called for the first parameter in the group. + + Args: + group: the parameter group, an item in self.param_groups + pairs: a list of pairs of (param, state) where param is a batched set of parameters, with a .grad + (1st dim is batch dim) and state is the state-dict where optimization parameters are kept. """ + assert len(pairs) >= 1 clipping_scale = group["clipping_scale"] - step = state["step"] + (first_p, first_state) = pairs[0] + step = first_state["step"] if clipping_scale is None or step == 0: # no clipping. return early on step == 0 because the other - # parameters' state won't have been initialize yet. + # parameters' state won't have been initialized yet. return 1.0 clipping_update_period = group["clipping_update_period"] - tot_sumsq = torch.tensor(0.0, device=p.device) - for p in group["params"]: - state = self.state[p] + tot_sumsq = torch.tensor(0.0, device=first_p.device) + for (p, state) in pairs: grad = p.grad - if grad is None: - continue if grad.is_sparse: raise RuntimeError( "ScaledAdam optimizer does not support sparse gradients" ) - if p.numel() == 1: + if p.numel() == p.shape[0]: # a batch of scalars tot_sumsq += (grad**2).sum() # sum() to change shape [1] to [] else: - tot_sumsq += (grad**2).sum() * (state["param_rms"] ** 2) + tot_sumsq += ((grad * state["param_rms"])**2).sum() + tot_norm = tot_sumsq.sqrt() - if not "model_norms" in state: - state["model_norms"] = torch.zeros(clipping_update_period, + if not "model_norms" in first_state: + first_state["model_norms"] = torch.zeros(clipping_update_period, device=p.device) - state["model_norms"][step % clipping_update_period] = tot_norm + first_state["model_norms"][step % clipping_update_period] = tot_norm if step % clipping_update_period == 0: - # We don't reach here if step == 0 because we - # would have returned above. - sorted_norms = state["model_norms"].sort()[0].to('cpu') + # Print some stats. + # We don't reach here if step == 0 because we would have returned + # above. + sorted_norms = first_state["model_norms"].sort()[0].to('cpu') quartiles = [] for n in range(0, 5): index = min(clipping_update_period - 1, @@ -241,10 +338,10 @@ class ScaledAdam(Optimizer): median = quartiles[2] threshold = clipping_scale * median - state["model_norm_threshold"] = threshold - percent_clipped = (state["num_clipped"] * 100.0 / clipping_update_period - if "num_clipped" in state else 0.0) - state["num_clipped"] = 0 + first_state["model_norm_threshold"] = threshold + percent_clipped = (first_state["num_clipped"] * 100.0 / clipping_update_period + if "num_clipped" in first_state else 0.0) + first_state["num_clipped"] = 0 quartiles = ' '.join([ '%.3e' % x for x in quartiles ]) logging.info(f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, " f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}") @@ -252,11 +349,11 @@ class ScaledAdam(Optimizer): if step < clipping_update_period: return 1.0 # We have not yet estimated a norm to clip to. else: - model_norm_threshold = state["model_norm_threshold"] + model_norm_threshold = first_state["model_norm_threshold"] ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item()) if ans < 1.0: - state["num_clipped"] += 1 + first_state["num_clipped"] += 1 if ans < 0.1: logging.warn(f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}") return ans @@ -268,10 +365,12 @@ class ScaledAdam(Optimizer): state: dict, clipping_scale: float): """ - Do the step for parameter p. + Do the step for one parameter, which is actually going to be a batch of + `real` parameters, with dim 0 as the batch dim. Args: group: dict to look up configuration values - p: parameter to update + p: parameter to update (actually multiple parameters stacked together + as a batch) state: state-dict for p, to look up the optimizer state """ lr = group["lr"] @@ -285,14 +384,17 @@ class ScaledAdam(Optimizer): delta = state["delta"] delta.mul_(beta1) - numel = p.numel() + batch_size = p.shape[0] + numel = p.numel() // batch_size if numel > 1: # Update the size/scale of p, and set param_rms scale_grads = state["scale_grads"] - scale_grads[step % size_update_period] = (p * grad).sum() + scale_grads[step % size_update_period] = (p * grad).sum( + dim=list(range(1, p.ndim)), keepdim=True) if step % size_update_period == size_update_period - 1: - param_rms = state["param_rms"] - param_rms.copy_((p**2).mean().sqrt()) + param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..) + param_rms.copy_((p**2).mean(dim=list(range(1, p.ndim)), + keepdim=True).sqrt()) if step > 0: # self._size_update() learns the overall scale on the # parameter, by shrinking or expanding it. @@ -322,7 +424,7 @@ class ScaledAdam(Optimizer): Args: group: dict to look up configuration values - scale_grads: a tensor of shape (size_update_period) containing + scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing grads w.r.t. the scales. p: The parameter to update state: The state-dict of p @@ -335,16 +437,17 @@ class ScaledAdam(Optimizer): param_max_rms = group["param_max_rms"] eps = group["eps"] step = state["step"] + batch_size = p.shape[0] size_update_period = scale_grads.shape[0] # correct beta2 for the size update period: we will have # faster decay at this level. beta2_corr = beta2 ** size_update_period - scale_exp_avg_sq = state["scale_exp_avg_sq"] # scalar + scale_exp_avg_sq = state["scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..) scale_exp_avg_sq.mul_(beta2_corr).add_( - (scale_grads ** 2).mean(), # mean over `size_update_period` elements - alpha=1-beta2_corr) + (scale_grads ** 2).mean(dim=0), # mean over dim `size_update_period` + alpha=1-beta2_corr) # shape is (batch_size, 1, 1, ...) # The 1st time we reach here is when size_step == 1. size_step = (step + 1) // size_update_period @@ -354,7 +457,7 @@ class ScaledAdam(Optimizer): denom = scale_exp_avg_sq.sqrt() + eps - scale_step = -size_lr * (bias_correction2 ** 0.5) * scale_grads.sum() / denom + scale_step = -size_lr * (bias_correction2 ** 0.5) * scale_grads.sum(dim=0) / denom is_too_small = (param_rms < param_min_rms) is_too_large = (param_rms > param_max_rms) @@ -373,8 +476,8 @@ class ScaledAdam(Optimizer): p: Tensor, state: dict): """ - This function does the core update of self.step(), in the case where the tensor - has more than 1 elements. + This function does the core update of self.step(), in the case where the members of + the batch have more than 1 element. Args: group: A dict which will be used to look up configuration values @@ -391,7 +494,6 @@ class ScaledAdam(Optimizer): param_min_rms = group["param_min_rms"] step = state["step"] - exp_avg_sq = state["exp_avg_sq"] exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1-beta2)) @@ -427,7 +529,7 @@ class ScaledAdam(Optimizer): lr = group["lr"] * group["scalar_lr_scale"] grad = p.grad - exp_avg_sq = state["exp_avg_sq"] # scalar + exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1-beta2)