From c3a87274465953971fc00c03e46205f9cec490f2 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 23 Aug 2021 22:28:45 +0800 Subject: [PATCH] Add train.py --- egs/librispeech/ASR/conformer_lm/dataset.py | 5 +- egs/librispeech/ASR/conformer_lm/madam.py | 959 ++++++++++++++++++++ egs/librispeech/ASR/conformer_lm/train.py | 607 +++++++++++++ 3 files changed, 1569 insertions(+), 2 deletions(-) create mode 100644 egs/librispeech/ASR/conformer_lm/madam.py create mode 100755 egs/librispeech/ASR/conformer_lm/train.py diff --git a/egs/librispeech/ASR/conformer_lm/dataset.py b/egs/librispeech/ASR/conformer_lm/dataset.py index cad2c4d8f..3074d1099 100644 --- a/egs/librispeech/ASR/conformer_lm/dataset.py +++ b/egs/librispeech/ASR/conformer_lm/dataset.py @@ -3,7 +3,8 @@ import torch.distributed as dist import k2 import _k2 import sentencepiece as spm -from typing import Optional, List, Tuple +from pathlib import Path +from typing import Optional, List, Tuple, Union @@ -36,7 +37,7 @@ class LmDataset(torch.utils.data.Dataset): return k2.index(self.words, sentence).values().tolist() -def load_train_test_lm_dataset(archive_fn: str, +def load_train_test_lm_dataset(archive_fn: Union[str,Path], test_proportion: float = 0.025) -> Tuple[LmDataset, LmDataset]: """ returns (train_lm_dataset, test_lm_dataset) diff --git a/egs/librispeech/ASR/conformer_lm/madam.py b/egs/librispeech/ASR/conformer_lm/madam.py new file mode 100644 index 000000000..aa605c30b --- /dev/null +++ b/egs/librispeech/ASR/conformer_lm/madam.py @@ -0,0 +1,959 @@ +import logging +import math +import random +import torch +from torch import nn +from torch import Tensor +from torch.optim.optimizer import Optimizer +from typing import List, Tuple + + + +# After this many warnings about infinite gradients we'll die. +inf_grad_count = 0 +inf_grad_max_count = 20 + +class Madam(Optimizer): + r"""Madam is a modification of the Adam algorithm, with various changes + intended to support certain "common-sense" ideas and solve common + pathologies that can happen particularly in transformer-type models that + have multiplication of parameters (particularly, key and query matrices)-- + these can be vulnerable to "subspace loss" where, if you have any l2 + regularization, certain subspaces in the key/query space might get + regularized towards zero. We solve this with a special formula that + changes how the l2/weight-decay is done (see compute_l2_grad()). + I'll try to write the math down at some point. This formula only + applies to tensors that have at least two dimensions; for one-dimensional + tensors we simply won't do l2 regularization. + + One more thing-- there is a special pathology that can sometimes afflict + models like LSTMs, where a particular element of a minibatch experiences + gradient blowup in the backward pass. We'd like to identify such cases and + fix it somehow, e.g. by removing or scaling down the gradient for that + particular minibatch. We can identify and somewhat fix this by seeing that the + gradient norm (computed over all the parameters in a parameter group) is + much more than on previous minibatches, and limiting it to (the preceding + average step size times some constant). + + Like most optimization algorithms, for this to work well you need to + have an appropriate learning rate schedule, either decreasing with + time, or increasing (warm-up) and then decreasing. The LR schedule may + possibly need to decrease a little more aggressively than you would with + Adam, or at least have smaller values overall than Adam, because + the smaller parameters will mean the effective (relative) learning + rate is higher. + + This is modified from PyTorch's optim/adam.py + + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + grad_norm_buffer_size (int, optional): Buffer size used in detecting + minibatches with unusually large gradients and scaling them down. + limit_grad_factor (float): factor by which we don't allow the + gradient to be greater than the average of previous gradients + (we'll scale the gradient down, over the whole param-group, + to enforce this). Must be greater than 1. Set to float('inf') + to disable norm clipping. + min_target_rms: A floor on the "target rms" of each Tensor, so + that Tensors that, when initialized, have less than this + rms value will have their target rms value floored to this + l2: True to enable l2 regularization + l2_period: You may set this to a value greater than one to save + computation by only periodically doing the l2 update. + We include a scaling factor in the formula so that, as far + as possible (for small learning rates) this shouldn't affect + the results. (Note: this probably isn't necessary to set, + since it turns out the update is quite fast, at least on GPU, + and the gradient clipping is actually more of a problem) + + + .. _Adam\: A Method for Stochastic Optimization: + https://arxiv.org/abs/1412.6980 + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + + """ + + def __init__(self, params, + lr: float = 1e-3, + betas: Tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + grad_norm_buffer_size: int = 8, + limit_grad_factor: float = 2.0, + min_target_rms: float = 0.05, + l2: bool = True, + l2_period: int = 1): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not (isinstance(grad_norm_buffer_size, int) and grad_norm_buffer_size > 1): + raise ValueError("Invalid grad_norm_buffer_size value: {}".format(grad_norm_buffer_size)) + if not limit_grad_factor > 1.0: + raise ValueError("Invalid limit_grad_factor: {}".format(limit_grad_factor)) + if not isinstance(l2, bool): + raise ValueError("Invalid l2 value: {}".format(l2)) + if not l2_period >= 1: + raise ValueError("Invalid l2_period value: {}".format(l2_period)) + defaults = dict(lr=lr, betas=betas, eps=eps, + grad_norm_buffer_size=grad_norm_buffer_size, + limit_grad_factor=limit_grad_factor, + l2=l2, l2_period=l2_period, + min_target_rms=min_target_rms) + super(Madam, self).__init__(params, defaults) + + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + + beta1, beta2 = group['betas'] + grad_norm_buffer_size = group['grad_norm_buffer_size'] + limit_grad_factor = group['limit_grad_factor'] + min_target_rms = group['min_target_rms'] + + # The next 5 lists are part of the original Adam optimizer + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + state_steps = [] + + # The next 3 lists are not part of the original Adam optimizer. + target_rms_values = [] # relates to weight decay. Target root-mean-square + # values of the elements of each parameter + # we are optimizing + prev_norm_stats = [] # contains Tensor with 2 elements each, the sum + # of the [sum_squared, count] of + # this parameter on previous minibatches (up to + # grad_norm_buffer_size minibatches) + cur_grad_norms = [] # and `cur_grad_norms` contains the squared l2 + # norm norm of this step's gradient for this + # parameter, as a Tensor. + + + for p in group['params']: + if p.grad is not None: + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') + grads.append(p.grad) + + state = self.state[p] + # Lazy state initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) + + # The things below are not part of original Adam, they are the Madam extension.. + state['target_rms'] = _get_target_rms(p, min_target_rms) + # grad_norm_buf is a rotating buffer containing (grad_norm**2, count), where + # count is 1 for grad_norms that are set and 0 for those that are not set because + # we're near step 0 or because they were infinite. + state['grad_norm_buf'] = torch.zeros(grad_norm_buffer_size, 2, device=p.device) + + exp_avgs.append(state['exp_avg']) + exp_avg_sqs.append(state['exp_avg_sq']) + + target_rms_values.append(state['target_rms']) + + cur_step = state['step'] + if limit_grad_factor != float('inf'): + grad_norm_buf = state['grad_norm_buf'] + cur_grad_norm = (p.grad ** 2).sum() # actually squared nom + prev_mean_norm = grad_norm_buf.sum(0) # prev_mean_norm is a Tensor [ tot_norm_squared, count ] + grad_norm_buf[cur_step % grad_norm_buffer_size][0] = cur_grad_norm + grad_norm_buf[cur_step % grad_norm_buffer_size][1].fill_(1.0) + prev_norm_stats.append(prev_mean_norm) + cur_grad_norms.append(cur_grad_norm) + + # update the steps for each param group update + cur_step += 1 + state['step'] = cur_step + # record the step after step update + state_steps.append(cur_step) + + if limit_grad_factor != float('inf'): + self._apply_grad_norm_clipping(group['params'], + prev_norm_stats, cur_grad_norms, grads, + limit_grad_factor, grad_norm_buffer_size) + + _madam(params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + state_steps, + target_rms_values, + beta1=beta1, + beta2=beta2, + lr=group['lr'], + eps=group['eps'], + l2=group['l2'], + l2_period=group['l2_period']) + + return loss + + + def _apply_grad_norm_clipping(self, + params_list, + prev_norm_stats: List[Tensor], + cur_grad_norms: List[Tensor], + grads: List[Tensor], + limit_grad_factor: float, + grad_norm_buffer_size: int) -> None: + """ + This function applies gradient norm clipping for this parameter group if this + minibatch has substantially larger gradients in this param group than + recent minibatches. The idea is to catch cases like where an LSTM + happens to blow up in the backward pass, or some code bug causes very + large or infinite gradients on a particular minibatch; so we scale + down any very large gradients and zero infinite ones. + + Args: + params_list: some kind of iterable or list of params in this group + prev_norm_stats: a list which, for each parameter in this group + with a grad, contains a Tensor with 2 elements each, containing + # the [sum, count] of up to `grad_norm_buffer_size` + # norms of this parameter on previous minibatches; + cur_grad_norms: a list of Tensor containing, for each parameter in this group, + the norm of this step's gradient for this parameter. + grads: List of gradients with the same order as prev_norm_stats and + cur_grad_norms + limit_grad_factor: a float >1.0 (e.g. 4.0) that dictates + how-much-larger-than-average gradients we allow before clipping. + grad_norm_buffer_size: an int that determines the rolling buffer size over which + we store gradient norms + """ + num_params = len(prev_norm_stats) + assert len(grads) == num_params + + all_prev_norm_stats, all_cur_grad_norms = _to_device('cpu', + torch.stack(prev_norm_stats), + torch.stack(cur_grad_norms)) + assert all_prev_norm_stats.shape == (num_params, 2) + assert all_cur_grad_norms.shape == (num_params,) + + # divide totals by counts (i.e. counts of iterations were we stored + # a finite grad) + all_prev_grad_norms = all_prev_norm_stats[:,0] / all_prev_norm_stats[:,1] + # prev_norm and cur_norm are floats, they are actually squared norms. + prev_norm = all_prev_grad_norms.sum().item() + cur_norm = all_cur_grad_norms.sum().item() + + if prev_norm - prev_norm != 0.0: + # There were zero counts; fix this by using the current grad norm + # for affected parameters, and recompute all_prev_grad_norms and + # prev_norm. + for i in range(num_params): + if all_prev_norm_stats[i][1] == 0.0: + # if count is 0 and cur norm is finite, use cur norm as our estimate + # of previous norms. This would only be useful if some but not + # all params were in this situation of having no previous estimates. + cur = all_cur_grad_norms[i] + if cur - cur == 0.0: # finite.. + all_prev_norm_stats[i][0] = cur + all_prev_norm_stats[i][1] = 1.0 + else: + # 0.0 is a default; likely won't matter, as if we + # get infinite `cur`, we'll abandon this minibatch. + all_prev_norm_stats[i][0] = 0.0 + all_prev_grad_norms = all_prev_norm_stats[:,0] / all_prev_norm_stats[:,1] + prev_norm = all_prev_grad_norms.sum().item() + + # Deal with infinite gradients. + if cur_norm - cur_norm != 0: # cur_norm is infinite or NaN + global inf_grad_count + logging.warning(f'Infinite gradient-norm detected (cur/prev: {cur_norm}/{prev_norm}): will ' + f'zero grad ({inf_grad_count}/{inf_grad_max_count} times until dying)') + inf_grad_count += 1 + if inf_grad_count >= inf_grad_max_count: + assert 0, "Reached max count of infinite gradient-norm stats" + # Zero all gradients in this group + for g in grads: + g[:] = 0. + # .. and zero the stored gradient norms in grad_norm_buf (so + # that infinities don't ruin our stats of previous batches) + for p in params_list: + if p.grad is not None: + state = self.state[p] + grad_norm_buf = state['grad_norm_buf'] + # cur_step is the location where we would have written the grad_norm. + # We didn't check if it was infinity before, because we didn't want to + # incur lots of GPU->CPU transfers. + cur_step = state['step'] - 1 + # Remove this 'bad' step from the buffer. + grad_norm_buf[cur_step % grad_norm_buffer_size][:] = 0.0 + else: + # cur_norm is finite. Check whether we have to clip this iteration's grad. + # we always remove infinities/NaNs from the buffer, so prev_norm should not + # be infinite or NaN. + assert prev_norm - prev_norm == 0.0 + # cur_norm and prev_norm are actually squared norms, so we need to + # square limit_grad_factor.. + limit_grad_factor2 = limit_grad_factor ** 2 + if cur_norm > prev_norm * limit_grad_factor2: + grad_factor2 = (prev_norm * limit_grad_factor2) / cur_norm + grad_factor = grad_factor2 ** 0.5 + cur_norm_f, prev_norm_f, grad_factor_f = ('%.2g' % cur_norm, '%.2g' % prev_norm, + '%.2g' % grad_factor) + logging.warning(f'Gradient norm exceeds average of last {grad_norm_buffer_size} ' + f'gradients times {limit_grad_factor}: cur/prev {cur_norm_f}/{prev_norm_f}: ' + f'scaling it by {grad_factor_f}.') + for g in grads: + g[:] *= grad_factor + # .. and scale down the stored gradient norms in grad_norm_buf, to + # avoid the bound getting too loose too quickly. + for p in params_list: + if p.grad is not None: + state = self.state[p] + grad_norm_buf = state['grad_norm_buf'] + cur_step = state['step'] - 1 + # the buffer contains squared norms, so multiply by grad_factor2 + grad_norm_buf[cur_step % grad_norm_buffer_size][0] *= grad_factor2 + + +def _to_device(device, *args): + """ + Transfers a tuple of Tensors from one device to another, using a single transfer. Must have + same dtype but may have different shapes. + E.g. + (cpu_tensor_a, cpu_tensor_b) = _to_device('cpu', gpu_tensor_a, gpu_tensor_b) + """ + if device == args[0].device: + return args + else: + arg0 = args[0] + combined_src = torch.cat([ x.reshape(-1) for x in args ]) + combined_dest = combined_src.to(device) + dests = [] + offset = 0 + for src in args: + numels = src.numel() + dests.append(combined_dest[offset:offset+numels].reshape(src.shape)) + offset += numels + return tuple(dests) + + + +def _get_target_rms(x: Tensor, min_target_rms: float) -> Tensor: + """ + Returns Tensor with one element, representing a target root-mean-square + value of elements of x, that we consider "reasonable", and will use a + as a "target rms" in our modified weight-decay formula. It returns + the maximum of the current RMS of the values of x, and `min_target_rms`, + as a Tensor on the same device as x. + """ + with torch.no_grad(): + # `norm` is the 2-norm of x currently (and this function should be + # called right after parameter initialization) + rms = ((x ** 2).sum() / x.numel()).sqrt() + largest_dim = max(list(x.shape)) + numel = x.numel() + if min_target_rms > 0.0: + rms = rms.clamp(min=min_target_rms) + if x.ndim > 1 and __name__ == '__main__': # will only be used for x.ndim > 1. + print("Target rms = ", rms) # Print this in testing only. + return rms + + +def _madam(params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + state_steps: List[int], + target_rms_values: List[Tensor], + *, + beta1: float, + beta2: float, + lr: float, + eps: float, + l2: bool, + l2_period: int): + r"""This is a modification of adam() from torch's optim/_functional.py. + + It has been modified to: + (i) remove the amsgrad option; this shouldn't be as necessary due to + the adaptive gradient norm clipping we have added + (ii) add our special formula for l2 regularization. This doesn't have + any tunable parameters, other than the target standard deviation + of the elements of the tensor (which is passed in as target_rms). + Args: + params: list of Tensor, containing the parameters to be optimized + grads: list of Tensor, containing the gradients corresponding to + each of the params (grads[i] should correspond to params[i].grad, + although it may have undergone gradient clipping). + exp_avgs: list of Tensor, containing tensors with the same dimensions + as params and grads, that contain the moving-averages of + `grads`. + exp_avg_sqs: list of Tensor, containing tensors with the same dimensions + as params and grads, that contain the moving-averages of + `grads ** 2`. + state_steps: list of int, containing the step for each parameter (step >= 1) + target_rms_values: list of Tensor with one element each, containing the + target root-mean-square values of each parameter tensor in `params` + l2: a bool, where if true we will activate the l2 regularization + formula. + l2_period: an integer that determines how often (i.e. every how many + minibatches) we apply the l2 update. We include a scaling factor + so that as far as possible the result will not be too sensitive + to the value of this. + + beta1: decay factor for gradients, e.g. 0.9 + beta2: decay factor for gradients squared, e.g. 0.999 + lr: learning rate, e.g. 0.0001 + eps: a small constant used to prevent division by zero, e.g. 1.0e-8 + + See :class:`~torch.optim.Adam` for details. + """ + assert len(params) == len(grads) == len(state_steps) == len(exp_avgs) == len(exp_avg_sqs) == len(target_rms_values) + + for i, param in enumerate(params): + + grad = grads[i] + + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step = state_steps[i] + target_rms = target_rms_values[i] + + bias_correction1 = 1 - beta1 ** step + bias_correction2 = 1 - beta2 ** step + + do_l2 = param.ndim > 1 and l2 and step % l2_period == 0 + + if do_l2: + # This represents just the "noise term" of the gradient, i.e. the grad minus the + # running mean. We'll later divide by denom. + cur_grad_noise = (grad - exp_avg) + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps) + + step_size = lr / bias_correction1 + + if not do_l2: + param.addcdiv_(exp_avg, denom, value=-step_size) + else: + # We can treat "pseudo_grad" as if it were a gradient (even though it's + # actually a gradient times a per-element learning rate). The analysis + # that we used to figure out what the l2 should be did not use the fact + # that the gradients were actually gradients, it simply analyzed it as a + # quantity that can be treated as close to zero-mean and with a certain + # structure of variance, and added to the param with the formula: + # + # param -= step_size * grad + # + # The original analysis assumed the gradients were independent from frame + # to frame; in fact these are not, but the important difference can be captured + # in a scalar `grad_scale` that expresses the scale of pseudo_grad relative + # to the independent gradients that we are effectively adding on each frame + # (but with a delay). + + pseudo_grad = exp_avg / denom + cur_pseudo_grad = cur_grad_noise / denom + + # grad_scale expresses the expected size of cur_pseudo_grad relative to the + # original grads if we had not done the moving-average; it is the sqrt of + # the sum of the squares of coefficients of previous gradients: + # c_n = (1-beta1) beta1^n, for + # n = 0, 1, .. + # .. plus one which is the sumsq of the coefficient of 'grad' itself in + # (grad - exp_avg). + # It is relevant that the sum of the coefficients (i.e. not squared) is 1; + # if this were not so we'd have to incorporate that into the formula for l2. + grad_scale = (((1 - beta1)**2) / (1 - beta1**2) + 1) ** 0.5 + + with torch.no_grad(): + l2_grad = _compute_l2_grad(param, cur_pseudo_grad, target_rms, + rho=step_size, grad_scale=grad_scale, + period_scale=l2_period, + eps=eps, safe=True) + + # TODO: could alternate computing l2 on only, say, odd frames, and scale it + # up by 2, to save time. + param.add_(pseudo_grad + l2_grad, alpha=-step_size) + + + +def _view_as_matrix(x: Tensor, dim: int) -> Tensor: + """ + Returns a Tensor of shape (n, x.shape[dim]), where n is the product + of the sizes of the other dimensions of x. This may involve a copy, + if x cannot be reshaped in this way. + """ + ndim = x.ndim + assert ndim > 1 and dim >= 0 and dim < ndim + # Move the dim to the last position in x.. + if dim != ndim - 1: + x = x.transpose(dim, ndim - 1) + return x.reshape(-1, x.shape[-1]) + + +def _outer_product(x: Tensor, dim: int) -> Tensor: + """ + Returns a Tensor of shape (x.shape[dim], x.shape[dim]) formed by + summing the outer products of all the vectors in x of size + `x.shape[dim]`, that we get by indexing x with all tuples of dimensions + on other axes. E.g. if x is a matrix and dim == 0, this would + be torch.matmul(x, x.transpose(0, 1)). + + Note: x must have at least 2 dimensions, x.ndim >= 2. + """ + x = _view_as_matrix(x, dim) + return torch.matmul(x.transpose(0, 1), x) + +def _multiply_on_dim(x: Tensor, m: Tensor, dim: int) -> Tensor: + """ + Multiplies x by the matrix m which must be of shape: + (x.shape[dim], n)), with `dim` as the dimension/axis on + x to be multiplied. + + Caution: result may not have the same layout/strides as x, + although it will have the same shape. + + Args: + x: Tensor to be multiplied; must have ndim >= 2 + m: Symmetric matrix to multiply x by; must have + m.shape == (x.shape[dim], x.shape[dim]) + dim: Dimension of x to multiply on, with 0 <= dim < x.ndim + Return: + The matrix product, of the same shape as + x, except with the size on dimension `dim` being n. + """ + ndim = x.ndim + if dim != ndim - 1: + x = x.transpose(dim, ndim - 1) + ans = torch.matmul(x, m) + if dim != ndim - 1: + # Swap the dimensions back to what they were originally. + ans = ans.transpose(dim, ndim - 1) + return ans + + +def _multiply_product_combined(l2: Tensor, grad: Tensor, dim: int, + need_grad_sumsq: bool): + """ + This function is an optimized version of the following code: + outer_prod = _outer_product(grad, dim) + l2 = _multiply_on_dim(l2, outer_prod, dim) + if dim == 0: # could choose any dim for this + grad_sumsq = torch.trace(outer_prod) + Args: + l2: The l2 matrix which starts out as the parameter tensor x, must have >= 2 diims + grad: The gradient tensor (or a gradient-like quantity); must + have same shape as l2. + dim: The dimension of l2 and grad that we want this to + act on, with 0 <= dim < l2.ndim. We multiply l2, on + this dim, by a symmetric quantity of shape + (l2.shape[dim], l2.shape[dim]), that is formed + by a product and sum on grad (this is a matrix + product, if there are 2 axes). + Returns: + Returns (l2, grad_sumsq), where l2 is the result of + multiplying l2 by the product mentioned above, and + grad_sumsq is either None, or a Tensor representing + the sum-of-squares of `grad`; for at least one + dim with 0 <= dim < l2.ndim, we guarantee to + return such a Tensor. + """ + grad = _view_as_matrix(grad, dim) + if grad.shape[1] <= grad.shape[0]: + # Minimize the size of the intermediate product, which will probably well reflect + # the compute time since memory access can be limiting on CUDA.a + grad_product = torch.matmul(grad.transpose(0, 1), grad) + l2 = _multiply_on_dim(l2, grad_product, dim) + if need_grad_sumsq: + grad_sumsq = torch.trace(grad_product) + else: + grad_sumsq = None + return (l2, grad_sumsq) + else: + l2 = _multiply_on_dim(l2, grad.transpose(0, 1), dim) + l2 = _multiply_on_dim(l2, grad, dim) + # This branch does not compute grad_sumsq, but we're bound to + # take the other branch on at least one occasion. + return (l2, None) + + + +def _compute_l2_grad(x: Tensor, grad: Tensor, target_stddev: float, rho: float, + grad_scale: float = 1.0, period_scale: int = 1, + eps: float = 1.0e-08, + safe: bool = True) -> Tensor: + """ + Returns the l2 gradient of x, which will be added to 'grad'. + This is a more principled replacement for the typical l2 regularization + formula where we do: + grad += weight_decay * x. + (Note: this must only be called if x.ndim >= 2). + + For x with 2 axes, we instead do this: + + grad += (rho / (2*target_stddev**2)) * (grad grad^T) x (grad^T grad) / trace(grad^T grad), + + where the implicit multiplication above refers to matrix multiplication; note, x means + the variable x. We'll have to write the justification of this, which is a little + complicated, separately; it has to do with using exactly the amount of l2 in each + subspace of each dimension of x, to to cancel out the gradient noise. + + Args: + x: parameter to be updated. MUST HAVE x.ndim >= 2. + grad: Gradient for x on this iteration (or at least, something that + is treated like a gradient in the update formula) +target_stddev: The target standard deviation (uncentered), of elements of x. + This is our estimate of what standard deviation these elements would + have in a well-trained model; it is set by some kind of heuristic. + rho: The learning rate we are going to use, as in: x -= (grad + l2) * rho. + grad_scale: A scale whereby the caller asserts that `grad` is some + quantity that is distributed like the real + gradient times `grad_scale` (this is useful when the provided `grad` + is really a moving average gradient). Because the l2 term's magnitude + is proportional to the gradient squared, we need to divide it by the + square of grad_scale, so this function uses 1/grad_scale^2 as a scaling + factor. +period_scale: An integer scale that we use to compensate for the fact that this + weight decay is only applied periodically, once every + `period_scale` minibatches. Accordingly, we make the l2 term + that many times larger. + eps: A small constant used to avoid division by zero + safe: If true, use a safe version of the formula that checks for + 'overshoot' of l2 regularization and fixes the issue (might + be an issue for models that are getting unstable or have high + learning rate) + + + Returns: + Returns l2 pseudo-gradient (term to be added to `grad`). + """ + assert x.shape == grad.shape + assert x.ndim >= 2 + + l2 = x + grad_sumsq = None + num_ignored_dims = 0 # for an optimization for when size=1 on some dim. + for dim in range(x.ndim): + # The code below is an optimization of the following few lines, + # which were perhaps easier to understand: + # outer_prod = _outer_product(grad, dim) + # l2 = _multiply_on_dim(l2, outer_prod, dim) + # if dim == 0: # could choose any dim for this + # grad_sumsq = torch.trace(outer_prod) + if x.shape[dim] <= 1: + num_ignored_dims += 1 + continue + (l2, maybe_grad_sumsq) = _multiply_product_combined(l2, grad, dim, + grad_sumsq is None) + if maybe_grad_sumsq is not None: + grad_sumsq = maybe_grad_sumsq + if grad_sumsq is None: + # We shouldn't reach here, except if at some point we start calling this + # code for tensors with ndim <= 1, or with numel() == 1. + grad_sumsq = (grad ** 2).sum() + + # l2 is the amount of l2, we'll subtract this from x, as in: + # x -= rho * (grad + l2). + + factor = rho * period_scale / (2.0 * (target_stddev * grad_scale)**2) + l2 = l2 * (factor / (grad_sumsq ** (x.ndim - 1 - num_ignored_dims) + eps)) + + if safe and rho > 0: + #x2_sum = (x ** 2).sum() + l2_sum = (l2 ** 2).sum() * (rho * rho) + cross_sum = (x * l2).sum() * rho + alpha = cross_sum / (l2_sum + eps) + # We want to minimize the sum-of-squares of (x - alpha * rho * l2), where alpha + # is a constant in [0,1] that we are about to estimate, intended to prevent + # instability by scaling down our weight decay formula. Right now (and treating + # things as if they were scalars for brevity): + # x2_sum = x * x + # l2_sum = rho * rho * l2 * l2 + # cross_sum = x * rho * l2 + # We want to minimize the sum-sq of (x - alpha * rho * l2), + # i.e. we want to choose alpha to minimize: + # x2_sum - 2 * alpha * cross_sum + alpha^2 * l2_sum + # d/dalpha of this, is: + # -2*cross_sum + 2 * alpha * l2_sum + # and setting this to zero and solving for alpha, we have: + # alpha = cross_sum / l2_sum. + # If it turns out that alpha >= 1, then we just use alpha=1 + # (the original formula), as there is no problem with + # instability/overshoot. + l2.mul_(alpha.clamp(max=1.0)) + if random.random() < 0.001 and alpha < 1.0: + logging.info(f'madam optimizer: alpha={alpha}, shape={tuple(x.shape)}') + return l2 + + + +class Moam(object): + """ + Implements Moam optimizer. This is a modified version of the Noam optimizer + which was proposed in "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf, + but changed to use Madam (see above) instead of Adam as the base optimizer. + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py + + Caution: you probably want to set 'factor' to a smaller value than you would typically + use for a corresponding Noam optimizer, because Moam does a kind of l2 regularization which + keeps the parameters fairly small, so the relative changes in model parameters + will be larger than Noam, for any given learning rate. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining parameter groups + model_size: attention dimension of the transformer model + factor: learning rate factor, that multiplies the output of the + formula based on model size + warm_step: number of warmup steps before the learning rate starts to decrease + (it increases until this point). + min_target_rms: this is a parameter of the Madam optimizer; it represents a floor + on the "target root-mean-square value" that is used when the initialization + of a tensor is zero or below this value. It may be worth optimizing. + Don't worry about tensors with fewer than 2 dimensions when setting this, + these are not subject to our l2 formula. + limit_grad_factor: you can set this to a finite value, e.g. 2.0, to activate + a mechanism that limits the norms of larger-than-usual gradients. + This seems to cause a slowdown, likely due to GPU->CPU transfers. + l2_period: mechanism to improve the optimization speed, by only applying the l2 + regularization (which is a complicated formula) every this-many + minibatches. E.g. can set it to 2 or 4. + """ + + def __init__(self, params, model_size: int = 256, + factor: float = 2.0, warm_step: int = 25000, + min_target_rms: float = 0.05, + limit_grad_factor: float = float('inf'), + l2_period: int = 1) -> None: + """Construct an Noam object.""" + self.optimizer = Madam(params, lr=0, betas=(0.9, 0.98), eps=1e-9, + min_target_rms=min_target_rms, + limit_grad_factor=limit_grad_factor, + l2_period=l2_period) + self._step = 0 + self.warmup = warm_step + self.factor = factor + self.model_size = model_size + self._rate = 0 + + @property + def param_groups(self): + """Return param_groups.""" + return self.optimizer.param_groups + + def step(self): + """Update parameters and rate.""" + self._step += 1 + rate = self.rate() + for p in self.optimizer.param_groups: + p["lr"] = rate + self._rate = rate + self.optimizer.step() + + def rate(self, step=None): + """Implement `lrate` above.""" + if step is None: + step = self._step + return ( + self.factor + * self.model_size ** (-0.5) + * min(step ** (-0.5), step * self.warmup ** (-1.5)) + ) + + def zero_grad(self): + """Reset gradient.""" + self.optimizer.zero_grad() + + def state_dict(self): + """Return state_dict.""" + return { + "_step": self._step, + "warmup": self.warmup, + "factor": self.factor, + "model_size": self.model_size, + "_rate": self._rate, + "optimizer": self.optimizer.state_dict(), + } + + def load_state_dict(self, state_dict): + """Load state_dict.""" + for key, value in state_dict.items(): + if key == "optimizer": + self.optimizer.load_state_dict(state_dict["optimizer"]) + else: + setattr(self, key, value) + + + +class TestModel(torch.nn.Module): + """Class for testing the Madam optimizer""" + def __init__(self): + super(TestModel, self).__init__() + self.first_layers = torch.nn.Sequential( + torch.nn.Linear(100, 200), + torch.nn.ReLU(), + torch.nn.Linear(200, 300), + torch.nn.ReLU()) + self.conv1 = torch.nn.Conv1d(in_channels=300, out_channels=200, + kernel_size=1) + self.relu = torch.nn.ReLU() + self.conv2 = torch.nn.Conv1d(in_channels=200, out_channels=250, + kernel_size=3) + + + def forward(self, x): + # from (B, T, 100) to (B, T, 200) + x = self.first_layers(x) + # B, T, C -> B, C, T + x = x.transpose(1, 2) + x = self.conv2(self.relu(self.conv1(x))) + # B, C, T -> B, T, C + x = x.transpose(1, 2) + return x + +def test_madam(): + print("Testing Madam optimizer") + global inf_grad_max_count + inf_grad_max_count = 200 + if torch.cuda.is_available(): + devices_and_l2 = [(torch.device('cuda'), True), + (torch.device('cuda'), False)] + #(torch.device('cpu'), True), + #(torch.device('cpu'), False)] + else: + devices_and_l2 = [(torch.device('cpu'), True), + (torch.device('cpu'), False)] + + + for (device, l2) in devices_and_l2: + model = TestModel().to(device) + # min_target_rms=0.01 is for testing, so the target equals the initial RMS + # and we can more easily tell whether our update has the desired effect. + # I also tested this with betas=(0.1, 0.98), to check that the effect of + # `grad_scale` was correct (it only makes much difference for small beta). + optimizer = Madam(model.parameters(), lr=0.0005, betas=(0.9, 0.98), + l2=l2, min_target_rms=0.01, l2_period=1) + #optimizer = torch.optim.Adam(model.parameters()) + + def get_elems_rms(x: Tensor) -> Tensor: + return ((x ** 2).sum() / x.numel()).sqrt().item() + + for i in range(1000): + if i % 100 == 0: + rms_values = (get_elems_rms(model.first_layers[0].weight), + get_elems_rms(model.first_layers[2].weight), + get_elems_rms(model.conv1.weight), + get_elems_rms(model.conv2.weight)) + print(f"Iter {i}, l2={l2}, device={device}: stddevs = {rms_values} ") + B = 4 + T = 20 + x = torch.randn(B, T, 100).to(device) + y = model(x) + yderiv = torch.randn_like(y) + if i % 190 <= 3 and i > 0: + yderiv *= 100.0 + if i % 550 == 0 and i > 0: + yderiv *= float('inf') + + y.backward(gradient=yderiv) + optimizer.step() + model.zero_grad() + print("") + +def test_moam(): + print("Testing Moam optimizer") + model = TestModel() + # min_target_rms=0.01 is for testing, so the target equals the initial RMS + # and we can more easily tell whether our update has the desired effect. + optimizer = Moam(model.parameters(), factor=1.0, warm_step=300, + min_target_rms=0.01) + + + def get_elems_rms(x: Tensor) -> Tensor: + return ((x ** 2).sum() / x.numel()).sqrt().item() + + for i in range(1000): + if i % 100 == 0: + rms_values = (get_elems_rms(model.first_layers[0].weight), + get_elems_rms(model.first_layers[2].weight), + get_elems_rms(model.conv1.weight), + get_elems_rms(model.conv2.weight)) + print(f"Iter {i} (Moam): stddevs = {rms_values} ") + B = 4 + T = 20 + x = torch.randn(B, T, 100) + y = model(x) + yderiv = torch.randn_like(y) + if i % 190 <= 3 and i > 0: + yderiv *= 100.0 + if i % 550 == 0 and i > 0: + yderiv *= float('inf') + + y.backward(gradient=yderiv) + optimizer.step() + model.zero_grad() + print("") + + + +def test_to_device(): + if not torch.cuda.is_available(): + return + a_gpu = torch.ones(1,2,3,4, device='cuda') + b_gpu = torch.zeros(3,8, device='cuda') + (a_cpu, b_cpu) = _to_device('cpu', a_gpu, b_gpu) + print("a_cpu,b_cpu = ", a_cpu, b_cpu) + (a_gpu2, b_gpu2) = _to_device('cuda', a_cpu, b_cpu) + print("a_gpu2,b_gpu2 = ", a_gpu2, b_gpu2) + +# Caution: this testing code is not very automated, it reqires looking at the output to +# make sure it looks right. The main thing is that with l2=True, the printed stddevs stay close +# to the "Target rms" values, which are printed out; while with l2=False, the stddevs +# increase to significantly higher than that. +# +# The test of the Moam optimizer is mainly to make sure it runs; the scale of the +# gradients, and the learning rate, are such that one of the rms's stays quite a bit +# above the target value, i.e. (0.047, 0.044, 0.047), vs. targets of +# (0.057, 0.04, 0.019), I think this has to do with the alpha<1 stability mechanism being +# activated, the l2 does have an effect, as I verified by changing the code to set +# l2=False. +def main(): + # Set number of threads to 1, or Torch can do weird things that make it extremely slow. + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + #test_to_device() + random.seed(0) + torch.random.manual_seed(0) + test_madam() + #test_moam() + + +if __name__ == '__main__': + main() diff --git a/egs/librispeech/ASR/conformer_lm/train.py b/egs/librispeech/ASR/conformer_lm/train.py new file mode 100755 index 000000000..85d671ecc --- /dev/null +++ b/egs/librispeech/ASR/conformer_lm/train.py @@ -0,0 +1,607 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Daniel Povey) +# +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from pathlib import Path +from shutil import copyfile +from typing import Optional, Tuple + +import k2 +import dataset # from . +import madam # from . +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import torch.nn as nn +from conformer import MaskedLmConformer +from lhotse.utils import fix_random_seed +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.nn.utils import clip_grad_norm_ +from torch.utils.tensorboard import SummaryWriter +from madam import Moam + +from icefall.checkpoint import load_checkpoint +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.dist import cleanup_dist, setup_dist + +from icefall.utils import ( + AttributeDict, + setup_logger, + str2bool, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + is saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - exp_dir: It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + + - lr: It specifies the initial learning rate + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - start_epoch: If it is not zero, load checkpoint `start_epoch-1` + and continue training from that checkpoint. + + - num_epochs: Number of epochs to train. + + - num_valid_batches: Number of batches of validation data to use each + time we compute validation loss + + - symbols_per_batch: Number of symbols in each batch (sampler will + choose the number of sentences to satisfy this contraint). + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + """ + params = AttributeDict( + { + "exp_dir": Path("conformer_lm/exp_1"), + "lm_dataset": Path("data/lm_training_5000/lm_data.pt"), + "num_tokens": 5000, + "blank_sym": 0, + "bos_sym": 1, + "eos_sym": 1, + "start_epoch": 0, + "num_epochs": 20, + "num_valid_batches": 100, + "symbols_per_batch": 10000, + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 10, + "reset_interval": 200, + "valid_interval": 3000, + "beam_size": 10, + "accum_grad": 1, + "attention_dim": 512, + "nhead": 8, + "num_decoder_layers": 6, + "lr_factor": 2.0, + "warm_step": 20000, + } + ) + + return params + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, +) -> None: + """Load checkpoint from file. + + If params.start_epoch is positive, it will load the checkpoint from + `params.start_epoch - 1`. Otherwise, this function does nothing. + + Apart from loading state dict for `model`, `optimizer` and `scheduler`, + it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + optimizer: + The optimizer that we are using. + scheduler: + The learning rate scheduler we are using. + Returns: + Return None. + """ + if params.start_epoch <= 0: + return + + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + saved_params = load_checkpoint( + filename, + model=model, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + params=params, + optimizer=optimizer, + scheduler=scheduler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + model: nn.Module, + batch: Tuple, + is_training: bool, +): + + """ + Compute training or validation loss given the model and its inputs + (this corresponds to log-prob of the targets, with weighting + of 1.0 for masked subsequences + (including padding blanks), and something smaller, e.g. 0.25, + for non-masked positions (this is not totally trivial due to + a small amount of randomization of symbols). + + This loss is not normalized; you can divide by batch[4].sum() + to get a normalized loss (i.e. divide by soft-count). + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of MaskedLmConformer in our case. + batch: + A batch of data, actually a tuple of 5 tensors (on the device), as returned + by collate_fn in ./dataset.py. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + + Returns: + Returns the loss as a scalar tensor. + """ + (masked_src_symbols, src_symbols, + tgt_symbols, src_key_padding_mask, tgt_weights) = batch + + with torch.set_grad_enabled(is_training): + memory, pos_emb = model(masked_src_symbols, src_key_padding_mask) + tgt_nll = model.decoder_nll(memory, pos_emb, src_symbols, + tgt_symbols, src_key_padding_mask) + loss = (tgt_nll * tgt_weights).sum() + + assert loss.requires_grad == is_training + + return loss + + +def compute_validation_loss( + device: torch.device, + params: AttributeDict, + model: nn.Module, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> None: + """Run the validation process. The validation loss + is saved in `params.valid_loss`. + """ + model.eval() + + tot_loss = 0.0 + tot_frames = 0.0 + for batch_idx, batch in enumerate(valid_dl): + batch = tuple(x.to(device) for x in batch) + + # `batch` is actually a tuple.. we'll unpack it later. + loss = compute_loss(model, batch, is_training=False) + num_frames = batch[4].sum() + + assert loss.requires_grad is False + assert ctc_loss.requires_grad is False + assert att_loss.requires_grad is False + + loss_cpu = loss.detach().cpu().item() + num_frames_cpu = num_frames.cpu().item() + + tot_loss += loss_cpu + tot_frames += num_frames_cpu + + + if world_size > 1: + s = torch.tensor( + [tot_loss, tot_frames], + device=loss.device, + ) + dist.all_reduce(s, op=dist.ReduceOp.SUM) + (tot_loss, tot_frames) = s.cpu().tolist() + + params.valid_loss = tot_loss / tot_frames + + if params.valid_loss < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = params.valid_loss + + +def train_one_epoch( + device: torch.device, + params: AttributeDict, + model: nn.Module, + optimizer: torch.optim.Optimizer, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + device: + The device to use for training (model must be on this device) + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + """ + model.train() # training mode + + tot_loss = 0.0 # sum of losses over all batches + tot_frames = 0.0 # sum of frames over all batches + + params.tot_loss = 0.0 + params.tot_frames = 0.0 + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + batch = tuple(x.to(device) for x in batch) + + loss = compute_loss( + model=model, + batch=batch, + is_training=True, + ) + + optimizer.zero_grad() + loss.backward() # We are not normalizing by the num-frames, but Adam/Madam are insensitive to the total + # gradient scale so this should not matter. + # clip_grad_norm_(model.parameters(), 5.0, 2.0) + optimizer.step() + + loss_cpu = loss.detach().cpu().item() + num_frames_cpu = batch[4].sum().cpu().item() + + tot_loss += loss_cpu + tot_frames += num_frames_cpu + + params.tot_frames += num_frames_cpu + params.tot_loss += loss_cpu + + tot_avg_loss = tot_loss / tot_frames + + if batch_idx % params.log_interval == 0: + logging.info( + f"Epoch {params.cur_epoch}, batch {batch_idx}, " + f"batch avg loss {loss_cpu/num_frames_cpu:.4f}, " + f"total avg loss: {tot_avg_loss:.4f}, " + f"batch size: {batch_size}" + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/current_loss", + loss_cpu / params.train_frames, + params.batch_idx_train, + ) + tb_writer.add_scalar( + "train/tot_avg_loss", + tot_avg_loss, + params.batch_idx_train, + ) + if batch_idx > 0 and batch_idx % params.reset_interval == 0: + tot_loss = 0.0 # sum of losses over all batches + tot_frames = 0.0 # sum of frames over all batches + + if batch_idx > 0 and batch_idx % params.valid_interval == 0: + compute_validation_loss( + device=device, + params=params, + model=model, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info( + f"Epoch {params.cur_epoch}, " + f"valid loss {params.valid_loss:.4f}," + f" best valid loss: {params.best_valid_loss:.4f} " + f"best valid epoch: {params.best_valid_epoch}" + ) + if tb_writer is not None: + tb_writer.add_scalar( + "train/valid_loss", + params.valid_loss, + params.batch_idx_train, + ) + + params.train_loss = params.tot_loss / params.tot_frames + + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(42) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + logging.info(params) + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + num_tokens = params.num_tokens + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + + + logging.info("About to create model") + model = MaskedLmConformer( + num_classes=params.num_tokens, + d_model=params.attention_dim, + nhead=params.nhead, + num_decoder_layers=params.num_decoder_layers, + ) + + checkpoints = load_checkpoint_if_available(params=params, model=model) + + model.to(device) + if world_size > 1: + model = DDP(model, device_ids=[rank]) + + optimizer = Moam( + model.parameters(), + model_size=params.attention_dim, + factor=params.lr_factor, + warm_step=params.warm_step, + ) + + if checkpoints: + optimizer.load_state_dict(checkpoints["optimizer"]) + + train,test = dataset.load_train_test_lm_dataset(params.lm_dataset) + + collate_fn=(lambda x:dataset.collate_fn(x, bos_sym=params.bos_sym, + eos_sym=params.eos_sym, + blank_sym=params.blank_sym, + mask_proportion=0.15, + padding_proportion=0.15, + randomize_proportion=0.05, + inv_mask_length=0.25, + unmasked_weight=0.25)) + + train_sampler = dataset.LmBatchSampler(train, + symbols_per_batch=params.symbols_per_batch, + world_size=world_size, rank=rank) + test_sampler = dataset.LmBatchSampler(test, + symbols_per_batch=params.symbols_per_batch, + world_size=world_size, rank=rank) + + train_dl = torch.utils.data.DataLoader(train, + batch_sampler=train_sampler, + collate_fn=collate_fn) + valid_dl = torch.utils.data.DataLoader(test, + batch_sampler=test_sampler, + collate_fn=collate_fn) + + for epoch in range(params.start_epoch, params.num_epochs): + train_dl.sampler.set_epoch(epoch) + + cur_lr = optimizer._rate + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + if rank == 0: + logging.info("epoch {}, learning rate {}".format(epoch, cur_lr)) + + params.cur_epoch = epoch + + train_one_epoch( + device=device, + params=params, + model=model, + optimizer=optimizer, + train_dl=train_dl, + valid_dl=valid_dl, + tb_writer=tb_writer, + world_size=world_size, + ) + + save_checkpoint( + params=params, + model=model, + optimizer=optimizer, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def main(): + parser = get_parser() + args = parser.parse_args() + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main()