From 02e409b6ceabb1b3b45b4031213ebe80ea2da378 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 16 Aug 2021 00:00:53 +0800 Subject: [PATCH] Replace warmup with lr scheduler. --- .../madam_no_warmup.py | 1072 ----------------- .../conformer_ctc_madam_no_warmup/train.py | 36 +- 2 files changed, 28 insertions(+), 1080 deletions(-) delete mode 100644 egs/librispeech/ASR/conformer_ctc_madam_no_warmup/madam_no_warmup.py diff --git a/egs/librispeech/ASR/conformer_ctc_madam_no_warmup/madam_no_warmup.py b/egs/librispeech/ASR/conformer_ctc_madam_no_warmup/madam_no_warmup.py deleted file mode 100644 index f26f89b05..000000000 --- a/egs/librispeech/ASR/conformer_ctc_madam_no_warmup/madam_no_warmup.py +++ /dev/null @@ -1,1072 +0,0 @@ -import logging -import math -import random -from typing import List, Tuple - -import torch -from torch import Tensor, nn -from torch.optim.optimizer import Optimizer - -# After this many warnings about infinite gradients we'll die. -inf_grad_count = 0 -inf_grad_max_count = 20 - - -class Madam(Optimizer): - r"""Madam is a modification of the Adam algorithm, with various changes - intended to support certain "common-sense" ideas and solve common - pathologies that can happen particularly in transformer-type models that - have multiplication of parameters (particularly, key and query matrices)-- - these can be vulnerable to "subspace loss" where, if you have any l2 - regularization, certain subspaces in the key/query space might get - regularized towards zero. We solve this with a special formula that - changes how the l2/weight-decay is done (see compute_l2_grad()). - I'll try to write the math down at some point. This formula only - applies to tensors that have at least two dimensions; for one-dimensional - tensors we simply won't do l2 regularization. - - One more thing-- there is a special pathology that can sometimes afflict - models like LSTMs, where a particular element of a minibatch experiences - gradient blowup in the backward pass. We'd like to identify such cases and - fix it somehow, e.g. by removing or scaling down the gradient for that - particular minibatch. We can identify and somewhat fix this by seeing that the - gradient norm (computed over all the parameters in a parameter group) is - much more than on previous minibatches, and limiting it to (the preceding - average step size times some constant). - - Like most optimization algorithms, for this to work well you need to - have an appropriate learning rate schedule, either decreasing with - time, or increasing (warm-up) and then decreasing. The LR schedule may - possibly need to decrease a little more aggressively than you would with - Adam, or at least have smaller values overall than Adam, because - the smaller parameters will mean the effective (relative) learning - rate is higher. - - This is modified from PyTorch's optim/adam.py - - - Args: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups - lr (float, optional): learning rate (default: 1e-3) - betas (Tuple[float, float], optional): coefficients used for computing - running averages of gradient and its square (default: (0.9, 0.999)) - eps (float, optional): term added to the denominator to improve - numerical stability (default: 1e-8) - grad_norm_buffer_size (int, optional): Buffer size used in detecting - minibatches with unusually large gradients and scaling them down. - limit_grad_factor (float): factor by which we don't allow the - gradient to be greater than the average of previous gradients - (we'll scale the gradient down, over the whole param-group, - to enforce this). Must be greater than 1. Set to float('inf') - to disable norm clipping. - min_target_rms: A floor on the "target rms" of each Tensor, so - that Tensors that, when initialized, have less than this - rms value will have their target rms value floored to this - l2: True to enable l2 regularization - l2_period: You may set this to a value greater than one to save - computation by only periodically doing the l2 update. - We include a scaling factor in the formula so that, as far - as possible (for small learning rates) this shouldn't affect - the results. (Note: this probably isn't necessary to set, - since it turns out the update is quite fast, at least on GPU, - and the gradient clipping is actually more of a problem) - - - .. _Adam\: A Method for Stochastic Optimization: - https://arxiv.org/abs/1412.6980 - .. _Decoupled Weight Decay Regularization: - https://arxiv.org/abs/1711.05101 - .. _On the Convergence of Adam and Beyond: - https://openreview.net/forum?id=ryQu7f-RZ - - """ - - def __init__( - self, - params, - lr: float = 1e-3, - betas: Tuple[float, float] = (0.9, 0.999), - eps: float = 1e-8, - grad_norm_buffer_size: int = 8, - limit_grad_factor: float = 2.0, - min_target_rms: float = 0.05, - l2: bool = True, - l2_period: int = 1, - ): - if not 0.0 <= lr: - raise ValueError("Invalid learning rate: {}".format(lr)) - if not 0.0 <= eps: - raise ValueError("Invalid epsilon value: {}".format(eps)) - if not 0.0 <= betas[0] < 1.0: - raise ValueError( - "Invalid beta parameter at index 0: {}".format(betas[0]) - ) - if not 0.0 <= betas[1] < 1.0: - raise ValueError( - "Invalid beta parameter at index 1: {}".format(betas[1]) - ) - if not ( - isinstance(grad_norm_buffer_size, int) and grad_norm_buffer_size > 1 - ): - raise ValueError( - "Invalid grad_norm_buffer_size value: {}".format( - grad_norm_buffer_size - ) - ) - if not limit_grad_factor > 1.0: - raise ValueError( - "Invalid limit_grad_factor: {}".format(limit_grad_factor) - ) - if not isinstance(l2, bool): - raise ValueError("Invalid l2 value: {}".format(l2)) - if not l2_period >= 1: - raise ValueError("Invalid l2_period value: {}".format(l2_period)) - defaults = dict( - lr=lr, - betas=betas, - eps=eps, - grad_norm_buffer_size=grad_norm_buffer_size, - limit_grad_factor=limit_grad_factor, - l2=l2, - l2_period=l2_period, - min_target_rms=min_target_rms, - ) - super(Madam, self).__init__(params, defaults) - - @torch.no_grad() - def step(self, closure=None): - """Performs a single optimization step. - - Args: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. - """ - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - - beta1, beta2 = group["betas"] - grad_norm_buffer_size = group["grad_norm_buffer_size"] - limit_grad_factor = group["limit_grad_factor"] - min_target_rms = group["min_target_rms"] - - # The next 5 lists are part of the original Adam optimizer - params_with_grad = [] - grads = [] - exp_avgs = [] - exp_avg_sqs = [] - state_steps = [] - - # The next 3 lists are not part of the original Adam optimizer. - target_rms_values = ( - [] - ) # relates to weight decay. Target root-mean-square - # values of the elements of each parameter - # we are optimizing - prev_norm_stats = ( - [] - ) # contains Tensor with 2 elements each, the sum - # of the [sum_squared, count] of - # this parameter on previous minibatches (up to - # grad_norm_buffer_size minibatches) - cur_grad_norms = [] # and `cur_grad_norms` contains the squared l2 - # norm norm of this step's gradient for this - # parameter, as a Tensor. - - for p in group["params"]: - if p.grad is not None: - params_with_grad.append(p) - if p.grad.is_sparse: - raise RuntimeError( - "Adam does not support sparse gradients, please consider SparseAdam instead" - ) - grads.append(p.grad) - - state = self.state[p] - # Lazy state initialization - if len(state) == 0: - state["step"] = 0 - # Exponential moving average of gradient values - state["exp_avg"] = torch.zeros_like( - p, memory_format=torch.preserve_format - ) - # Exponential moving average of squared gradient values - state["exp_avg_sq"] = torch.zeros_like( - p, memory_format=torch.preserve_format - ) - - # The things below are not part of original Adam, they are the Madam extension.. - state["target_rms"] = _get_target_rms(p, min_target_rms) - # grad_norm_buf is a rotating buffer containing (grad_norm**2, count), where - # count is 1 for grad_norms that are set and 0 for those that are not set because - # we're near step 0 or because they were infinite. - state["grad_norm_buf"] = torch.zeros( - grad_norm_buffer_size, 2, device=p.device - ) - - exp_avgs.append(state["exp_avg"]) - exp_avg_sqs.append(state["exp_avg_sq"]) - - target_rms_values.append(state["target_rms"]) - - cur_step = state["step"] - if limit_grad_factor != float("inf"): - grad_norm_buf = state["grad_norm_buf"] - cur_grad_norm = ( - p.grad ** 2 - ).sum() # actually squared nom - prev_mean_norm = grad_norm_buf.sum( - 0 - ) # prev_mean_norm is a Tensor [ tot_norm_squared, count ] - grad_norm_buf[cur_step % grad_norm_buffer_size][ - 0 - ] = cur_grad_norm - grad_norm_buf[cur_step % grad_norm_buffer_size][ - 1 - ].fill_(1.0) - prev_norm_stats.append(prev_mean_norm) - cur_grad_norms.append(cur_grad_norm) - - # update the steps for each param group update - cur_step += 1 - state["step"] = cur_step - # record the step after step update - state_steps.append(cur_step) - - if limit_grad_factor != float("inf"): - self._apply_grad_norm_clipping( - group["params"], - prev_norm_stats, - cur_grad_norms, - grads, - limit_grad_factor, - grad_norm_buffer_size, - ) - - _madam( - params_with_grad, - grads, - exp_avgs, - exp_avg_sqs, - state_steps, - target_rms_values, - beta1=beta1, - beta2=beta2, - lr=group["lr"], - eps=group["eps"], - l2=group["l2"], - l2_period=group["l2_period"], - ) - - return loss - - def _apply_grad_norm_clipping( - self, - params_list, - prev_norm_stats: List[Tensor], - cur_grad_norms: List[Tensor], - grads: List[Tensor], - limit_grad_factor: float, - grad_norm_buffer_size: int, - ) -> None: - """ - This function applies gradient norm clipping for this parameter group if this - minibatch has substantially larger gradients in this param group than - recent minibatches. The idea is to catch cases like where an LSTM - happens to blow up in the backward pass, or some code bug causes very - large or infinite gradients on a particular minibatch; so we scale - down any very large gradients and zero infinite ones. - - Args: - params_list: some kind of iterable or list of params in this group - prev_norm_stats: a list which, for each parameter in this group - with a grad, contains a Tensor with 2 elements each, containing - # the [sum, count] of up to `grad_norm_buffer_size` - # norms of this parameter on previous minibatches; - cur_grad_norms: a list of Tensor containing, for each parameter in this group, - the norm of this step's gradient for this parameter. - grads: List of gradients with the same order as prev_norm_stats and - cur_grad_norms - limit_grad_factor: a float >1.0 (e.g. 4.0) that dictates - how-much-larger-than-average gradients we allow before clipping. - grad_norm_buffer_size: an int that determines the rolling buffer size over which - we store gradient norms - """ - num_params = len(prev_norm_stats) - assert len(grads) == num_params - - all_prev_norm_stats, all_cur_grad_norms = _to_device( - "cpu", torch.stack(prev_norm_stats), torch.stack(cur_grad_norms) - ) - assert all_prev_norm_stats.shape == (num_params, 2) - assert all_cur_grad_norms.shape == (num_params,) - - # divide totals by counts (i.e. counts of iterations were we stored - # a finite grad) - all_prev_grad_norms = ( - all_prev_norm_stats[:, 0] / all_prev_norm_stats[:, 1] - ) - # prev_norm and cur_norm are floats, they are actually squared norms. - prev_norm = all_prev_grad_norms.sum().item() - cur_norm = all_cur_grad_norms.sum().item() - - if prev_norm - prev_norm != 0.0: - # There were zero counts; fix this by using the current grad norm - # for affected parameters, and recompute all_prev_grad_norms and - # prev_norm. - for i in range(num_params): - if all_prev_norm_stats[i][1] == 0.0: - # if count is 0 and cur norm is finite, use cur norm as our estimate - # of previous norms. This would only be useful if some but not - # all params were in this situation of having no previous estimates. - cur = all_cur_grad_norms[i] - if cur - cur == 0.0: # finite.. - all_prev_norm_stats[i][0] = cur - all_prev_norm_stats[i][1] = 1.0 - else: - # 0.0 is a default; likely won't matter, as if we - # get infinite `cur`, we'll abandon this minibatch. - all_prev_norm_stats[i][0] = 0.0 - all_prev_grad_norms = ( - all_prev_norm_stats[:, 0] / all_prev_norm_stats[:, 1] - ) - prev_norm = all_prev_grad_norms.sum().item() - - # Deal with infinite gradients. - if cur_norm - cur_norm != 0: # cur_norm is infinite or NaN - global inf_grad_count - logging.warning( - f"Infinite gradient-norm detected (cur/prev: {cur_norm}/{prev_norm}): will " - f"zero grad ({inf_grad_count}/{inf_grad_max_count} times until dying)" - ) - inf_grad_count += 1 - if inf_grad_count >= inf_grad_max_count: - assert 0, "Reached max count of infinite gradient-norm stats" - # Zero all gradients in this group - for g in grads: - g[:] = 0.0 - # .. and zero the stored gradient norms in grad_norm_buf (so - # that infinities don't ruin our stats of previous batches) - for p in params_list: - if p.grad is not None: - state = self.state[p] - grad_norm_buf = state["grad_norm_buf"] - # cur_step is the location where we would have written the grad_norm. - # We didn't check if it was infinity before, because we didn't want to - # incur lots of GPU->CPU transfers. - cur_step = state["step"] - 1 - # Remove this 'bad' step from the buffer. - grad_norm_buf[cur_step % grad_norm_buffer_size][:] = 0.0 - else: - # cur_norm is finite. Check whether we have to clip this iteration's grad. - # we always remove infinities/NaNs from the buffer, so prev_norm should not - # be infinite or NaN. - assert prev_norm - prev_norm == 0.0 - # cur_norm and prev_norm are actually squared norms, so we need to - # square limit_grad_factor.. - limit_grad_factor2 = limit_grad_factor ** 2 - if cur_norm > prev_norm * limit_grad_factor2: - grad_factor2 = (prev_norm * limit_grad_factor2) / cur_norm - grad_factor = grad_factor2 ** 0.5 - cur_norm_f, prev_norm_f, grad_factor_f = ( - "%.2g" % cur_norm, - "%.2g" % prev_norm, - "%.2g" % grad_factor, - ) - logging.warning( - f"Gradient norm exceeds average of last {grad_norm_buffer_size} " - f"gradients times {limit_grad_factor}: cur/prev {cur_norm_f}/{prev_norm_f}: " - f"scaling it by {grad_factor_f}." - ) - for g in grads: - g[:] *= grad_factor - # .. and scale down the stored gradient norms in grad_norm_buf, to - # avoid the bound getting too loose too quickly. - for p in params_list: - if p.grad is not None: - state = self.state[p] - grad_norm_buf = state["grad_norm_buf"] - cur_step = state["step"] - 1 - # the buffer contains squared norms, so multiply by grad_factor2 - grad_norm_buf[cur_step % grad_norm_buffer_size][ - 0 - ] *= grad_factor2 - - -def _to_device(device, *args): - """ - Transfers a tuple of Tensors from one device to another, using a single transfer. Must have - same dtype but may have different shapes. - E.g. - (cpu_tensor_a, cpu_tensor_b) = _to_device('cpu', gpu_tensor_a, gpu_tensor_b) - """ - if device == args[0].device: - return args - else: - arg0 = args[0] - combined_src = torch.cat([x.reshape(-1) for x in args]) - combined_dest = combined_src.to(device) - dests = [] - offset = 0 - for src in args: - numels = src.numel() - dests.append( - combined_dest[offset : offset + numels].reshape(src.shape) - ) - offset += numels - return tuple(dests) - - -def _get_target_rms(x: Tensor, min_target_rms: float) -> Tensor: - """ - Returns Tensor with one element, representing a target root-mean-square - value of elements of x, that we consider "reasonable", and will use a - as a "target rms" in our modified weight-decay formula. It returns - the maximum of the current RMS of the values of x, and `min_target_rms`, - as a Tensor on the same device as x. - """ - with torch.no_grad(): - # `norm` is the 2-norm of x currently (and this function should be - # called right after parameter initialization) - rms = ((x ** 2).sum() / x.numel()).sqrt() - largest_dim = max(list(x.shape)) - numel = x.numel() - if min_target_rms > 0.0: - rms = rms.clamp(min=min_target_rms) - if ( - x.ndim > 1 and __name__ == "__main__" - ): # will only be used for x.ndim > 1. - print("Target rms = ", rms) # Print this in testing only. - return rms - - -def _madam( - params: List[Tensor], - grads: List[Tensor], - exp_avgs: List[Tensor], - exp_avg_sqs: List[Tensor], - state_steps: List[int], - target_rms_values: List[Tensor], - *, - beta1: float, - beta2: float, - lr: float, - eps: float, - l2: bool, - l2_period: int, -): - r"""This is a modification of adam() from torch's optim/_functional.py. - - It has been modified to: - (i) remove the amsgrad option; this shouldn't be as necessary due to - the adaptive gradient norm clipping we have added - (ii) add our special formula for l2 regularization. This doesn't have - any tunable parameters, other than the target standard deviation - of the elements of the tensor (which is passed in as target_rms). - Args: - params: list of Tensor, containing the parameters to be optimized - grads: list of Tensor, containing the gradients corresponding to - each of the params (grads[i] should correspond to params[i].grad, - although it may have undergone gradient clipping). - exp_avgs: list of Tensor, containing tensors with the same dimensions - as params and grads, that contain the moving-averages of - `grads`. - exp_avg_sqs: list of Tensor, containing tensors with the same dimensions - as params and grads, that contain the moving-averages of - `grads ** 2`. - state_steps: list of int, containing the step for each parameter (step >= 1) - target_rms_values: list of Tensor with one element each, containing the - target root-mean-square values of each parameter tensor in `params` - l2: a bool, where if true we will activate the l2 regularization - formula. - l2_period: an integer that determines how often (i.e. every how many - minibatches) we apply the l2 update. We include a scaling factor - so that as far as possible the result will not be too sensitive - to the value of this. - - beta1: decay factor for gradients, e.g. 0.9 - beta2: decay factor for gradients squared, e.g. 0.999 - lr: learning rate, e.g. 0.0001 - eps: a small constant used to prevent division by zero, e.g. 1.0e-8 - - See :class:`~torch.optim.Adam` for details. - """ - assert ( - len(params) - == len(grads) - == len(state_steps) - == len(exp_avgs) - == len(exp_avg_sqs) - == len(target_rms_values) - ) - - for i, param in enumerate(params): - - grad = grads[i] - - exp_avg = exp_avgs[i] - exp_avg_sq = exp_avg_sqs[i] - step = state_steps[i] - target_rms = target_rms_values[i] - - bias_correction1 = 1 - beta1 ** step - bias_correction2 = 1 - beta2 ** step - - do_l2 = param.ndim > 1 and l2 and step % l2_period == 0 - - if do_l2: - # This represents just the "noise term" of the gradient, i.e. the grad minus the - # running mean. We'll later divide by denom. - cur_grad_noise = grad - exp_avg - - # Decay the first and second moment running average coefficient - exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - - denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps) - - step_size = lr / bias_correction1 - - if not do_l2: - param.addcdiv_(exp_avg, denom, value=-step_size) - else: - # We can treat "pseudo_grad" as if it were a gradient (even though it's - # actually a gradient times a per-element learning rate). The analysis - # that we used to figure out what the l2 should be did not use the fact - # that the gradients were actually gradients, it simply analyzed it as a - # quantity that can be treated as close to zero-mean and with a certain - # structure of variance, and added to the param with the formula: - # - # param -= step_size * grad - # - # The original analysis assumed the gradients were independent from frame - # to frame; in fact these are not, but the important difference can be captured - # in a scalar `grad_scale` that expresses the scale of pseudo_grad relative - # to the independent gradients that we are effectively adding on each frame - # (but with a delay). - - pseudo_grad = exp_avg / denom - cur_pseudo_grad = cur_grad_noise / denom - - # grad_scale expresses the expected size of cur_pseudo_grad relative to the - # original grads if we had not done the moving-average; it is the sqrt of - # the sum of the squares of coefficients of previous gradients: - # c_n = (1-beta1) beta1^n, for - # n = 0, 1, .. - # .. plus one which is the sumsq of the coefficient of 'grad' itself in - # (grad - exp_avg). - # It is relevant that the sum of the coefficients (i.e. not squared) is 1; - # if this were not so we'd have to incorporate that into the formula for l2. - grad_scale = (((1 - beta1) ** 2) / (1 - beta1 ** 2) + 1) ** 0.5 - - with torch.no_grad(): - l2_grad = _compute_l2_grad( - param, - cur_pseudo_grad, - target_rms, - rho=step_size, - grad_scale=grad_scale, - period_scale=l2_period, - eps=eps, - safe=True, - ) - - # TODO: could alternate computing l2 on only, say, odd frames, and scale it - # up by 2, to save time. - param.add_(pseudo_grad + l2_grad, alpha=-step_size) - - -def _view_as_matrix(x: Tensor, dim: int) -> Tensor: - """ - Returns a Tensor of shape (n, x.shape[dim]), where n is the product - of the sizes of the other dimensions of x. This may involve a copy, - if x cannot be reshaped in this way. - """ - ndim = x.ndim - assert ndim > 1 and dim >= 0 and dim < ndim - # Move the dim to the last position in x.. - if dim != ndim - 1: - x = x.transpose(dim, ndim - 1) - return x.reshape(-1, x.shape[-1]) - - -def _outer_product(x: Tensor, dim: int) -> Tensor: - """ - Returns a Tensor of shape (x.shape[dim], x.shape[dim]) formed by - summing the outer products of all the vectors in x of size - `x.shape[dim]`, that we get by indexing x with all tuples of dimensions - on other axes. E.g. if x is a matrix and dim == 0, this would - be torch.matmul(x, x.transpose(0, 1)). - - Note: x must have at least 2 dimensions, x.ndim >= 2. - """ - x = _view_as_matrix(x, dim) - return torch.matmul(x.transpose(0, 1), x) - - -def _multiply_on_dim(x: Tensor, m: Tensor, dim: int) -> Tensor: - """ - Multiplies x by the matrix m which must be of shape: - (x.shape[dim], n)), with `dim` as the dimension/axis on - x to be multiplied. - - Caution: result may not have the same layout/strides as x, - although it will have the same shape. - - Args: - x: Tensor to be multiplied; must have ndim >= 2 - m: Symmetric matrix to multiply x by; must have - m.shape == (x.shape[dim], x.shape[dim]) - dim: Dimension of x to multiply on, with 0 <= dim < x.ndim - Return: - The matrix product, of the same shape as - x, except with the size on dimension `dim` being n. - """ - ndim = x.ndim - if dim != ndim - 1: - x = x.transpose(dim, ndim - 1) - ans = torch.matmul(x, m) - if dim != ndim - 1: - # Swap the dimensions back to what they were originally. - ans = ans.transpose(dim, ndim - 1) - return ans - - -def _multiply_product_combined( - l2: Tensor, grad: Tensor, dim: int, need_grad_sumsq: bool -): - """ - This function is an optimized version of the following code: - outer_prod = _outer_product(grad, dim) - l2 = _multiply_on_dim(l2, outer_prod, dim) - if dim == 0: # could choose any dim for this - grad_sumsq = torch.trace(outer_prod) - Args: - l2: The l2 matrix which starts out as the parameter tensor x, must have >= 2 diims - grad: The gradient tensor (or a gradient-like quantity); must - have same shape as l2. - dim: The dimension of l2 and grad that we want this to - act on, with 0 <= dim < l2.ndim. We multiply l2, on - this dim, by a symmetric quantity of shape - (l2.shape[dim], l2.shape[dim]), that is formed - by a product and sum on grad (this is a matrix - product, if there are 2 axes). - Returns: - Returns (l2, grad_sumsq), where l2 is the result of - multiplying l2 by the product mentioned above, and - grad_sumsq is either None, or a Tensor representing - the sum-of-squares of `grad`; for at least one - dim with 0 <= dim < l2.ndim, we guarantee to - return such a Tensor. - """ - grad = _view_as_matrix(grad, dim) - if grad.shape[1] <= grad.shape[0]: - # Minimize the size of the intermediate product, which will probably well reflect - # the compute time since memory access can be limiting on CUDA.a - grad_product = torch.matmul(grad.transpose(0, 1), grad) - l2 = _multiply_on_dim(l2, grad_product, dim) - if need_grad_sumsq: - grad_sumsq = torch.trace(grad_product) - else: - grad_sumsq = None - return (l2, grad_sumsq) - else: - l2 = _multiply_on_dim(l2, grad.transpose(0, 1), dim) - l2 = _multiply_on_dim(l2, grad, dim) - # This branch does not compute grad_sumsq, but we're bound to - # take the other branch on at least one occasion. - return (l2, None) - - -def _compute_l2_grad( - x: Tensor, - grad: Tensor, - target_stddev: float, - rho: float, - grad_scale: float = 1.0, - period_scale: int = 1, - eps: float = 1.0e-08, - safe: bool = True, -) -> Tensor: - """ - Returns the l2 gradient of x, which will be added to 'grad'. - This is a more principled replacement for the typical l2 regularization - formula where we do: - grad += weight_decay * x. - (Note: this must only be called if x.ndim >= 2). - - For x with 2 axes, we instead do this: - - grad += (rho / (2*target_stddev**2)) * (grad grad^T) x (grad^T grad) / trace(grad^T grad), - - where the implicit multiplication above refers to matrix multiplication; note, x means - the variable x. We'll have to write the justification of this, which is a little - complicated, separately; it has to do with using exactly the amount of l2 in each - subspace of each dimension of x, to to cancel out the gradient noise. - - Args: - x: parameter to be updated. MUST HAVE x.ndim >= 2. - grad: Gradient for x on this iteration (or at least, something that - is treated like a gradient in the update formula) - target_stddev: The target standard deviation (uncentered), of elements of x. - This is our estimate of what standard deviation these elements would - have in a well-trained model; it is set by some kind of heuristic. - rho: The learning rate we are going to use, as in: x -= (grad + l2) * rho. - grad_scale: A scale whereby the caller asserts that `grad` is some - quantity that is distributed like the real - gradient times `grad_scale` (this is useful when the provided `grad` - is really a moving average gradient). Because the l2 term's magnitude - is proportional to the gradient squared, we need to divide it by the - square of grad_scale, so this function uses 1/grad_scale^2 as a scaling - factor. - period_scale: An integer scale that we use to compensate for the fact that this - weight decay is only applied periodically, once every - `period_scale` minibatches. Accordingly, we make the l2 term - that many times larger. - eps: A small constant used to avoid division by zero - safe: If true, use a safe version of the formula that checks for - 'overshoot' of l2 regularization and fixes the issue (might - be an issue for models that are getting unstable or have high - learning rate) - - - Returns: - Returns l2 pseudo-gradient (term to be added to `grad`). - """ - assert x.shape == grad.shape - assert x.ndim >= 2 - - l2 = x - grad_sumsq = None - num_ignored_dims = 0 # for an optimization for when size=1 on some dim. - for dim in range(x.ndim): - # The code below is an optimization of the following few lines, - # which were perhaps easier to understand: - # outer_prod = _outer_product(grad, dim) - # l2 = _multiply_on_dim(l2, outer_prod, dim) - # if dim == 0: # could choose any dim for this - # grad_sumsq = torch.trace(outer_prod) - if x.shape[dim] <= 1: - num_ignored_dims += 1 - continue - (l2, maybe_grad_sumsq) = _multiply_product_combined( - l2, grad, dim, grad_sumsq is None - ) - if maybe_grad_sumsq is not None: - grad_sumsq = maybe_grad_sumsq - if grad_sumsq is None: - # We shouldn't reach here, except if at some point we start calling this - # code for tensors with ndim <= 1, or with numel() == 1. - grad_sumsq = (grad ** 2).sum() - - # l2 is the amount of l2, we'll subtract this from x, as in: - # x -= rho * (grad + l2). - - factor = rho * period_scale / (2.0 * (target_stddev * grad_scale) ** 2) - l2 = l2 * (factor / (grad_sumsq ** (x.ndim - 1 - num_ignored_dims) + eps)) - - if safe and rho > 0: - # x2_sum = (x ** 2).sum() - l2_sum = (l2 ** 2).sum() * (rho * rho) - cross_sum = (x * l2).sum() * rho - alpha = cross_sum / (l2_sum + eps) - # We want to minimize the sum-of-squares of (x - alpha * rho * l2), where alpha - # is a constant in [0,1] that we are about to estimate, intended to prevent - # instability by scaling down our weight decay formula. Right now (and treating - # things as if they were scalars for brevity): - # x2_sum = x * x - # l2_sum = rho * rho * l2 * l2 - # cross_sum = x * rho * l2 - # We want to minimize the sum-sq of (x - alpha * rho * l2), - # i.e. we want to choose alpha to minimize: - # x2_sum - 2 * alpha * cross_sum + alpha^2 * l2_sum - # d/dalpha of this, is: - # -2*cross_sum + 2 * alpha * l2_sum - # and setting this to zero and solving for alpha, we have: - # alpha = cross_sum / l2_sum. - # If it turns out that alpha >= 1, then we just use alpha=1 - # (the original formula), as there is no problem with - # instability/overshoot. - l2.mul_(alpha.clamp(max=1.0)) - if random.random() < 0.001 and alpha < 1.0: - logging.info( - f"madam optimizer: alpha={alpha}, shape={tuple(x.shape)}" - ) - return l2 - - -class Moam(object): - """ - Implements Moam optimizer. This is a modified version of the Noam optimizer - which was proposed in "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf, - but changed to use Madam (see above) instead of Adam as the base optimizer. - Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py - - Caution: you probably want to set 'factor' to a smaller value than you would typically - use for a corresponding Noam optimizer, because Moam does a kind of l2 regularization which - keeps the parameters fairly small, so the relative changes in model parameters - will be larger than Noam, for any given learning rate. - - Args: - params (iterable): iterable of parameters to optimize or dicts defining parameter groups - model_size: attention dimension of the transformer model - factor: learning rate factor, that multiplies the output of the - formula based on model size - warm_step: number of warmup steps before the learning rate starts to decrease - (it increases until this point). - min_target_rms: this is a parameter of the Madam optimizer; it represents a floor - on the "target root-mean-square value" that is used when the initialization - of a tensor is zero or below this value. It may be worth optimizing. - Don't worry about tensors with fewer than 2 dimensions when setting this, - these are not subject to our l2 formula. - limit_grad_factor: you can set this to a finite value, e.g. 2.0, to activate - a mechanism that limits the norms of larger-than-usual gradients. - This seems to cause a slowdown, likely due to GPU->CPU transfers. - l2_period: mechanism to improve the optimization speed, by only applying the l2 - regularization (which is a complicated formula) every this-many - minibatches. E.g. can set it to 2 or 4. - """ - - def __init__( - self, - params, - model_size: int = 256, - factor: float = 2.0, - # warm_step: int = 25000, - min_target_rms: float = 0.05, - limit_grad_factor: float = float("inf"), - l2_period: int = 1, - ) -> None: - """Construct an Noam object.""" - self.optimizer = Madam( - params, - lr=0, - betas=(0.9, 0.98), - eps=1e-9, - min_target_rms=min_target_rms, - limit_grad_factor=limit_grad_factor, - l2_period=l2_period, - ) - self._step = 0 - # self.warmup = warm_step - self.factor = factor - self.model_size = model_size - # self._rate = 0 - - @property - def param_groups(self): - """Return param_groups.""" - return self.optimizer.param_groups - - def step(self): - """Update parameters and rate.""" - self._step += 1 - # rate = self.rate() - # for p in self.optimizer.param_groups: - # p["lr"] = rate - # self._rate = rate - self.optimizer.step() - - # def rate(self, step=None): - # """Implement `lrate` above.""" - # if step is None: - # step = self._step - # return ( - # self.factor - # * self.model_size ** (-0.5) - # * min(step ** (-0.5), step * self.warmup ** (-1.5)) - # ) - - def zero_grad(self): - """Reset gradient.""" - self.optimizer.zero_grad() - - def state_dict(self): - """Return state_dict.""" - return { - "_step": self._step, - # "warmup": self.warmup, - "factor": self.factor, - "model_size": self.model_size, - # "_rate": self._rate, - "optimizer": self.optimizer.state_dict(), - } - - def load_state_dict(self, state_dict): - """Load state_dict.""" - for key, value in state_dict.items(): - if key == "optimizer": - self.optimizer.load_state_dict(state_dict["optimizer"]) - else: - setattr(self, key, value) - - -class TestModel(torch.nn.Module): - """Class for testing the Madam optimizer""" - - def __init__(self): - super(TestModel, self).__init__() - self.first_layers = torch.nn.Sequential( - torch.nn.Linear(100, 200), - torch.nn.ReLU(), - torch.nn.Linear(200, 300), - torch.nn.ReLU(), - ) - self.conv1 = torch.nn.Conv1d( - in_channels=300, out_channels=200, kernel_size=1 - ) - self.relu = torch.nn.ReLU() - self.conv2 = torch.nn.Conv1d( - in_channels=200, out_channels=250, kernel_size=3 - ) - - def forward(self, x): - # from (B, T, 100) to (B, T, 200) - x = self.first_layers(x) - # B, T, C -> B, C, T - x = x.transpose(1, 2) - x = self.conv2(self.relu(self.conv1(x))) - # B, C, T -> B, T, C - x = x.transpose(1, 2) - return x - - -def test_madam(): - print("Testing Madam optimizer") - global inf_grad_max_count - inf_grad_max_count = 200 - if torch.cuda.is_available(): - devices_and_l2 = [ - (torch.device("cuda"), True), - (torch.device("cuda"), False), - (torch.device("cpu"), True), - (torch.device("cpu"), False), - ] - else: - devices_and_l2 = [ - (torch.device("cpu"), True), - (torch.device("cpu"), False), - ] - - for (device, l2) in devices_and_l2: - model = TestModel().to(device) - # min_target_rms=0.01 is for testing, so the target equals the initial RMS - # and we can more easily tell whether our update has the desired effect. - # I also tested this with betas=(0.1, 0.98), to check that the effect of - # `grad_scale` was correct (it only makes much difference for small beta). - optimizer = Madam( - model.parameters(), - lr=0.0005, - betas=(0.9, 0.98), - l2=l2, - min_target_rms=0.01, - l2_period=1, - ) - # optimizer = torch.optim.Adam(model.parameters()) - - def get_elems_rms(x: Tensor) -> Tensor: - return ((x ** 2).sum() / x.numel()).sqrt().item() - - for i in range(1000): - if i % 100 == 0: - rms_values = ( - get_elems_rms(model.first_layers[0].weight), - get_elems_rms(model.first_layers[2].weight), - get_elems_rms(model.conv1.weight), - get_elems_rms(model.conv2.weight), - ) - print( - f"Iter {i}, l2={l2}, device={device}: stddevs = {rms_values} " - ) - B = 4 - T = 20 - x = torch.randn(B, T, 100).to(device) - y = model(x) - yderiv = torch.randn_like(y) - if i % 190 <= 3 and i > 0: - yderiv *= 100.0 - if i % 550 == 0 and i > 0: - yderiv *= float("inf") - - y.backward(gradient=yderiv) - optimizer.step() - model.zero_grad() - print("") - - -def test_moam(): - print("Testing Moam optimizer") - model = TestModel() - # min_target_rms=0.01 is for testing, so the target equals the initial RMS - # and we can more easily tell whether our update has the desired effect. - optimizer = Moam( - model.parameters(), factor=1.0, warm_step=300, min_target_rms=0.01 - ) - - def get_elems_rms(x: Tensor) -> Tensor: - return ((x ** 2).sum() / x.numel()).sqrt().item() - - for i in range(1000): - if i % 100 == 0: - rms_values = ( - get_elems_rms(model.first_layers[0].weight), - get_elems_rms(model.first_layers[2].weight), - get_elems_rms(model.conv1.weight), - get_elems_rms(model.conv2.weight), - ) - print(f"Iter {i} (Moam): stddevs = {rms_values} ") - B = 4 - T = 20 - x = torch.randn(B, T, 100) - y = model(x) - yderiv = torch.randn_like(y) - if i % 190 <= 3 and i > 0: - yderiv *= 100.0 - if i % 550 == 0 and i > 0: - yderiv *= float("inf") - - y.backward(gradient=yderiv) - optimizer.step() - model.zero_grad() - print("") - - -def test_to_device(): - if not torch.cuda.is_available(): - return - a_gpu = torch.ones(1, 2, 3, 4, device="cuda") - b_gpu = torch.zeros(3, 8, device="cuda") - (a_cpu, b_cpu) = _to_device("cpu", a_gpu, b_gpu) - print("a_cpu,b_cpu = ", a_cpu, b_cpu) - (a_gpu2, b_gpu2) = _to_device("cuda", a_cpu, b_cpu) - print("a_gpu2,b_gpu2 = ", a_gpu2, b_gpu2) - - -# Caution: this testing code is not very automated, it reqires looking at the output to -# make sure it looks right. The main thing is that with l2=True, the printed stddevs stay close -# to the "Target rms" values, which are printed out; while with l2=False, the stddevs -# increase to significantly higher than that. -# -# The test of the Moam optimizer is mainly to make sure it runs; the scale of the -# gradients, and the learning rate, are such that one of the rms's stays quite a bit -# above the target value, i.e. (0.047, 0.044, 0.047), vs. targets of -# (0.057, 0.04, 0.019), I think this has to do with the alpha<1 stability mechanism being -# activated, the l2 does have an effect, as I verified by changing the code to set -# l2=False. -def main(): - # Set number of threads to 1, or Torch can do weird things that make it extremely slow. - torch.set_num_threads(1) - torch.set_num_interop_threads(1) - # test_to_device() - random.seed(0) - torch.random.manual_seed(0) - test_madam() - test_moam() - - -if __name__ == "__main__": - main() diff --git a/egs/librispeech/ASR/conformer_ctc_madam_no_warmup/train.py b/egs/librispeech/ASR/conformer_ctc_madam_no_warmup/train.py index 24e3357d3..fb74b5781 100755 --- a/egs/librispeech/ASR/conformer_ctc_madam_no_warmup/train.py +++ b/egs/librispeech/ASR/conformer_ctc_madam_no_warmup/train.py @@ -17,8 +17,7 @@ import torch.nn as nn from conformer import Conformer from lhotse.utils import fix_random_seed -# from transformer import Noam -from madam_no_warmup import Moam +from madam import Madam from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.utils import clip_grad_norm_ from torch.utils.tensorboard import SummaryWriter @@ -37,6 +36,29 @@ from icefall.utils import ( ) +def create_madam( + params, + model_size: int = 256, + factor: float = 2.0, + warm_step: int = 25000, + min_target_rms: float = 0.05, + limit_grad_factor: float = float("inf"), + l2_period: int = 1, +): + initial_lr = warm_step ** (-0.5) + optimizer = Madam( + params, + lr=initial_lr, + betas=(0.9, 0.98), + eps=1e-9, + min_target_rms=min_target_rms, + limit_grad_factor=limit_grad_factor, + l2_period=l2_period, + ) + + return optimizer + + def get_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter @@ -86,8 +108,6 @@ def get_params() -> AttributeDict: - lang_dir: It contains language related input files such as "lexicon.txt" - - lr: It specifies the initial learning rate - - feature_dim: The model input dim. It has to match the one used in computing features. @@ -155,6 +175,7 @@ def get_params() -> AttributeDict: "mmi_loss": False, "use_feat_batchnorm": False, "lr_factor": 2.0, + "warm_step": 30000, } ) @@ -697,14 +718,15 @@ def run(rank, world_size, args): if world_size > 1: model = DDP(model, device_ids=[rank]) - optimizer = Moam( + optimizer = create_madam( model.parameters(), model_size=params.attention_dim, factor=params.lr_factor, + warm_step=params.warm_step, ) scheduler = torch.optim.lr_scheduler.LambdaLR( - optimizer, lambda ep: 1.0 if ep < 3 else 0.7 ** (ep - 2) + optimizer, lambda ep: 1.0 if ep < 3 else 0.75 ** (ep - 2) ) if checkpoints and checkpoints["optimizer"]: @@ -720,8 +742,6 @@ def run(rank, world_size, args): for epoch in range(params.start_epoch, params.num_epochs): train_dl.sampler.set_epoch(epoch) - # LR scheduler can hold multiple learning rates for multiple parameter groups; - # For now we report just the first LR which we assume concerns most of the parameters. cur_lr = scheduler.get_last_lr()[0] if tb_writer is not None: tb_writer.add_scalar(