diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index b86948067..79acd7ca3 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -25,73 +25,9 @@ from torch.optim import Optimizer import logging -class BatchedOptimizer(Optimizer): - """ - This class adds to class Optimizer the capability to optimize parameters in batches: - it will stack the parameters and their grads for you so the optimizer can work - on tensors with an extra leading dimension. This is intended for speed with GPUs, - as it reduces the number of kernels launched in the optimizer. - - Args: - params: - """ - def __init__(self, params, defaults): - super(BatchedOptimizer, self).__init__(params, defaults) - def batched_params(self, param_group): - """ - This function returns an iterator over tuples (p, state), where - p is a `fake` parameter that is stacked (over axis 0) from real parameters - that share the same shape, and its gradient is also stacked; - `state` is the state corresponding to this fake parameter - (it will be physically located in the "state" for one of the real - parameters, the last one that has any particular shape and dtype). - - The idea is, instead of doing: - - for p in group["params"]: - state = self.state[p] - ... - - you can do: - - for p, state in self.batched_params(group["params"]): - ... - - - Args: - group: a parameter group, which is a list of parameters; should be - one of self.groups. - """ - batches = defaultdict(list) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter - - for p in param_group: - assert p.grad is not None and "All parameters must have a grad when you call optimizer.step()" - assert not p.grad.is_sparse and "Sparse gradients not supported." - key = (str(p.dtype), *p.shape) - batches[key].append(p) - - for p in param_group: - key = (str(p.dtype), *p.shape) - batch = batches[key] - if p is batch[0]: # if this is the 1st param in the batch... - # we arbitrarily store the state in the - # state corresponding to the 1st parameter in the - # group. class Optimizer will take care of saving/loading state. - state = self.state[p] - p_stacked = torch.stack(batch) - grad = torch.stack([p.grad for p in batch]) - p_stacked.grad = grad - yield p_stacked, state # <-- calling code will do the actual optimization here! - # Now un-stack the parameter changes - for i,p in enumerate(batch): - p.copy_(p_stacked[i]) - - - - -class ScaledAdam(BatchedOptimizer): +class ScaledAdam(Optimizer): """ Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update proportional to the norm of that parameter; and also learn the scale of the parameter, @@ -170,7 +106,9 @@ class ScaledAdam(BatchedOptimizer): batch = True for group in self.param_groups: - for p, state in self.batched_params(group["params"]): + for p in group["params"]: + state = self.state[p] + # Perform optimization step grad = p.grad if grad.is_sparse: @@ -191,14 +129,11 @@ class ScaledAdam(BatchedOptimizer): p: Tensor, state: dict): """ - Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p - is actually the batch dimension, corresponding to batched-together - parameters of a given shape. + Initializes state dict for parameter 'p'. Args: group: Dict to look up configuration values. - p: The parameter that we are initializing the state for, actually - will be stacked-together "real" parameters + p: The parameter that we are initializing the state for state: Dict from string to whatever state we are initializing """ eps = group["eps"] @@ -218,20 +153,17 @@ class ScaledAdam(BatchedOptimizer): p, memory_format=torch.preserve_format ) - batch_size = p.shape[0] - numel = p.numel() // batch_size + numel = p.numel() if numel > 1: # "param_rms" just periodically records the scalar root-mean-square value of # the parameter tensor. - # it has a shape like (batch_size, 1, 1, 1, 1) - param_rms = (p**2).mean(dim=list(range(1, p.ndim)), - keepdim=True).sqrt() + eps + param_rms = (p**2).mean().sqrt() + eps state["param_rms"] = param_rms state["scale_exp_avg_sq"] = torch.zeros_like(param_rms) - state["scale_grads"] = torch.zeros(size_update_period, *param_rms.shape, - **kwargs) + state["scale_grads"] = torch.zeros(size_update_period, **kwargs) + # exp_avg_sq is the weighted sum of scaled gradients. as in Adam. state["exp_avg_sq"] = torch.zeros_like( @@ -244,12 +176,10 @@ class ScaledAdam(BatchedOptimizer): p: Tensor, state: dict): """ - Do the step for one parameter, which is actually going to be a batch of - `real` parameters, with dim 0 as the batch dim. + Do the step for parameter p. Args: group: dict to look up configuration values - p: parameter to update (actually multiple parameters stacked together - as a batch) + p: parameter to update state: state-dict for p, to look up the optimizer state """ lr = group["lr"] @@ -262,17 +192,14 @@ class ScaledAdam(BatchedOptimizer): delta = state["delta"] delta.mul_(beta1) - batch_size = p.shape[0] - numel = p.numel() // batch_size + numel = p.numel() if numel > 1: # Update the size/scale of p, and set param_rms scale_grads = state["scale_grads"] - scale_grads[step % size_update_period] = (p * grad).sum( - dim=list(range(1, p.ndim)), keepdim=True) + scale_grads[step % size_update_period] = (p * grad).sum() if step % size_update_period == size_update_period - 1: - param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..) - param_rms.copy_((p**2).mean(dim=list(range(1, p.ndim)), - keepdim=True).sqrt().clamp_(min=eps)) + param_rms = state["param_rms"] + param_rms.copy_((p**2).mean().sqrt().clamp_(min=eps)) if step > 0: # self._size_update() learns the overall scale on the # parameter, by shrinking or expanding it. @@ -302,7 +229,7 @@ class ScaledAdam(BatchedOptimizer): Args: group: dict to look up configuration values - scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing + scale_grads: a tensor of shape (size_update_period) containing grads w.r.t. the scales. p: The parameter to update state: The state-dict of p @@ -315,16 +242,15 @@ class ScaledAdam(BatchedOptimizer): param_max_rms = group["param_max_rms"] step = state["step"] - batch_size = p.shape[0] size_update_period = scale_grads.shape[0] # correct beta2 for the size update period: we will have # faster decay at this level. beta2_corr = beta2 ** size_update_period - scale_exp_avg_sq = state["scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..) + scale_exp_avg_sq = state["scale_exp_avg_sq"] # scalar scale_exp_avg_sq.mul_(beta2_corr).add_( - (scale_grads ** 2).mean(dim=0), # mean over dim `size_update_period` - alpha=1-beta2_corr) # shape is (batch_size, 1, 1, ...) + (scale_grads ** 2).mean(), # mean over `size_update_period` elements + alpha=1-beta2_corr) # The 1st time we reach here is when size_step == 1. size_step = (step + 1) // size_update_period @@ -333,9 +259,9 @@ class ScaledAdam(BatchedOptimizer): # at the start of training. eps = 1.0e-10 # this value is not critical. - denom = scale_exp_avg_sq.sqrt() + eps # shape: (batch_size, 1, 1, ..) + denom = scale_exp_avg_sq.sqrt() + eps - scale_step = -size_lr * (bias_correction2 ** 0.5) * scale_grads.sum(dim=0) / denom + scale_step = -size_lr * (bias_correction2 ** 0.5) * scale_grads.sum() / denom is_too_small = (param_rms < param_min_rms) is_too_large = (param_rms > param_max_rms) @@ -353,17 +279,13 @@ class ScaledAdam(BatchedOptimizer): state: dict): """ This function does the core update of self.step(), in the case where the tensor - has more than 1 elements. It multiplies the moving-average gradient tensor by a - state["lr_{dim}"] on each non-trivial axis/dimension, does a length normalization, - and takes a step in that direction. + has more than 1 elements. Args: group: A dict which will be used to look up configuration values p: The parameter to be updated grad: The grad of p state: The state-dict corresponding to parameter p - batch: if true, the first dimension of p is a batch dim, i.e. a batch of - parameter matrices. This function modifies p. """ @@ -401,7 +323,7 @@ class ScaledAdam(BatchedOptimizer): state: dict): """ A simplified form of the core update for scalar tensors, where we cannot get a good - estimate of the parameter rms. p will not be a scalar, because it has a batch dim. + estimate of the parameter rms. """ beta1, beta2 = group["betas"] scalar_max = group["scalar_max"] @@ -409,7 +331,7 @@ class ScaledAdam(BatchedOptimizer): lr = group["lr"] grad = p.grad - exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,) + exp_avg_sq = state["exp_avg_sq"] # scalar exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1-beta2)